# 导包

In [2]:
# -*- coding: utf-8 -*-
import json
import re
import math
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from keras.utils import plot_model
from struct import unpack, pack
import numpy as np
from math import sqrt
import scipy.io as scio
from sklearn.decomposition import PCA
from scipy import interpolate
import scipy.signal as signal
import matplotlib.pyplot as plt
from sklearn import manifold
from tqdm import tqdm
import scipy.io.wavfile as wav
from scipy.fftpack import fft
from random import shuffle
from collections import Counter  
from sklearn.ensemble import IsolationForest 
from functools import reduce
from sklearn.model_selection import train_test_split
!pip install pyts==0.7.1
from pyts.image import *
try:
    !cat /proc/cpuinfo | grep 'model name' |uniq
    !nvidia-smi
except:
    print('there is no GPU') 

model name	: Intel(R) Core(TM) i9-10900F CPU @ 2.80GHz
Fri Dec 18 21:16:43 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 455.45.01    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3070    Off  | 00000000:01:00.0  On |                  N/A |
|  0%   43C    P8    19W / 270W |    298MiB /  7979MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+--------------------------------

# 从.data提取CSI幅值的相关函数

In [3]:
# 从.dat文件中提取csi
def dbinv(x):
    return 10**(x / 10)
# 计算接收到的信号强度(RSS)(以dBm为单位)
def total_rss(data):
    rssi_mag = 0
    if data['rssi_a'] != 0:
        rssi_mag = rssi_mag + dbinv(data['rssi_a'])
    if data['rssi_b'] != 0:
        rssi_mag = rssi_mag + dbinv(data['rssi_b'])
    if data['rssi_c'] != 0:
        rssi_mag = rssi_mag + dbinv(data['rssi_c'])
    return 10 * np.log10(rssi_mag) - 44 - data['agc']
# 将CSI结构转换为信道矩阵H。
def get_scaled_csi(data):
    csi = data['csi']
    ntx = data['ntx']
    nrx = data['nrx']
    csi_sq = csi * np.conj(csi)
    csi_pwr = csi_sq.sum().real  # 求和
    rssi_pwr = dbinv(total_rss(data))
    scale = rssi_pwr / (csi_pwr / 30)
    if data['noise'] == -127:
        noise = -92
    else:
        noise = data['noise']
    thermal_noise_pwr = dbinv(noise)
    quant_error_pwr = scale * (nrx * ntx)
    total_noise_pwr = thermal_noise_pwr + quant_error_pwr
    ret = csi * sqrt(scale / total_noise_pwr)
    if ntx == 2:
        ret = ret * sqrt(2)
    elif ntx == 3:
        ret = ret * sqrt(dbinv(4.5))
    return ret

def expandable_or(x, y):
    z = x | y
    low = z & 0xff
    return unpack('b', pack('B', low))[0]

def read_bfree(array):
    result = {}
    timestamp_low = array[0] + (array[1] << 8) + (array[2] << 16) + (array[3] << 24)
    bf_count = array[4] + (array[5] << 8)
    nrx = array[8]  # 接收天线的数目
    ntx = array[9]
    rssi_a = array[10]
    rssi_b = array[11]
    rssi_c = array[12]
    # noise
    noise = unpack('b', pack('B', array[13]))[0]
    agc = array[14]
    antenna_sel = array[15]
    length = array[16] + (array[17] << 8)
    fake_rate_n_flags = array[18] + (array[19] << 8)
    calc_len = (30 * (nrx * ntx * 8 * 2 + 3) + 7) // 8
    payload = array[20:]  # csi数据部分

    if length != calc_len:
        print('数据发现错误!')
        exit(0)
    result['timestamp_low'] = timestamp_low
    result['bfree_count'] = bf_count
    result['rssi_a'] = rssi_a
    result['rssi_b'] = rssi_b
    result['rssi_c'] = rssi_c
    result['nrx'] = nrx
    result['ntx'] = ntx
    result['agc'] = agc
    result['rate'] = fake_rate_n_flags
    result['noise'] = noise
    csi = np.zeros((ntx, nrx, 30), dtype=np.complex64)
    # 现在开始构建numpy array
    idx = 0        
    for sub_idx in range(30):
        idx = idx + 3
        remainder = idx % 8  # 余数
        for m in range(nrx):
            for n in range(ntx):
                real = expandable_or((payload[idx // 8] >> remainder), (payload[idx // 8 + 1] << (8 - remainder)))
                img = expandable_or((payload[idx // 8 + 1] >> remainder), (payload[idx // 8 + 2] << (8 - remainder)))
                csi[n, m, sub_idx] = complex(real, img)     # 构建一个复数
                idx = idx + 16
    result['csi'] = csi
    perm = np.zeros(3, dtype=np.uint32)
    perm[0] = (antenna_sel & 0x3) + 1
    perm[1] = ((antenna_sel >> 2) & 0x3) + 1
    perm[2] = ((antenna_sel >> 4) & 0x3) + 1
    result['perm'] = perm
    return result

# 从.dat抽取生成CSI字典数组 2*3*30
def extract_csi(file_name):
    triangle = np.array([1, 3, 6])
    csis = []
    with open(file_name, 'rb') as f:
        buff = f.read()
        curr = 0    # 记录当前已经处理到了的位置
        length = len(buff)
        while curr < (length - 3):
            data_len = unpack('>H', buff[curr:curr+2])[0]  # 实际长度
            if data_len > (length - curr - 2):  # 防止越界的错误
                break
            code = unpack('B', buff[curr+2:curr+3])[0]  # 代码
            curr = curr + 3
            if code == 187:
                # 将CSI数据帧解析
                csi_dic = read_bfree(buff[curr:])
                perm = csi_dic['perm']
                nrx = csi_dic['nrx']
                csi = csi_dic['csi']
                if sum(perm) == triangle[nrx - 1]:  # 下标从0开始
                    csi[:, perm - 1, :] = csi[:, [x for x in range(nrx)], :]
                # csi = get_scaled_csi(data)
                csis.append(csi_dic)
            curr = curr + data_len - 1
    return csis

# 特征工程

In [4]:
# 提取每个天线对上的CSI
def process_data(raw_file_path, save_path):
  X={}
  for f in tqdm(os.listdir(raw_file_path)):
    if f.endswith('.dat'):
      file_name = os.path.join(raw_file_path, f)
      extracted_data = extract_csi(file_name)
      print('processing {} the length of this file is:{}'.format(file_name, len(extracted_data)))
      tx, rx, sub = extracted_data[0]['csi'].shape
      data_csi = np.zeros((len(extracted_data), tx, rx, sub), dtype=np.complex64)
      for i in range(len(extracted_data)):
        data_csi[i] = get_scaled_csi(extracted_data[i])
      data_csi = np.clip(np.abs(np.squeeze(data_csi)), 1e-8, 1e100)[:,:,:2,:].reshape(-1, 4, 30)   # (205, 2, 2, 30)
      data = np.zeros((data_csi.shape[0],4))  #N*4
      for ant in range(4):  # 每个天线对上的CSI变化趋势相同,为节约计算这里选择天线对即可
          data_csi_ant = data_csi[:, ant, :]  
          b, a = signal.butter(5, 4*2/30, 'low')
          var_max = 0
          s_max = None
          for s in range(30):
              carrier_data = signal.lfilter(b, a, data_csi_ant[:, s]) # N*1
              length = len(carrier_data)
              var_temp = np.var(carrier_data[length//5:3*length//5]) 
              if var_max < var_temp: 
                  var_max = var_temp
                  s_max = carrier_data
          data[:, ant] = s_max
      scio.savemat(os.path.join(save_path, f.split('.')[0]+'.mat'), {'csi': data})
      X[f]=data
  print('all raw file processed')
  return X
  
train_raw_file_path = r'../data/raw/train'
train_save_path = r'../data/raw/train'
test_raw_file_path = r'../data/processed/test'
test_save_path = r'../data/processed/test'


In [4]:
X_train = process_data(train_raw_file_path, train_save_path)
X_test = process_data(test_raw_file_path, test_save_path)

  3%|▎         | 27/780 [00:00<00:13, 57.47it/s]

processing ../data/raw/train/csi-s1-e1-a2-12.dat the length of this file is:255
processing ../data/raw/train/csi-s1-e1-a1-18.dat the length of this file is:315
processing ../data/raw/train/csi-s1-e1-a3-27.dat the length of this file is:242
processing ../data/raw/train/csi-s1-e1-a1-29.dat the length of this file is:287


  4%|▍         | 34/780 [00:00<00:12, 60.67it/s]

processing ../data/raw/train/csi-s1-e1-a5-40.dat the length of this file is:190
processing ../data/raw/train/csi-s1-e1-a1-35.dat the length of this file is:297
processing ../data/raw/train/csi-s1-e1-a4-32.dat the length of this file is:409


  5%|▌         | 41/780 [00:00<00:15, 46.53it/s]

processing ../data/raw/train/csi-s1-e1-a4-42.dat the length of this file is:384
processing ../data/raw/train/csi-s1-e1-a1-40.dat the length of this file is:347
processing ../data/raw/train/csi-s1-e1-a2-5.dat the length of this file is:281


  7%|▋         | 57/780 [00:00<00:14, 50.62it/s]

processing ../data/raw/train/csi-s1-e1-a4-45.dat the length of this file is:409
processing ../data/raw/train/csi-s1-e1-a5-24.dat the length of this file is:183
processing ../data/raw/train/csi-s1-e1-a2-29.dat the length of this file is:247
processing ../data/raw/train/csi-s1-e1-a3-41.dat the length of this file is:310


  9%|▉         | 71/780 [00:01<00:13, 52.15it/s]

processing ../data/raw/train/csi-s1-e1-a2-15.dat the length of this file is:254
processing ../data/raw/train/csi-s1-e1-a1-23.dat the length of this file is:297
processing ../data/raw/train/csi-s1-e1-a3-28.dat the length of this file is:284
processing ../data/raw/train/csi-s1-e1-a4-35.dat the length of this file is:398


 11%|█         | 82/780 [00:01<00:16, 43.36it/s]

processing ../data/raw/train/csi-s1-e1-a2-21.dat the length of this file is:304
processing ../data/raw/train/csi-s1-e1-a3-43.dat the length of this file is:304
processing ../data/raw/train/csi-s1-e1-a2-6.dat the length of this file is:293
processing ../data/raw/train/csi-s1-e1-a2-39.dat the length of this file is:314
processing ../data/raw/train/csi-s1-e1-a1-33.dat the length of this file is:298


 12%|█▏        | 92/780 [00:01<00:16, 42.38it/s]

processing ../data/raw/train/csi-s1-e1-a2-42.dat the length of this file is:256
processing ../data/raw/train/csi-s1-e1-a3-49.dat the length of this file is:296
processing ../data/raw/train/csi-s1-e1-a1-37.dat the length of this file is:285
processing ../data/raw/train/csi-s1-e1-a4-22.dat the length of this file is:354


 12%|█▏        | 97/780 [00:01<00:17, 39.71it/s]

processing ../data/raw/train/csi-s1-e1-a4-16.dat the length of this file is:352
processing ../data/raw/train/csi-s1-e1-a4-25.dat the length of this file is:346
processing ../data/raw/train/csi-s1-e1-a1-39.dat the length of this file is:302
processing ../data/raw/train/csi-s1-e1-a1-36.dat the length of this file is:292


 14%|█▍        | 113/780 [00:02<00:17, 39.08it/s]

processing ../data/raw/train/csi-s1-e1-a1-4.dat the length of this file is:360
processing ../data/raw/train/csi-s1-e1-a5-38.dat the length of this file is:177
processing ../data/raw/train/csi-s1-e1-a5-16.dat the length of this file is:173
processing ../data/raw/train/csi-s1-e1-a5-50.dat the length of this file is:177
processing ../data/raw/train/csi-s1-e1-a5-45.dat the length of this file is:257


 15%|█▌        | 119/780 [00:02<00:16, 40.09it/s]

processing ../data/raw/train/csi-s1-e1-a4-21.dat the length of this file is:385
processing ../data/raw/train/csi-s1-e1-a6-22.dat the length of this file is:168
processing ../data/raw/train/csi-s1-e1-a6-29.dat the length of this file is:143
processing ../data/raw/train/csi-s1-e1-a6-41.dat the length of this file is:146
processing ../data/raw/train/csi-s1-e1-a6-19.dat the length of this file is:179


 17%|█▋        | 134/780 [00:02<00:13, 49.12it/s]

processing ../data/raw/train/csi-s1-e1-a4-24.dat the length of this file is:378
processing ../data/raw/train/csi-s1-e1-a1-17.dat the length of this file is:295
processing ../data/raw/train/csi-s1-e1-a4-30.dat the length of this file is:405
processing ../data/raw/train/csi-s1-e1-a1-49.dat the length of this file is:303


 19%|█▉        | 147/780 [00:02<00:14, 43.55it/s]

processing ../data/raw/train/csi-s1-e1-a4-41.dat the length of this file is:413
processing ../data/raw/train/csi-s1-e1-a1-15.dat the length of this file is:294
processing ../data/raw/train/csi-s1-e1-a1-7.dat the length of this file is:319
processing ../data/raw/train/csi-s1-e1-a4-29.dat the length of this file is:386


 23%|██▎       | 181/780 [00:03<00:11, 53.28it/s]

processing ../data/raw/train/csi-s1-e1-a5-18.dat the length of this file is:157
processing ../data/raw/train/csi-s1-e1-a1-8.dat the length of this file is:287
processing ../data/raw/train/csi-s1-e1-a6-23.dat the length of this file is:125
processing ../data/raw/train/csi-s1-e1-a5-34.dat the length of this file is:167
processing ../data/raw/train/csi-s1-e1-a2-18.dat the length of this file is:281


 25%|██▍       | 192/780 [00:03<00:12, 46.36it/s]

processing ../data/raw/train/csi-s1-e1-a2-4.dat the length of this file is:274
processing ../data/raw/train/csi-s1-e1-a3-26.dat the length of this file is:281
processing ../data/raw/train/csi-s1-e1-a6-50.dat the length of this file is:173
processing ../data/raw/train/csi-s1-e1-a3-11.dat the length of this file is:235
processing ../data/raw/train/csi-s1-e1-a3-25.dat the length of this file is:255
processing ../data/raw/train/csi-s1-e1-a6-15.dat the length of this file is:137

 27%|██▋       | 208/780 [00:03<00:10, 56.58it/s]


processing ../data/raw/train/csi-s1-e1-a1-26.dat the length of this file is:299
processing ../data/raw/train/csi-s1-e1-a2-24.dat the length of this file is:241
processing ../data/raw/train/csi-s1-e1-a1-42.dat the length of this file is:367
processing ../data/raw/train/csi-s1-e1-a4-12.dat the length of this file is:478
processing ../data/raw/train/csi-s1-e1-a4-26.dat the length of this file is:364
processing ../data/raw/train/csi-s1-e1-a1-48.dat the length of this file is:311
processing ../data/raw/train/csi-s1-e1-a1-16.dat the length of this file is:301


 29%|██▉       | 226/780 [00:04<00:12, 44.18it/s]

processing ../data/raw/train/csi-s1-e1-a4-34.dat the length of this file is:412
processing ../data/raw/train/csi-s1-e1-a2-16.dat the length of this file is:258
processing ../data/raw/train/csi-s1-e1-a3-19.dat the length of this file is:181
processing ../data/raw/train/csi-s1-e1-a2-30.dat the length of this file is:264


 30%|██▉       | 233/780 [00:04<00:12, 44.09it/s]

processing ../data/raw/train/csi-s1-e1-a1-45.dat the length of this file is:298
processing ../data/raw/train/csi-s1-e1-a3-23.dat the length of this file is:210
processing ../data/raw/train/csi-s1-e1-a1-9.dat the length of this file is:291
processing ../data/raw/train/csi-s1-e1-a6-40.dat the length of this file is:136


 32%|███▏      | 248/780 [00:04<00:11, 46.00it/s]

processing ../data/raw/train/csi-s1-e1-a4-48.dat the length of this file is:417
processing ../data/raw/train/csi-s1-e1-a3-30.dat the length of this file is:246
processing ../data/raw/train/csi-s1-e1-a4-38.dat the length of this file is:416
processing ../data/raw/train/csi-s1-e1-a6-45.dat the length of this file is:166


 34%|███▍      | 267/780 [00:04<00:09, 52.81it/s]

processing ../data/raw/train/csi-s1-e1-a5-13.dat the length of this file is:162
processing ../data/raw/train/csi-s1-e1-a3-34.dat the length of this file is:260
processing ../data/raw/train/csi-s1-e1-a2-33.dat the length of this file is:247
processing ../data/raw/train/csi-s1-e1-a6-46.dat the length of this file is:166
processing ../data/raw/train/csi-s1-e1-a3-36.dat the length of this file is:290


 36%|███▌      | 282/780 [00:04<00:08, 61.85it/s]

processing ../data/raw/train/csi-s1-e1-a6-20.dat the length of this file is:158
processing ../data/raw/train/csi-s1-e1-a1-6.dat the length of this file is:291
processing ../data/raw/train/csi-s1-e1-a3-40.dat the length of this file is:267
processing ../data/raw/train/csi-s1-e1-a2-20.dat the length of this file is:260
processing ../data/raw/train/csi-s1-e1-a6-26.dat the length of this file is:172


 37%|███▋      | 290/780 [00:05<00:08, 56.06it/s]

processing ../data/raw/train/csi-s1-e1-a2-27.dat the length of this file is:272
processing ../data/raw/train/csi-s1-e1-a6-39.dat the length of this file is:140
processing ../data/raw/train/csi-s1-e1-a1-25.dat the length of this file is:299
processing ../data/raw/train/csi-s1-e1-a2-9.dat the length of this file is:260


 39%|███▉      | 308/780 [00:05<00:08, 55.02it/s]

processing ../data/raw/train/csi-s1-e1-a1-47.dat the length of this file is:291
processing ../data/raw/train/csi-s1-e1-a5-17.dat the length of this file is:192
processing ../data/raw/train/csi-s1-e1-a2-38.dat the length of this file is:263
processing ../data/raw/train/csi-s1-e1-a2-3.dat the length of this file is:282
processing ../data/raw/train/csi-s1-e1-a6-24.dat the length of this file is:190


 40%|████      | 315/780 [00:05<00:10, 45.66it/s]

processing ../data/raw/train/csi-s1-e1-a4-15.dat the length of this file is:375
processing ../data/raw/train/csi-s1-e1-a4-49.dat the length of this file is:494
processing ../data/raw/train/csi-s1-e1-a6-12.dat the length of this file is:135
processing ../data/raw/train/csi-s1-e1-a3-16.dat the length of this file is:251


 43%|████▎     | 334/780 [00:05<00:08, 54.11it/s]

processing ../data/raw/train/csi-s1-e1-a5-21.dat the length of this file is:151
processing ../data/raw/train/csi-s1-e1-a5-39.dat the length of this file is:209
processing ../data/raw/train/csi-s1-e1-a6-28.dat the length of this file is:156
processing ../data/raw/train/csi-s1-e1-a2-10.dat the length of this file is:254
processing ../data/raw/train/csi-s1-e1-a6-11.dat the length of this file is:127


 44%|████▎     | 341/780 [00:06<00:08, 49.26it/s]

processing ../data/raw/train/csi-s1-e1-a1-50.dat the length of this file is:304
processing ../data/raw/train/csi-s1-e1-a5-49.dat the length of this file is:271
processing ../data/raw/train/csi-s1-e1-a1-19.dat the length of this file is:283
processing ../data/raw/train/csi-s1-e1-a2-37.dat the length of this file is:275


 46%|████▌     | 356/780 [00:06<00:08, 51.03it/s]

processing ../data/raw/train/csi-s1-e1-a1-30.dat the length of this file is:295
processing ../data/raw/train/csi-s1-e1-a4-44.dat the length of this file is:402
processing ../data/raw/train/csi-s1-e1-a5-37.dat the length of this file is:187
processing ../data/raw/train/csi-s1-e1-a6-31.dat the length of this file is:160
processing ../data/raw/train/csi-s1-e1-a6-36.dat the length of this file is:174


 46%|████▋     | 362/780 [00:06<00:09, 44.27it/s]

processing ../data/raw/train/csi-s1-e1-a6-16.dat the length of this file is:146
processing ../data/raw/train/csi-s1-e1-a3-39.dat the length of this file is:256
processing ../data/raw/train/csi-s1-e1-a1-41.dat the length of this file is:306
processing ../data/raw/train/csi-s1-e1-a5-36.dat the length of this file is:181


 47%|████▋     | 370/780 [00:06<00:08, 46.48it/s]

processing ../data/raw/train/csi-s1-e1-a1-14.dat the length of this file is:300
processing ../data/raw/train/csi-s1-e1-a2-31.dat the length of this file is:272
processing ../data/raw/train/csi-s1-e1-a6-27.dat the length of this file is:175
processing ../data/raw/train/csi-s1-e1-a1-38.dat the length of this file is:287


 49%|████▉     | 381/780 [00:06<00:09, 44.11it/s]

processing ../data/raw/train/csi-s1-e1-a2-32.dat the length of this file is:291
processing ../data/raw/train/csi-s1-e1-a2-50.dat the length of this file is:307
processing ../data/raw/train/csi-s1-e1-a2-36.dat the length of this file is:274
processing ../data/raw/train/csi-s1-e1-a2-11.dat the length of this file is:257


 51%|█████     | 394/780 [00:07<00:08, 46.01it/s]

processing ../data/raw/train/csi-s1-e1-a3-21.dat the length of this file is:275
processing ../data/raw/train/csi-s1-e1-a5-32.dat the length of this file is:181
processing ../data/raw/train/csi-s1-e1-a1-12.dat the length of this file is:362
processing ../data/raw/train/csi-s1-e1-a2-40.dat the length of this file is:249
processing ../data/raw/train/csi-s1-e1-a5-48.dat the length of this file is:201


 53%|█████▎    | 410/780 [00:07<00:06, 55.07it/s]

processing ../data/raw/train/csi-s1-e1-a3-38.dat the length of this file is:244
processing ../data/raw/train/csi-s1-e1-a2-34.dat the length of this file is:328
processing ../data/raw/train/csi-s1-e1-a5-15.dat the length of this file is:171
processing ../data/raw/train/csi-s1-e1-a6-37.dat the length of this file is:134
processing ../data/raw/train/csi-s1-e1-a3-46.dat the length of this file is:270


 54%|█████▍    | 424/780 [00:07<00:06, 51.75it/s]

processing ../data/raw/train/csi-s1-e1-a6-48.dat the length of this file is:181
processing ../data/raw/train/csi-s1-e1-a1-34.dat the length of this file is:299
processing ../data/raw/train/csi-s1-e1-a5-23.dat the length of this file is:179
processing ../data/raw/train/csi-s1-e1-a5-41.dat the length of this file is:198
processing ../data/raw/train/csi-s1-e1-a5-22.dat the length of this file is:178


 55%|█████▌    | 430/780 [00:07<00:06, 53.90it/s]

processing ../data/raw/train/csi-s1-e1-a1-13.dat the length of this file is:309
processing ../data/raw/train/csi-s1-e1-a1-3.dat the length of this file is:291
processing ../data/raw/train/csi-s1-e1-a4-43.dat the length of this file is:386
processing ../data/raw/train/csi-s1-e1-a5-19.dat the length of this file is:176


 57%|█████▋    | 446/780 [00:08<00:06, 49.56it/s]

processing ../data/raw/train/csi-s1-e1-a5-43.dat the length of this file is:172
processing ../data/raw/train/csi-s1-e1-a6-44.dat the length of this file is:165
processing ../data/raw/train/csi-s1-e1-a6-49.dat the length of this file is:177
processing ../data/raw/train/csi-s1-e1-a3-35.dat the length of this file is:293
processing ../data/raw/train/csi-s1-e1-a3-12.dat the length of this file is:208


 59%|█████▊    | 457/780 [00:08<00:05, 55.63it/s]

processing ../data/raw/train/csi-s1-e1-a3-18.dat the length of this file is:245
processing ../data/raw/train/csi-s1-e1-a2-17.dat the length of this file is:246
processing ../data/raw/train/csi-s1-e1-a4-33.dat the length of this file is:401
processing ../data/raw/train/csi-s1-e1-a3-13.dat the length of this file is:250


 60%|██████    | 471/780 [00:08<00:05, 52.36it/s]

processing ../data/raw/train/csi-s1-e1-a3-50.dat the length of this file is:253
processing ../data/raw/train/csi-s1-e1-a2-48.dat the length of this file is:263
processing ../data/raw/train/csi-s1-e1-a6-33.dat the length of this file is:121
processing ../data/raw/train/csi-s1-e1-a5-30.dat the length of this file is:163
processing ../data/raw/train/csi-s1-e1-a2-45.dat the length of this file is:258
processing ../data/raw/train/csi-s1-e1-a6-35.dat the length of this file is:166


 62%|██████▏   | 486/780 [00:08<00:05, 54.67it/s]

processing ../data/raw/train/csi-s1-e1-a1-28.dat the length of this file is:297
processing ../data/raw/train/csi-s1-e1-a5-44.dat the length of this file is:187
processing ../data/raw/train/csi-s1-e1-a5-29.dat the length of this file is:170
processing ../data/raw/train/csi-s1-e1-a2-7.dat the length of this file is:244
processing ../data/raw/train/csi-s1-e1-a4-23.dat the length of this file is:418


 64%|██████▍   | 500/780 [00:09<00:04, 57.94it/s]

processing ../data/raw/train/csi-s1-e1-a5-27.dat the length of this file is:156
processing ../data/raw/train/csi-s1-e1-a2-47.dat the length of this file is:300
processing ../data/raw/train/csi-s1-e1-a3-48.dat the length of this file is:273
processing ../data/raw/train/csi-s1-e1-a5-31.dat the length of this file is:164
processing ../data/raw/train/csi-s1-e1-a3-42.dat the length of this file is:274


 66%|██████▌   | 514/780 [00:09<00:03, 68.98it/s]

processing ../data/raw/train/csi-s1-e1-a6-14.dat the length of this file is:117
processing ../data/raw/train/csi-s1-e1-a3-32.dat the length of this file is:263
processing ../data/raw/train/csi-s1-e1-a1-31.dat the length of this file is:274
processing ../data/raw/train/csi-s1-e1-a2-28.dat the length of this file is:255
processing ../data/raw/train/csi-s1-e1-a6-47.dat the length of this file is:150


 67%|██████▋   | 522/780 [00:09<00:05, 51.35it/s]

processing ../data/raw/train/csi-s1-e1-a1-5.dat the length of this file is:297
processing ../data/raw/train/csi-s1-e1-a1-1.dat the length of this file is:312
processing ../data/raw/train/csi-s1-e1-a4-47.dat the length of this file is:407


 69%|██████▊   | 535/780 [00:09<00:05, 42.81it/s]

processing ../data/raw/train/csi-s1-e1-a4-13.dat the length of this file is:382
processing ../data/raw/train/csi-s1-e1-a2-2.dat the length of this file is:250
processing ../data/raw/train/csi-s1-e1-a4-14.dat the length of this file is:377
processing ../data/raw/train/csi-s1-e1-a6-21.dat the length of this file is:144
processing ../data/raw/train/csi-s1-e1-a3-24.dat the length of this file is:257


 72%|███████▏  | 563/780 [00:10<00:03, 56.81it/s]

processing ../data/raw/train/csi-s1-e1-a5-20.dat the length of this file is:164
processing ../data/raw/train/csi-s1-e1-a2-25.dat the length of this file is:268
processing ../data/raw/train/csi-s1-e1-a3-17.dat the length of this file is:235
processing ../data/raw/train/csi-s1-e1-a4-11.dat the length of this file is:404
processing ../data/raw/train/csi-s1-e1-a3-45.dat the length of this file is:294
processing ../data/raw/train/csi-s1-e1-a6-38.dat the length of this file is:131
processing ../data/raw/train/csi-s1-e1-a4-19.dat the length of this file is:415
processing ../data/raw/train/csi-s1-e1-a2-44.dat the length of this file is:269


 74%|███████▍  | 578/780 [00:10<00:04, 43.12it/s]

processing ../data/raw/train/csi-s1-e1-a2-19.dat the length of this file is:267
processing ../data/raw/train/csi-s1-e1-a5-47.dat the length of this file is:223
processing ../data/raw/train/csi-s1-e1-a3-31.dat the length of this file is:254
processing ../data/raw/train/csi-s1-e1-a4-17.dat the length of this file is:383


 76%|███████▌  | 594/780 [00:10<00:03, 53.07it/s]

processing ../data/raw/train/csi-s1-e1-a2-41.dat the length of this file is:247
processing ../data/raw/train/csi-s1-e1-a5-26.dat the length of this file is:193
processing ../data/raw/train/csi-s1-e1-a2-8.dat the length of this file is:256
processing ../data/raw/train/csi-s1-e1-a2-14.dat the length of this file is:264


 80%|███████▉  | 621/780 [00:10<00:02, 68.63it/s]

processing ../data/raw/train/csi-s1-e1-a4-40.dat the length of this file is:482
processing ../data/raw/train/csi-s1-e1-a5-35.dat the length of this file is:196
processing ../data/raw/train/csi-s1-e1-a6-30.dat the length of this file is:166
processing ../data/raw/train/csi-s1-e1-a1-22.dat the length of this file is:313
processing ../data/raw/train/csi-s1-e1-a6-43.dat the length of this file is:149
processing ../data/raw/train/csi-s1-e1-a6-17.dat the length of this file is:146


 83%|████████▎ | 646/780 [00:11<00:01, 80.68it/s]

processing ../data/raw/train/csi-s1-e1-a4-36.dat the length of this file is:393
processing ../data/raw/train/csi-s1-e1-a1-46.dat the length of this file is:296
processing ../data/raw/train/csi-s1-e1-a3-47.dat the length of this file is:263
processing ../data/raw/train/csi-s1-e1-a5-14.dat the length of this file is:197
processing ../data/raw/train/csi-s1-e1-a5-46.dat the length of this file is:240


 84%|████████▍ | 656/780 [00:11<00:01, 65.90it/s]

processing ../data/raw/train/csi-s1-e1-a3-20.dat the length of this file is:234
processing ../data/raw/train/csi-s1-e1-a3-15.dat the length of this file is:237
processing ../data/raw/train/csi-s1-e1-a5-42.dat the length of this file is:199
processing ../data/raw/train/csi-s1-e1-a6-25.dat the length of this file is:166
processing ../data/raw/train/csi-s1-e1-a1-32.dat the length of this file is:290


 85%|████████▌ | 666/780 [00:11<00:01, 67.74it/s]

processing ../data/raw/train/csi-s1-e1-a3-14.dat the length of this file is:233
processing ../data/raw/train/csi-s1-e1-a5-25.dat the length of this file is:168
processing ../data/raw/train/csi-s1-e1-a4-27.dat the length of this file is:412
processing ../data/raw/train/csi-s1-e1-a1-43.dat the length of this file is:287


 89%|████████▊ | 692/780 [00:11<00:01, 61.14it/s]

processing ../data/raw/train/csi-s1-e1-a4-28.dat the length of this file is:426
processing ../data/raw/train/csi-s1-e1-a5-33.dat the length of this file is:176
processing ../data/raw/train/csi-s1-e1-a3-33.dat the length of this file is:240
processing ../data/raw/train/csi-s1-e1-a3-44.dat the length of this file is:323
processing ../data/raw/train/csi-s1-e1-a6-42.dat the length of this file is:139


 90%|████████▉ | 701/780 [00:12<00:01, 53.33it/s]

processing ../data/raw/train/csi-s1-e1-a1-11.dat the length of this file is:280
processing ../data/raw/train/csi-s1-e1-a2-46.dat the length of this file is:270
processing ../data/raw/train/csi-s1-e1-a4-37.dat the length of this file is:428
processing ../data/raw/train/csi-s1-e1-a5-11.dat the length of this file is:218


 91%|█████████ | 709/780 [00:12<00:01, 52.45it/s]

processing ../data/raw/train/csi-s1-e1-a1-10.dat the length of this file is:296
processing ../data/raw/train/csi-s1-e1-a1-27.dat the length of this file is:294
processing ../data/raw/train/csi-s1-e1-a3-37.dat the length of this file is:257
processing ../data/raw/train/csi-s1-e1-a6-18.dat the length of this file is:158
processing ../data/raw/train/csi-s1-e1-a2-49.dat the length of this file is:290


 93%|█████████▎| 729/780 [00:12<00:01, 49.31it/s]

processing ../data/raw/train/csi-s1-e1-a5-28.dat the length of this file is:211
processing ../data/raw/train/csi-s1-e1-a2-1.dat the length of this file is:252
processing ../data/raw/train/csi-s1-e1-a4-50.dat the length of this file is:411
processing ../data/raw/train/csi-s1-e1-a2-13.dat the length of this file is:257


 94%|█████████▍| 736/780 [00:12<00:00, 44.73it/s]

processing ../data/raw/train/csi-s1-e1-a4-39.dat the length of this file is:377
processing ../data/raw/train/csi-s1-e1-a1-44.dat the length of this file is:289
processing ../data/raw/train/csi-s1-e1-a1-24.dat the length of this file is:308
processing ../data/raw/train/csi-s1-e1-a2-22.dat the length of this file is:241
processing ../data/raw/train/csi-s1-e1-a6-13.dat the length of this file is:150


 96%|█████████▌| 747/780 [00:13<00:00, 43.26it/s]

processing ../data/raw/train/csi-s1-e1-a6-32.dat the length of this file is:104
processing ../data/raw/train/csi-s1-e1-a2-23.dat the length of this file is:248
processing ../data/raw/train/csi-s1-e1-a1-21.dat the length of this file is:299
processing ../data/raw/train/csi-s1-e1-a2-35.dat the length of this file is:302


 97%|█████████▋| 758/780 [00:13<00:00, 44.63it/s]

processing ../data/raw/train/csi-s1-e1-a2-43.dat the length of this file is:257
processing ../data/raw/train/csi-s1-e1-a1-20.dat the length of this file is:293
processing ../data/raw/train/csi-s1-e1-a4-18.dat the length of this file is:394
processing ../data/raw/train/csi-s1-e1-a3-29.dat the length of this file is:274


 98%|█████████▊| 763/780 [00:13<00:00, 39.55it/s]

processing ../data/raw/train/csi-s1-e1-a4-46.dat the length of this file is:448
processing ../data/raw/train/csi-s1-e1-a4-31.dat the length of this file is:372
processing ../data/raw/train/csi-s1-e1-a4-20.dat the length of this file is:369
processing ../data/raw/train/csi-s1-e1-a6-34.dat the length of this file is:178


100%|██████████| 780/780 [00:13<00:00, 56.43it/s]
0it [00:00, ?it/s]

processing ../data/raw/train/csi-s1-e1-a5-12.dat the length of this file is:219
processing ../data/raw/train/csi-s1-e1-a3-22.dat the length of this file is:261
processing ../data/raw/train/csi-s1-e1-a1-2.dat the length of this file is:353
processing ../data/raw/train/csi-s1-e1-a2-26.dat the length of this file is:292
all raw file processed
all raw file processed





In [5]:
# src_path=r'/content/drive/My Drive/SR_CSI/Gestures_data/train'
# target_path=r'/content/drive/My Drive/SR_CSI/Gestures_data/test'
# a=[]
# for i in os.listdir(target_path):
#   a.append(i)
# print(len(a))
#   if i.endswith('.mat'):
#     os.remove(os.path.join(src_path,i))
# import os,shutil
# for i in range(1,7):
#   for j in range(6, 11):
#     shutil.move(src_path+r'/csi-s1-e1-a{}-{}.txt'.format(i,j),target_path+r'/csi-s1-e1-a{}-{}.txt'.format(i,j))   
# # X_train = process_data(train_raw_file_path, train_save_path)
# # test_raw_file_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/test'

In [6]:


# 重新构造mat文件选择最佳子载波上的CSI
best_index_list = [1, 3, 3, 3, 2, 3]
#
for f in tqdm(os.listdir(train_save_path)):
    if f.endswith('.mat'):
      file_name = os.path.join(train_save_path, f)
      activate_index = int(re.findall(re.compile(r'csi-s1-e1-a(.*?)-.*?.mat', re.S), f)[0])
      csi = np.squeeze(scio.loadmat(file_name)['csi'])
      best_csi = csi[:, best_index_list[activate_index-1]]     
      scio.savemat(file_name, {'csi': best_csi})
for f in tqdm(os.listdir(test_save_path)):
    if f.endswith('.mat'):
      file_name = os.path.join(test_save_path, f)
      activate_index = int(re.findall(re.compile(r'csi-s1-e1-a(.*?)-.*?.mat', re.S), f)[0])
      best_csi = np.squeeze(scio.loadmat(file_name)['csi'])[:, best_index_list[activate_index-1]]
      scio.savemat(file_name, {'csi': best_csi})

100%|██████████| 780/780 [00:00<00:00, 14165.96it/s]
0it [00:00, ?it/s]


In [7]:
# --抖动,添加噪声超参数：sigma =噪声的标准偏差（STD）
def add_jitter(X, sigma=0.1):  
    myNoise = np.random.normal(loc=0, scale=sigma, size=X.shape)
    # plt.plot(myNoise, label='noise')
    return X+myNoise

# --缩放¶超参数:σ=放大/缩小系数的标准值通过乘以一个随机标量来更改窗口中数据的大小
def add_scaling(X, sigma=0.2):
    scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1))
    myNoise = np.matmul(np.ones((X.shape[0],1)), scalingFactor)
    return X*myNoise

# 降采样,使用一组降采样因子 k1, k2, k3，每隔 ki-1 个数据取一个。
def down_sampling(data, rate=1):
    # down sampling by rate k
    if rate > data.shape[0] / 3:
        print('sampling rate is too high')
        return None
    ds_data = data[::rate]  # temp after down sampling
    ds_data_len = ds_data.shape[0]  # remark the length info
    return ds_data 

# --滑动平均 使用一组滑动窗口l1, l2, l3，每li个数据取平均
def moving_average(data, moving_wl=10):
    data_len = data.shape[0]
    if  moving_wl > data.shape[0] / 3:
        print('moving window is too high')
        return None
    ma_data = np.zeros(data_len-moving_wl+1)
    for i in range(data_len-moving_wl+1):
        ma_data[i] = np.mean(data[i: i+moving_wl])
    return ma_data

# ------------------裁剪（Crop） 使用滑动窗口在时间序列上截取数据
def data_crop(data, wl_ratio=0.8):
    data_len = data.shape[0]
    wl = int(data_len*wl_ratio)
    start = int(data_len*(1-wl_ratio)//2)
    end = start + wl
    #print(start, end)
    crop_data = data[start:end]
    return crop_data

In [8]:
#　加载数据生成特征，供网络训练
# 读取文件,生成音频文件和标签文件列表
def get_file_list(train_file):
    label_f_list = []
    wav_f_list = []
    for root, dirs, files in os.walk(train_file):
        for file in files:
            if file.endswith('.mat') or file.endswith('.MAT'):
                wav_file = os.sep.join([root, file])
                label_file = wav_file.split('.mat')[0] + '.txt'
                wav_f_list.append(wav_file)
                label_f_list.append(label_file)
    return label_f_list, wav_f_list

# label数据处理
def get_label_data(label_f_list):
    # 生成label_data每个文件里的声音标签集合
    l_list = []
    for label_file in tqdm(label_f_list):
        with open(label_file, 'r', encoding='utf8') as ff:
            try:
                data = ff.read()
            except:
                print(label_file, 'not get label data')
            l_list.append(data)
    return l_list

# 为label建立词典
def gen_py_list(label_batch):
    p_list = []
    for li in label_batch:
        l = li.split('  ')
        for pny in l:
            if pny not in p_list:
                p_list.append(pny)
    p_list.append('_')  # 该帧可能是空
    return p_list

# 将读取到的label映射到对应的id
def py2id(l, p_list):
    ids = []
    for py in l.split('  '):
        try:
          ids.append(p_list.index(py))
        except ValueError:
          print(py, 'is not in py_llist')
    return ids

# 对label进行padding和长度获取，不同的是数据维度不同，且label的长度就是输入给ctc的长度，不需要额外处理
def label_padding(label_batch):
    label_lens = np.array([len(l) for l in label_batch])
    label_max_len = max(label_lens)
    new_label_batch = np.zeros((len(label_batch), label_max_len))
    for j in range(len(label_batch)):
        new_label_batch[j][:len(label_batch[j])] = label_batch[j]
    return new_label_batch, label_lens

# 统一batch内数据：[batch_size, time_step, feature_dim],除此之外，ctc需要获得的信息还有输入序列的长度。
def wav_padding(image_size, wav_batch):
    wav_lens = [len(w) for w in wav_batch]
    wav_max_len = max(wav_lens)
    # 每一个sample的时间长都不一样，选择batch内最长的那个时间为基准，进行padding。
    new_wav_batch = np.zeros((len(wav_batch), wav_max_len, image_size,1))#wav_max_len, len(wav_batch[0][0], 1))) 
    # 需要构成成一个tensorflow块，这就要求每个样本数据形式是一样的。
    for j in range(len(wav_batch)):
        new_wav_batch[j, :wav_batch[j].shape[0], :, 0] = wav_batch[j]
    # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!
    wav_length = np.array([j // 8 for j in wav_lens])
    return new_wav_batch, wav_length

# 生成batch_size的信号时频图和标签数据，存放到两个list中去
def get_batch_generator(b_size, w_list, l_list, p_list, image_size):
    shuffle_list = [i for i in range(len(w_list))]
    while True:
        for j in range(len(w_list)//b_size):
            shuffle(shuffle_list)  # 打乱数据的顺序，我们通过查询乱序的索引值，来确定训练数据的顺序
            wav_batch = []
            label_batch = []
            begin = j*b_size
            end = begin+b_size
            for index in shuffle_list[begin:end]:
                fbank = compute_fbank_filt(w_list[index], image_size)
                # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
                pad_fbank = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
                pad_fbank[:fbank.shape[0], :] = fbank
                label = py2id(l_list[index], p_list)
                wav_batch.append(pad_fbank)
                label_batch.append(label)
            pad_wav_data, wav_length = wav_padding(image_size, wav_batch)
            pad_label_data, label_length = label_padding(label_batch)
            input_batch = {'the_inputs': pad_wav_data,
                    'the_labels': pad_label_data,
                    'input_length': wav_length,
                    'label_length': label_length}
            output_batch = {'ctc': np.zeros(pad_wav_data.shape[0])}

            yield input_batch, output_batch
def get_train_data(w_list, l_list, p_list, image_size):
    wav_batch, label_batch = [], []
    for index in range(len(w_list)):
        X = np.squeeze(scio.loadmat(w_list[index])['csi'])
        #print(type(X),X.shape)
        gasf = GADF(image_size)
        X_gasf = gasf.fit_transform(X.reshape(1, -1))
        fbank = X_gasf[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank[:fbank.shape[0], :] = fbank
        wav_batch.append(pad_fbank)
        label = py2id(l_list[index], p_list)
        label_batch.append(label)

        X1 = add_jitter(X)
        X_gasf1 = gasf.fit_transform(X1.reshape(1, -1))
        fbank1 = X_gasf1[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank1 = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank1[:fbank.shape[0], :] = fbank1
        wav_batch.append(pad_fbank1)
        label_batch.append(label)

        X2 = add_scaling(X)
        X_gasf2 = gasf.fit_transform(X2.reshape(1, -1))
        fbank2 = X_gasf2[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank2 = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank2[:fbank.shape[0], :] = fbank2
        wav_batch.append(pad_fbank2)
        label_batch.append(label)

        X3 = down_sampling(X)
        X_gasf3 = gasf.fit_transform(X3.reshape(1, -1))
        fbank3 = X_gasf3[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank3 = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank3[:fbank.shape[0], :] = fbank3
        wav_batch.append(pad_fbank3)
        label_batch.append(label)

        X4 = moving_average(X)
        X_gasf4 = gasf.fit_transform(X4.reshape(1, -1))
        fbank4 = X_gasf4[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank4 = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank4[:fbank.shape[0], :] = fbank4
        wav_batch.append(pad_fbank4)
        label_batch.append(label)

        X5 = data_crop(X)
        X_gasf5 = gasf.fit_transform(X5.reshape(1, -1))
        fbank5 = X_gasf5[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank5 = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank5[:fbank.shape[0], :] = fbank5
        wav_batch.append(pad_fbank5)
        label_batch.append(label)
          
    pad_wav_data, wav_length = wav_padding(image_size, wav_batch)
    pad_label_data, label_length = label_padding(label_batch)
    inputs = {'the_inputs': pad_wav_data,
          'the_labels': pad_label_data,
          'input_length': wav_length,
          'label_length': label_length}
    outputs = {'ctc': np.zeros(pad_wav_data.shape[0])}
    return inputs, outputs


 # 生成的信号时频图和标签数据，存放到两个list中去
def get_test_data(image_size, w_list, l_list, p_list):
    input_data = []
    output_data = []
    for i in range(len(w_list)):
        wav_batch = []
        label_batch = []

        X = np.squeeze(scio.loadmat(w_list[i])['csi'])
        #print(type(X),X.shape)
        gasf = GADF(image_size)
        X_gasf = gasf.fit_transform(X.reshape(1, -1))
        fbank = X_gasf[0]
        # !!!!!!!3个maxpooling层数据的每个维度需要能够被8整除。因此我们训练实际输入的数据为wav_len//8。!!!!!!!!!!!!             
        pad_fbank = np.zeros((image_size//8*8+8,image_size))#fbank.shape[0]//8*8+8, fbank.shape[1], 3))
        pad_fbank[:fbank.shape[0], :] = fbank
        wav_batch.append(pad_fbank)
        label = py2id(l_list[i], p_list)
        label_batch.append(label[0])
        pad_wav_data, wav_length = wav_padding(image_size, wav_batch)
        input_data.append(pad_wav_data)                                        
        output_data.append(label[0])
    return input_data, output_data

# 加载数据

In [9]:
train_file_path = train_save_path
train_label_file_list, train_wav_file_list = get_file_list(train_file_path)

train_label_list = get_label_data(train_label_file_list)

# 用迭代器的时候需要
# train_wav_file_list, validate_wav_file_list, train_label_list, validate_label_list = train_test_split(train_wav_file_list, train_label_list, test_size=0.2, random_state=0)

# 每个文件拼音标签的集合
with open('label_list.txt', 'w') as f:  
    f.write('\n'.join(train_label_list))  
py_list = gen_py_list(train_label_list)  # 拼音的集合

with open('pinyin_list.txt', 'w') as f:
    f.write('\n'.join(py_list))  # 保存拼音列表

train_file_nums = len(train_wav_file_list)
# validate_file_nums = len(validate_wav_file_list)

py_list_size = len(py_list)  # 模型输出的维度
print('py_list:', py_list)
print('train files amount:', len(train_label_list), train_file_nums, 'label amount:', py_list_size)#, 'validate files amount:', len(validate_label_list), validate_file_nums,)

# 测试数据
test_file_path = test_save_path # r'/content/drive/My Drive/data_thchs30'
test_label_file_list, test_wav_file_list = get_file_list(test_file_path)

test_label_list = get_label_data(test_label_file_list)
print('test files amount:',len(test_label_file_list), len(test_wav_file_list))
test_data_num = len(test_label_file_list)

100%|██████████| 260/260 [00:00<00:00, 114274.24it/s]
0it [00:00, ?it/s]

py_list: ['6', '1', '5', '2', '3', '4', '_']
train files amount: 260 260 label amount: 7
test files amount: 0 0





# 搭建模型

In [10]:
# 添加CTC损失函数
def ctc_lambda(args):
  labels, y_pred, input_length, label_length = args
  y_pred = y_pred[:, :, :]
  return tf.keras.backend.ctc_batch_cost(labels, y_pred, input_length, label_length)
# 定义解码器
def decode_ctc(preds, py_list):
    window_num = np.zeros((1), dtype=np.int32)
    window_num[0] = preds.shape[1]
    decode = keras.backend.ctc_decode(preds, window_num, greedy=True, beam_width=100, top_paths=1)
    result_index = keras.backend.get_value(decode[0][0])[0]
    result_py = []
    for i in result_index:
        try:
            result_py.append(py_list[i])
        except IndexError:
            print(i, 'not in py_list')
    return result_index, result_py

def create_model(input_size, output_size):
    inputs = Input(name='the_inputs', shape=input_size)
    # 1
    h1_1 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-1')(inputs)
    h1_2 = BatchNormalization(name='BatchNormal_1-1')(h1_1)
    h1_3 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-2')(h1_2)
    h1_4 = BatchNormalization(name='BatchNormal_1-2')(h1_3)
    h1_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_1')(h1_4)
    # 2
    h2_1 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-1')(h1_5)
    h2_2 = BatchNormalization(name='BatchNormal_2-1')(h2_1)
    h2_3 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-2')(h2_2)
    h2_4 = BatchNormalization(name='BatchNormal_2-2')(h2_3)
    h2_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_2')(h2_4)
    # 3
    h3_1 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-1')(h2_5)
    h3_2 = BatchNormalization(name='BatchNormal_3-1')(h3_1)
    h3_3 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-2')(h3_2)
    h3_4 = BatchNormalization(name='BatchNormal_3-2')(h3_3)
    h3_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_3')(h3_4)
    # 4
    h4_1 = Conv2D(512, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-1')(h3_5)
    h4_2 = BatchNormalization(name='BatchNormal_4-1')(h4_1)
    h4_3 = Conv2D(512, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-2')(h4_2)
    h4_4 = BatchNormalization(name='BatchNormal_4-2')(h4_3)
    # 由于声学模型网络结构原因（3个maxpooling层），我们的音频数据的每个维度需要能够被8整除。这里输入序列经过卷积网络后，长度缩短了8倍，因此我们训练实际输入的数据为wav_len//8。
    h5_1 = Reshape((-1, int(input_size[1]//8*512)), name='Reshape_1')(h4_4)
    lstm_1 = LSTM(128, return_sequences=True, kernel_initializer='he_normal', name='lstm1')(h5_1)
    lstm_2 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm2')(lstm_1)
    
    h5_2 = Dense(512, activation='relu', use_bias=True, kernel_initializer='he_normal', name='Dense_1')(lstm_2)
    h5_3 = Dense(output_size, activation="relu", use_bias=True, kernel_initializer='he_normal', name='Dense_2')(h5_2)#(layer_h15)

    outputs = Activation('softmax', name='Activation_1')(h5_3)

    base_model = keras.Model(inputs=inputs, outputs=outputs)

    labels = Input(name='the_labels', shape=[None], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')
    # keras.Lambda(function, output_shape=None, mask=None, arguments=None)
    # 将任意表达式封装为 Layer 对象
    loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, outputs, input_length, label_length])
    ctc_model = keras.Model(inputs=[labels, inputs, input_length, label_length], outputs=loss_out)

    opt = keras.optimizers.Adam(lr=0.0008, beta_1=0.9, beta_2=0.999, decay=0.01, epsilon=10e-8)
    # ctc_model=multi_gpu_model(ctc_model,gpus=2)
    ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt, metrics=['accuracy'])

    return base_model, ctc_model

In [11]:
# 设置GPU内存按需分配
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

# 开始训练、测试

In [14]:
image_size=64  
batch_size = 16
logs_path = '../logs'

input_size=(None, image_size, 1)
base_model, ctc_model = create_model(input_size, output_size=py_list_size) 
print(ctc_model.summary())

# keras.utils.plot_model(base_model, show_shapes=True)

if not os.path.exists(logs_path):  # 判断保存模型的目录是否存在
    os.makedirs(logs_path)  # 如果不存在，就新建一个，避免之后保存模型的时候炸掉

# train_batch_gen = get_batch_generator(batch_size, train_wav_file_list, train_label_list, py_list, image_size)
# validate_batch_gen = get_batch_generator(batch_size, validate_wav_file_list, validate_label_list, py_list, image_size)
# input_data = next(train_batch_gen)[0]
# plt.imshow(input_data['the_inputs'][0].T[0])
# plt.show()
# print(input_data['the_inputs'].shape, input_data['the_labels'].shape, input_data['input_length'].shape, input_data['label_length'].shape)

train_data = get_train_data(train_wav_file_list, train_label_list, py_list, image_size)

cb = []
cb.append(keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0))
# 当监测值不再改善时，该回调函数将中止训练可防止过拟合
cb.append(keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20, verbose=1, mode='auto'))
# his = ctc_model.fit_generator(train_batch_gen, verbose=1, steps_per_epoch=train_file_nums//batch_size, validation_data=validate_batch_gen, validation_steps=validate_file_nums//batch_size, epochs=1000, callbacks=cb)  

his = ctc_model.fit(train_data[0], train_data[1], validation_split=0.1, batch_size=32, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
ctc_model.save_weights(r'save_weights.h5')
with open(r'model_struct.json', 'w') as f:
    json_string = base_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
the_inputs (InputLayer)         [(None, None, 64, 1) 0                                            
__________________________________________________________________________________________________
Conv2D_1-1 (Conv2D)             (None, None, 64, 64) 640         the_inputs[0][0]                 
__________________________________________________________________________________________________
BatchNormal_1-1 (BatchNormaliza (None, None, 64, 64) 256         Conv2D_1-1[0][0]                 
__________________________________________________________________________________________________
Conv2D_1-2 (Conv2D)             (None, None, 64, 64) 36928       BatchNormal_1-1[0][0]            
____________________________________________________________________________________________

In [15]:
# 加载权重
with open(r'model_struct.json') as f:
  model_struct = f.read()
test_model = keras.models.model_from_json(model_struct)
test_model.load_weights(r'save_weights.h5')
# model = keras.models.load_model('all_model.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')

# 对模型进行评价
def evaluate(kind, wavs, labels):
  data_num = len(wavs)
  error_cnt = 0
  for i in range(data_num):
    pre = test_model.predict(wavs[i]) # (1, 20, 11)
    pre_index, pre_label = decode_ctc(pre, py_list) # ['5']
    try:
      pre_label = int(pre_label[0])
    except:
      pre_label = None
    label = int(py_list[labels[i]])
    if label != pre_label:
      error_cnt += 1
      print('真实标签：', label, '预测结果', pre_label)
  print('{}:样本数{}错误数{}准确率：{:%}'.format(kind, data_num, error_cnt, (1-error_cnt/data_num)))

# 训练集
train_wavs, train_labels = get_test_data(image_size, train_wav_file_list, train_label_list, py_list)
# 测试集
test_wavs, test_labels = get_test_data(image_size, test_wav_file_list, test_label_list, py_list)

evaluate('trian', train_wavs, train_labels)
evaluate('test', test_wavs, test_labels)

模型已加载
py_list已加载
trian:样本数260错误数0准确率：100.000000%


ZeroDivisionError: division by zero

# ---------------------------------------------------7.对比实验

# CNN

In [None]:
def create_cnn_model(input_size, output_size):
    inputs = Input(name='the_inputs', shape=input_size)
    # 1
    h1_1 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-1')(inputs)
    h1_2 = BatchNormalization(name='BatchNormal_1-1')(h1_1)
    h1_3 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-2')(h1_2)
    h1_4 = BatchNormalization(name='BatchNormal_1-2')(h1_3)
    h1_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_1')(h1_4)
    # 2
    h2_1 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-1')(h1_5)
    h2_2 = BatchNormalization(name='BatchNormal_2-1')(h2_1)
    h2_3 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-2')(h2_2)
    h2_4 = BatchNormalization(name='BatchNormal_2-2')(h2_3)
    h2_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_2')(h2_4)
    # 3
    h3_1 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-1')(h2_5)
    h3_2 = BatchNormalization(name='BatchNormal_3-1')(h3_1)
    h3_3 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-2')(h3_2)
    h3_4 = BatchNormalization(name='BatchNormal_3-2')(h3_3)
    h3_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_3')(h3_4)
    # 4
    h4_1 = Conv2D(521, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-1')(h3_5)
    h4_2 = BatchNormalization(name='BatchNormal_4-1')(h4_1)
    h4_3 = Conv2D(512, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-2')(h4_2)
    h4_4 = BatchNormalization(name='BatchNormal_4-2')(h4_3)
    # 由于声学模型网络结构原因（3个maxpooling层），我们的音频数据的每个维度需要能够被8整除。这里输入序列经过卷积网络后，长度缩短了8倍，因此我们训练实际输入的数据为wav_len//8。
    h5_1 = Reshape((-1, int(input_size[1]//8*512)), name='Reshape_1')(h4_4)
    h5_2 = Dense(512, activation='relu', use_bias=True, kernel_initializer='he_normal', name='Dense_1')(h5_1)
    h5_3 = Dense(output_size, activation="relu", use_bias=True, kernel_initializer='he_normal', name='Dense_2')(h5_2)#(layer_h15)

    outputs = Activation('softmax', name='Activation_1')(h5_3)

    base_model = keras.Model(inputs=inputs, outputs=outputs)

    labels = Input(name='the_labels', shape=[None], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')
    # keras.Lambda(function, output_shape=None, mask=None, arguments=None)
    # 将任意表达式封装为 Layer 对象
    loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, outputs, input_length, label_length])
    ctc_model = keras.Model(inputs=[labels, inputs, input_length, label_length], outputs=loss_out)

    opt = keras.optimizers.Adam(lr=0.0008, beta_1=0.9, beta_2=0.999, decay=0.01, epsilon=10e-8)
    # ctc_model=multi_gpu_model(ctc_model,gpus=2)
    ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt, metrics=['accuracy'])

    return base_model, ctc_model
base_cnn_model, ctc_cnn_model = create_cnn_model(input_size, output_size=py_list_size) 

his = ctc_cnn_model.fit(train_data[0], train_data[1], validation_split=0.1, batch_size=32, verbose=0, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
ctc_cnn_model.save_weights(r'save_cnn_weights.h5')
with open(r'model_cnn_struct.json', 'w') as f:
    json_string = base_cnn_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

In [None]:
# 加载权重
with open(r'model_cnn_struct.json') as f:
  model_cnn_struct = f.read()
test_cnn_model = keras.models.model_from_json(model_cnn_struct)
# model.summary()
test_cnn_model.load_weights(r'save_cnn_weights.h5')
# model = keras.models.load_model('all_model.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')

# 对模型进行评价
def evaluate(kind, wavs, labels):
  data_num = len(wavs)
  error_cnt = 0
  for i in range(data_num):
    pre = test_cnn_model.predict(wavs[i]) # (1, 20, 11)
    pre_index, pre_label = decode_ctc(pre, py_list) # ['5']
    try:
      pre_label = int(pre_label[0])
    except:
      pre_label = None
    label = int(py_list[labels[i]])
    if label != pre_label:
      error_cnt += 1
      print('真实标签：', label, '预测结果', pre_label)
  print('{}:样本数{}错误数{}准确率：{:%}'.format(kind, data_num, error_cnt, (1-error_cnt/data_num)))

evaluate('trian', train_wavs, train_labels)
evaluate('test', test_wavs, test_labels)

#LSTM

In [None]:
# #LSTM minist分类
# from keras.datasets import mnist
# n_input = 28
# n_step = 28
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# x_train = x_train.reshape(-1, n_step, n_input, 1)
# x_test = x_test.reshape(-1, n_step, n_input, 1)
# x_train = x_train.astype('float32')
# x_test = x_test.astype('float32')
# x_train /= 255
# x_test /= 255

# y_train = keras.utils.to_categorical(y_train, n_classes)
# y_test = keras.utils.to_categorical(y_test, n_classes)
# inputs = Input(name='the_inputs', shape=(n_step, n_input, 1))
# inner = Reshape((n_step, n_input),name='Reshape_1')(inputs)
# # expected input data shape: (batch_size, timesteps, data_dim)
# lstm_1 = LSTM(128, kernel_initializer='he_normal', name='lstm1')(inner)
# d1 = Dense(10)(lstm_1)
# outputs = Activation('softmax')(d1)
# model = keras.Model(inputs, outputs)

# model.summary()
# model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
# model.fit(x_train, y_train,batch_size=128,epochs=20,verbose=1,validation_data=(x_test, y_test))
# scores = model.evaluate(x_test, y_test, verbose=0)
# print('LSTM test score:', scores[0])
# print('LSTM test accuracy:', scores[1])

In [None]:
# LSTM太差
def create_lstm_model(input_size, output_size):
    inputs = Input(name='the_inputs', shape=input_size)
    inner = Reshape((72, 64),name='Reshape_1')(inputs)
    # expected input data shape: (batch_size, timesteps, data_dim)
    lstm_1 = Bidirectional(LSTM(32, return_sequences=True, kernel_initializer='he_normal', name='lstm1'))(inner)
    
    lstm_2 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='he_normal', name='lstm3'))(lstm_1)

    
    h5_2 = Dense(512, activation='relu', use_bias=True, kernel_initializer='he_normal', name='Dense_4')(lstm_2)
    h5_3 = Dense(output_size, activation="relu", use_bias=True, kernel_initializer='he_normal', name='Dense_5')(h5_2)#(layer_h15)

    outputs = Activation('softmax', name='Activation_1')(h5_3)

    base_model = keras.Model(inputs=inputs, outputs=outputs)

    labels = Input(name='the_labels', shape=[None], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')
    # keras.Lambda(function, output_shape=None, mask=None, arguments=None)
    # 将任意表达式封装为 Layer 对象
    loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, outputs, input_length, label_length])
    ctc_model = keras.Model(inputs=[labels, inputs, input_length, label_length], outputs=loss_out)

    opt = keras.optimizers.Adam(lr=0.0008, beta_1=0.9, beta_2=0.999, decay=0.01, epsilon=10e-8)
    # ctc_model=multi_gpu_model(ctc_model,gpus=2)
    ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt, metrics=['accuracy'])

    return base_model, ctc_model

base_lstm_model, ctc_lstm_model = create_lstm_model(input_size=(72, 64, 1), output_size=py_list_size) 
#ctc_lstm_model.summary()
train_data = get_train_data(train_wav_file_list, train_label_list, py_list, image_size) 
his = ctc_lstm_model.fit(train_data[0], train_data[1], validation_split=0.1, batch_size=32, verbose=0, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
ctc_lstm_model.save_weights(r'save_lstm_weights.h5')
with open(r'model_lstm_struct.json', 'w') as f:
    json_string = base_lstm_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

# 加载权重
with open(r'model_lstm_struct.json') as f:
  model_lstm_struct = f.read()
test_lstm_model = keras.models.model_from_json(model_lstm_struct)

test_lstm_model.load_weights(r'save_lstm_weights.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')

In [None]:
# 对模型进行评价
def evaluate(kind, wavs, labels):
  data_num = len(wavs)
  error_cnt = 0
  for i in range(data_num):
    pre = test_lstm_model.predict(wavs[i]) # (1, 20, 11)
    pre_index, pre_label = decode_ctc(pre, py_list) # ['5']
    try:
      pre_label = int(pre_label[0])
    except:
      pre_label = None
    label = int(py_list[labels[i]])
    if label != pre_label:
      error_cnt += 1
      print('真实标签：', label, '预测结果', pre_label)
  print('{}:样本数{}错误数{}准确率：{:%}'.format(kind, data_num, error_cnt, (1-error_cnt/data_num)))

# 训练集
train_wavs, train_labels = get_test_data(image_size, train_wav_file_list, train_label_list, py_list)
# 测试集
test_wavs, test_labels = get_test_data(image_size, test_wav_file_list, test_label_list, py_list)

evaluate('trian', train_wavs, train_labels)
evaluate('test', test_wavs, test_labels)

# 交叉熵损失函数的CRNN


In [None]:
def create_cross_model(input_size, output_size):
    inputs = Input(name='the_inputs', shape=input_size)
    # 1
    h1_1 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-1')(inputs)
    h1_2 = BatchNormalization(name='BatchNormal_1-1')(h1_1)
    h1_3 = Conv2D(64, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_1-2')(h1_2)
    h1_4 = BatchNormalization(name='BatchNormal_1-2')(h1_3)
    h1_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_1')(h1_4)
    # 2
    h2_1 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-1')(h1_5)
    h2_2 = BatchNormalization(name='BatchNormal_2-1')(h2_1)
    h2_3 = Conv2D(128, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_2-2')(h2_2)
    h2_4 = BatchNormalization(name='BatchNormal_2-2')(h2_3)
    h2_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_2')(h2_4)
    # 3
    h3_1 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-1')(h2_5)
    h3_2 = BatchNormalization(name='BatchNormal_3-1')(h3_1)
    h3_3 = Conv2D(256, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_3-2')(h3_2)
    h3_4 = BatchNormalization(name='BatchNormal_3-2')(h3_3)
    h3_5 = MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid", name='MaxPooling2D_3')(h3_4)
    # 4
    h4_1 = Conv2D(512, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-1')(h3_5)
    h4_2 = BatchNormalization(name='BatchNormal_4-1')(h4_1)
    h4_3 = Conv2D(512, (3, 3), use_bias=True, activation='relu', padding='same', kernel_initializer='he_normal', name='Conv2D_4-2')(h4_2)
    h4_4 = BatchNormalization(name='BatchNormal_4-2')(h4_3)
    # 由于声学模型网络结构原因（3个maxpooling层），我们的音频数据的每个维度需要能够被8整除。这里输入序列经过卷积网络后，长度缩短了8倍，因此我们训练实际输入的数据为wav_len//8。
    h5_1 = Reshape((9, int(input_size[1]//8*512)), name='Reshape_1')(h4_4)

    lstm_1 = LSTM(128, return_sequences=True, kernel_initializer='he_normal', name='lstm1')(h5_1)
    lstm_2 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm2')(lstm_1)
    h5_11 = Flatten()(lstm_2)
    h5_2 = Dense(512, activation='relu', use_bias=True, kernel_initializer='he_normal', name='Dense_1')(h5_11)
    h5_3 = Dense(output_size, activation="relu", use_bias=True, kernel_initializer='he_normal', name='Dense_2')(h5_2)#(layer_h15)
    outputs = Activation('softmax', name='Activation_1')(h5_3)
    base_model = keras.Model(inputs=inputs, outputs=outputs)

    opt = keras.optimizers.Adam(lr=0.001, beta_1=0.09, beta_2=0.999, decay=0.1, epsilon=10e-8)
    # ctc_model=multi_gpu_model(ctc_model,gpus=2)
    base_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

    return base_model

crnn_cross_model = create_cross_model(input_size=(72, 64, 1), output_size=6) 

train_data = get_train_data(train_wav_file_list, train_label_list, py_list, image_size) 
from keras.utils import np_utils

l = np_utils.to_categorical(train_data[0]['the_labels'], num_classes=6)
his = crnn_cross_model.fit(train_data[0]['the_inputs'], l, validation_split=0.2, batch_size=32, verbose=0, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
crnn_cross_model.save_weights(r'save_crnn_cross_weights.h5')
with open(r'model_crnn_cross_struct.json', 'w') as f:
    json_string = crnn_cross_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

In [None]:
# 加载权重
with open(r'model_crnn_cross_struct.json') as f:
  crnn_cross_struct = f.read()
test_crnn_cross_model = keras.models.model_from_json(crnn_cross_struct)
test_crnn_cross_model.load_weights(r'save_crnn_cross_weights.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')
loss, accuracy = crnn_cross_model.evaluate(np.array(test_wavs).reshape((60, 72, 64, 1)), np_utils.to_categorical(np.array(test_labels),num_classes=6))

print('准确率：{:%}'.format(accuracy))

In [None]:
# shown_offset = re.findall(re.compile(r'<ul.*?id="feedlist_id" shown-offset="(.*?)">', re.S), response.text)[0]


# 选择方差第二大的子载波

In [None]:
def process_data(raw_file_path, save_path):
  X={}
  for f in tqdm(os.listdir(raw_file_path)):
    if f.endswith('.dat'):
      file_name = os.path.join(raw_file_path, f)
      extracted_data = extract_csi(file_name)
      print('processing {} the length of this file is:{}'.format(file_name, len(extracted_data)))
      tx, rx, sub = extracted_data[0]['csi'].shape
      data_csi = np.zeros((len(extracted_data), tx, rx, sub), dtype=np.complex64)
      for i in range(len(extracted_data)):
        data_csi[i] = get_scaled_csi(extracted_data[i])
      data_csi = np.clip(np.abs(np.squeeze(data_csi)), 1e-8, 1e100)[:,:,:2,:].reshape(-1, 120)   # (205, 2, 2, 30)
      data = np.zeros((data_csi.shape[0],120)) 
      var = [] 
      for i in range(120):
        data_csi_sub = data_csi[:, i]  
        b, a = signal.butter(5, 4*2/30, 'low')
        carrier_data = signal.lfilter(b, a, data_csi_sub) # N*1
        data[:, i] = carrier_data
        
        # length = len(carrier_data)
        # var_temp = np.var(carrier_data[length//5:3*length//5]) 
        # var.append(var_temp)
      
      #data = data_csi[:, np.argsort(var)[1]]
      #print(np.argsort(var)[1])
      scio.savemat(os.path.join(save_path, f.split('.')[0]+'.mat'), {'csi': data})
      X[f]=data
  print('all raw file processed')
  return X

train_raw_file_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/train'
train_save_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/train'
X_train = process_data(train_raw_file_path, train_save_path)
test_raw_file_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/test'
test_save_path = r'/content/drive/My Drive/SR_CSI//Gestures_data/test'
X_test = process_data(test_raw_file_path, test_save_path)

In [None]:
# 重新构造mat文件选择最佳子载波上的CSI
best_index_list = [58, 1, 29, 29, 28, 29]

#
for f in tqdm(os.listdir(train_save_path)):
    if f.endswith('.mat'):
      file_name = os.path.join(train_save_path, f)
      activate_index = int(re.findall(re.compile(r'csi-s1-e1-a(.*?)-.*?.mat', re.S), f)[0])
      csi = np.squeeze(scio.loadmat(file_name)['csi'])
      best_csi = csi[:, best_index_list[activate_index-1]]     
      scio.savemat(file_name, {'csi': best_csi})
for f in tqdm(os.listdir(test_save_path)):
    if f.endswith('.mat'):
      file_name = os.path.join(test_save_path, f)
      activate_index = int(re.findall(re.compile(r'csi-s1-e1-a(.*?)-.*?.mat', re.S), f)[0])
      best_csi = np.squeeze(scio.loadmat(file_name)['csi'])[:, best_index_list[activate_index-1]]
      scio.savemat(file_name, {'csi': best_csi})

In [None]:
train_file_path = train_save_path
train_label_file_list, train_wav_file_list = get_file_list(train_file_path)

train_label_list = get_label_data(train_label_file_list)

# 用迭代器的时候需要
# train_wav_file_list, validate_wav_file_list, train_label_list, validate_label_list = train_test_split(train_wav_file_list, train_label_list, test_size=0.2, random_state=0)

# 每个文件拼音标签的集合
with open('label_list.txt', 'w') as f:  
    f.write('\n'.join(train_label_list))  
py_list = gen_py_list(train_label_list)  # 拼音的集合

with open('pinyin_list.txt', 'w') as f:
    f.write('\n'.join(py_list))  # 保存拼音列表

train_file_nums = len(train_wav_file_list)
# validate_file_nums = len(validate_wav_file_list)

py_list_size = len(py_list)  # 模型输出的维度
print('py_list:', py_list)
print('train files amount:', len(train_label_list), train_file_nums, 'label amount:', py_list_size)#, 'validate files amount:', len(validate_label_list), validate_file_nums,)

# 测试数据
test_file_path = test_save_path # r'/content/drive/My Drive/data_thchs30'
test_label_file_list, test_wav_file_list = get_file_list(test_file_path)

test_label_list = get_label_data(test_label_file_list)
print('test files amount:',len(test_label_file_list), len(test_wav_file_list))
test_data_num = len(test_label_file_list)

In [None]:
base_sec_model, ctc_sec_model = create_model(input_size, output_size=py_list_size) 

train_data = get_train_data(train_wav_file_list, train_label_list, py_list, image_size)


cb = []
cb.append(keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0))
# 当监测值不再改善时，该回调函数将中止训练可防止过拟合
cb.append(keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20, verbose=1, mode='auto'))
# his = ctc_model.fit_generator(train_batch_gen, verbose=1, steps_per_epoch=train_file_nums//batch_size, validation_data=validate_batch_gen, validation_steps=validate_file_nums//batch_size, epochs=1000, callbacks=cb)  


his = ctc_sec_model.fit(train_data[0], train_data[1], validation_split=0.1, batch_size=32, verbose=0, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
ctc_sec_model.save_weights(r'save_sec_weights.h5')
with open(r'model_sec_struct.json', 'w') as f:
    json_string = base_sec_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

In [None]:
# 加载权重
with open(r'model_sec_struct.json') as f:
  model_struct = f.read()
test_sec_model = keras.models.model_from_json(model_struct)
test_sec_model.load_weights(r'save_sec_weights.h5')
# model = keras.models.load_model('all_model.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')

# 对模型进行评价
def evaluate(kind, wavs, labels):
  data_num = len(wavs)
  error_cnt = 0
  for i in range(data_num):
    pre = test_sec_model.predict(wavs[i]) # (1, 20, 11)
    pre_index, pre_label = decode_ctc(pre, py_list) # ['5']
    try:
      pre_label = int(pre_label[0])
    except:
      pre_label = None
    label = int(py_list[labels[i]])
    if label != pre_label:
      error_cnt += 1
      print('真实标签：', label, '预测结果', pre_label)
  print('{}:样本数{}错误数{}准确率：{:%}'.format(kind, data_num, error_cnt, (1-error_cnt/data_num)))

# 训练集
train_wavs, train_labels = get_test_data(image_size, train_wav_file_list, train_label_list, py_list)
# 测试集
test_wavs, test_labels = get_test_data(image_size, test_wav_file_list, test_label_list, py_list)

evaluate('trian', train_wavs, train_labels)
evaluate('test', test_wavs, test_labels)

# 所有子载波求平均

In [None]:
def process_data(raw_file_path, save_path):
  X={}
  for f in tqdm(os.listdir(raw_file_path)):
    if f.endswith('.dat'):
      file_name = os.path.join(raw_file_path, f)
      extracted_data = extract_csi(file_name)
      print('processing {} the length of this file is:{}'.format(file_name, len(extracted_data)))
      tx, rx, sub = extracted_data[0]['csi'].shape
      data_csi = np.zeros((len(extracted_data), tx, rx, sub), dtype=np.complex64)
      for i in range(len(extracted_data)):
        data_csi[i] = get_scaled_csi(extracted_data[i])
      data_csi = np.clip(np.abs(np.squeeze(data_csi)), 1e-8, 1e100)[:,:,:2,:].reshape(-1, 4, 30)   # (205, 2, 2, 30)
      data = []
      for i in range(4):
        data_ave = np.average(data_csi[:, i, :],axis=1)
        data.extend(data_ave)
      scio.savemat(os.path.join(save_path, f.split('.')[0]+'.mat'), {'csi': data})
      X[f]=data
  print('all raw file processed')
  return X

train_raw_file_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/train'
train_save_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/train'
X_train = process_data(train_raw_file_path, train_save_path)
test_raw_file_path = r'/content/drive/My Drive/SR_CSI/Gestures_data/test'
test_save_path = r'/content/drive/My Drive/SR_CSI//Gestures_data/test'
X_test = process_data(test_raw_file_path, test_save_path)

In [None]:
train_file_path = train_save_path
train_label_file_list, train_wav_file_list = get_file_list(train_file_path)

train_label_list = get_label_data(train_label_file_list)

# 用迭代器的时候需要
# train_wav_file_list, validate_wav_file_list, train_label_list, validate_label_list = train_test_split(train_wav_file_list, train_label_list, test_size=0.2, random_state=0)

# 每个文件拼音标签的集合
with open('label_list.txt', 'w') as f:  
    f.write('\n'.join(train_label_list))  
py_list = gen_py_list(train_label_list)  # 拼音的集合

with open('pinyin_list.txt', 'w') as f:
    f.write('\n'.join(py_list))  # 保存拼音列表

train_file_nums = len(train_wav_file_list)
# validate_file_nums = len(validate_wav_file_list)

py_list_size = len(py_list)  # 模型输出的维度
print('py_list:', py_list)
print('train files amount:', len(train_label_list), train_file_nums, 'label amount:', py_list_size)#, 'validate files amount:', len(validate_label_list), validate_file_nums,)

# 测试数据
test_file_path = test_save_path # r'/content/drive/My Drive/data_thchs30'
test_label_file_list, test_wav_file_list = get_file_list(test_file_path)

test_label_list = get_label_data(test_label_file_list)
print('test files amount:',len(test_label_file_list), len(test_wav_file_list))
test_data_num = len(test_label_file_list)

In [None]:
base_ave_model, ctc_ave_model = create_model(input_size=(None, 256, 1), output_size=py_list_size) 

train_data = get_train_data(train_wav_file_list, train_label_list, py_list, image_size=256)

his = ctc_ave_model.fit(train_data[0], train_data[1], validation_split=0.1, batch_size=32, verbose=1, epochs=1000, callbacks=cb)  # callback的epoch都是对fit里的参数来说

#  保存模型结构及权重
ctc_ave_model.save_weights(r'save_ave_weights.h5')
with open(r'model_ave_struct.json', 'w') as f:
    json_string = base_ave_model.to_json()
    f.write(json_string)  # 保存模型信息
print('模型结构及权重已保存')

In [None]:
# 加载权重
with open(r'model_ave_struct.json') as f:
  model_struct = f.read()
test_ave_model = keras.models.model_from_json(model_struct)
test_ave_model.load_weights(r'save_ave_weights.h5')
# model = keras.models.load_model('all_model.h5')
print('模型已加载')
py_list = []
with open(r'pinyin_list.txt', 'r') as f:
    contents = f.readlines()
for line in contents:
    i = line.strip('\n')
    py_list.append(i)
print('py_list已加载')

# 对模型进行评价
def evaluate(kind, wavs, labels):
  data_num = len(wavs)
  error_cnt = 0
  for i in range(data_num):
    pre = test_sec_model.predict(wavs[i]) # (1, 20, 11)
    pre_index, pre_label = decode_ctc(pre, py_list) # ['5']
    try:
      pre_label = int(pre_label[0])
    except:
      pre_label = None
    label = int(py_list[labels[i]])
    if label != pre_label:
      error_cnt += 1
      print('真实标签：', label, '预测结果', pre_label)
  print('{}:样本数{}错误数{}准确率：{:%}'.format(kind, data_num, error_cnt, (1-error_cnt/data_num)))

# 训练集
train_wavs, train_labels = get_test_data(image_size, train_wav_file_list, train_label_list, py_list)
# 测试集
test_wavs, test_labels = get_test_data(image_size, test_wav_file_list, test_label_list, py_list)

evaluate('trian', train_wavs, train_labels)
evaluate('test', test_wavs, test_labels)

#### ROC、AUC曲线 

In [None]:
# from sklearn.datasets import make_classification
# from sklearn.preprocessing import label_binarize
# from keras.models import Sequential
# from keras.layers import Dense
# import numpy as np
# from scipy import interp
# import matplotlib.pyplot as plt
# from itertools import cycle
# from sklearn.model_selection import train_test_split
# from sklearn.metrics import roc_curve, auc

# # 标签共三类
# n_classes = py_list_size-1

# X, y = make_classification(n_samples=80000, n_features=20, n_informative=3, n_redundant=0, n_classes=n_classes,
#     n_clusters_per_class=2)
# # print(X.shape, y.shape)
# # print(X[0], y[0])
# # (80000, 20) (80000,)
# # [-1.90920853 -1.30052757 -0.76903467 -3.2546519  -0.02947816  0.14105006
# #   0.43556031 -0.81300607 -0.94553296 -0.92774495  1.49041451 -0.4443121
# #  -1.16342165 -0.32997815 -1.02907045 -0.39950447 -0.711287    0.51382424
# #   2.88822258 -2.0935274 ] 
# # 1

# # Binarize the output相当于one_hot
# y = label_binarize(y, classes=[0, 1, 2])
# # print(y.shape, y[0])
# # (80000, 3) [0 1 0]

# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
# model = Sequential()
# model.add(Dense(20, input_dim=20, activation='relu'))
# model.add(Dense(40, activation='relu'))
# model.add(Dense(3, activation='softmax'))
# model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# model.fit(X_train, y_train, epochs=1, batch_size=100, verbose=1)

# y_pred = model.predict(X_test)
# # print(y_pred.shape)
# # (40000, 3)

# # Compute ROC curve and ROC area for each class
# fpr = dict()
# tpr = dict()
# roc_auc = dict()
# for i in range(n_classes):
#     # scores = np.array([0.1, 0.4, 0.35, 0.8])
#     # fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
#     # y 就是标准值，scores 是每个预测值对应的阳性概率，比如0.1就是指第一个数预测为阳性的概率为0.1，很显然，
#     # y 和 socres应该有相同多的元素，都等于样本数。pos_label=2 是指在y中标签为2的是标准阳性标签，其余值是阴性。
#     # 接下来选取一个阈值计算TPR/FPR,阈值的选取规则是在scores值中从大到小的以此选取，于是第一个选取的阈值是0.8
#     # label=[1,1,2,2] scores=[0.1,0.4,0.35,0.8] thresholds=[0.8,0.4,0.35,0.1] 以threshold为0.8为例，将0.8与
#     # scores 中所有值比较大小得到预测值，[0,0,0,1].对于label中两个1，其概率分别为0.1，0.4，小于阈值0.8，判定为
#     # 负样本，而他们的label是1，说明他们确实是负样本，判断正确，是两个TN；两个2，对应概率为0.35，0.8，0.35小于
#     # 0.8，判定为负样本，但是label是2，应该是个正样本，所以这是个FN；最后0.8>=0.8,这是个TP，所以最后的结果是
#     # ：1个TP，2个TN，1个FN，0个FP
#     fpr[i], tpr[i], thresholds = roc_curve(y_test[:, i], y_pred[:, i])  # (40000,)
#     # print(fpr[i].shape)# (5491,)# (6562,)# (4271,)
#     roc_auc[i] = auc(fpr[i], tpr[i])
    

# # 计算microROC曲线和ROC面积 
# # .ravel()将多维数组转换为一维数组
# fpr["micro"], tpr["micro"]  , thresholds = roc_curve(y_test.ravel(), y_pred.ravel())  #  (120000,)
# roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# # 计算macroROC曲线和ROC面积
# # 首先，汇总所有的假阳性率
# # np.unique() 该函数是去除数组中的重复数字，并进行排序之后输出。
# # print(np.concatenate([fpr[i] for i in range(n_classes)]).shape) (16324,)
# all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))  # (7901,)
# # 然后插值所有的ROC曲线在这一点
# # np.zeros_like() 这个函数的意思就是生成一个和你所给数组a相同shape的全0数组。
# mean_tpr = np.zeros_like(all_fpr)
# for i in range(n_classes):
#     mean_tpr += interp(all_fpr, fpr[i], tpr[i])
    
# # 最后求平均值并计算AUC
# mean_tpr /= n_classes
# fpr["macro"] = all_fpr
# tpr["macro"] = mean_tpr
# roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# # Plot all ROC curves
# plt.figure(1)
# plt.plot(fpr["micro"], tpr["micro"], color='deeppink', linestyle=':', linewidth=4,
#          label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]))

# plt.plot(fpr["macro"], tpr["macro"],color='navy', linestyle=':', linewidth=4,
#          label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]))

# colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
# for i, color in zip(range(n_classes), colors):
#     plt.plot(fpr[i], tpr[i], color=color, linewidth=2,
#              label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))

# plt.plot([0, 1], [0, 1], 'k--', linewidth=2)
# plt.xlim([0.0, 1.0])
# plt.ylim([0.0, 1.05])
# plt.xlabel('False Positive Rate')
# plt.ylabel('True Positive Rate')
# plt.title('Some extension of Receiver Operating Characteristic to multi-class')
# plt.legend(loc='best')
# plt.show()


# # Zoom in view of the upper left corner.
# plt.figure(2)
# plt.xlim(0, 0.2)
# plt.ylim(0.8, 1)
# plt.plot(fpr["micro"], tpr["micro"],color='deeppink', linestyle=':', linewidth=4,
#          label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]))

# plt.plot(fpr["macro"], tpr["macro"],color='navy', linestyle=':', linewidth=4,
#          label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]))

# colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
# for i, color in zip(range(n_classes), colors):
#     plt.plot(fpr[i], tpr[i], color=color, linewidth=2,
#              label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))

# plt.plot([0, 1], [0, 1], 'k--', linewidth=2)
# plt.xlabel('False Positive Rate')
# plt.ylabel('True Positive Rate')
# plt.title('ROC curve (zoomed in at top left)')
# plt.legend(loc='best')
# plt.show()

#### 混淆矩阵 

In [None]:
# import matplotlib.pyplot as plt
# from sklearn.metrics import confusion_matrix
# def plot_confusion_matrix(title, y_true, y_pred, labels):
#     cm = confusion_matrix(y_true, y_pred)
    
#     # np.newaxis的作用就是在这一位置增加一个一维，这一位置指的是np.newaxis所在的位置，比较抽象，需要配合例子理解。
#     # x1 = np.array([1, 2, 3, 4, 5])
#     # the shape of x1 is (5,)
#     # x1_new = x1[:, np.newaxis]
# # now, the shape of x1_new is (5, 1)


#     cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#     # print (cm, '\n\n', cm_normalized)
#     # [[1 0 0 0 0]                           
#     #  [0 1 0 0 0]
#     #  [0 0 1 0 0]
#     #  [0 0 0 1 0]
#     #  [0 0 0 0 1]]

#     #  [[1. 0. 0. 0. 0.]
#     #  [0. 1. 0. 0. 0.]
#     #  [0. 0. 1. 0. 0.]
#     #  [0. 0. 0. 1. 0.]
#     #  [0. 0. 0. 0. 1.]]
#     tick_marks = np.array(range(len(labels))) + 0.5
#     #  [0.5 1.5 2.5 3.5 4.5 5.5]
#     np.set_printoptions(precision=2)
    
#     plt.figure(figsize=(10, 8), dpi=120)
#     ind_array = np.arange(len(labels))
#     x, y = np.meshgrid(ind_array, ind_array)
#     # print(ind_ａrray, '\n\n', x, '\n\n', y)
#     # [0 1 2 3 4 5] 

#     #  [[0 1 2 3 4 5]
#     #  [0 1 2 3 4 5]
#     #  [0 1 2 3 4 5]
#     #  [0 1 2 3 4 5]
#     #  [0 1 2 3 4 5]
#     #  [0 1 2 3 4 5]] 

#     #  [[0 0 0 0 0 0]
#     #  [1 1 1 1 1 1]
#     #  [2 2 2 2 2 2]
#     #  [3 3 3 3 3 3]
#     #  [4 4 4 4 4 4]
#     #  [5 5 5 5 5 5]]
#     intFlag = 0 # 标记在图片中对文字是整数型还是浮点型
#     for x_val, y_val in zip(x.flatten(), y.flatten()):
#         # plt.text()函数用于设置文字说明。

#         if (intFlag):
#             c = cm[y_val][x_val]
#             plt.text(x_val, y_val, "%d" % (c,), color='red', fontsize=8, va='center', ha='center')

#         else:
#             c = cm_normalized[y_val][x_val]
#             if (c > 0.01):
#                 plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
#             else:
#                 plt.text(x_val, y_val, "%d" % (0,), color='red', fontsize=7, va='center', ha='center')
#     cmap = plt.cm.binary
#     if(intFlag):
#         plt.imshow(cm, interpolation='nearest', cmap=cmap)
#     else:
#         plt.imshow(cm_normalized, interpolation='nearest', cmap=cmap)
#     plt.gca().set_xticks(tick_marks, minor=True)
#     plt.gca().set_yticks(tick_marks, minor=True)
#     plt.gca().xaxis.set_ticks_position('none')
#     plt.gca().yaxis.set_ticks_position('none')
#     plt.grid(True, which='minor', linestyle='-')
#     plt.gcf().subplots_adjust(bottom=0.15)
#     plt.title(title)
#     plt.colorbar()
#     xlocations = np.array(range(len(labels)))
#     plt.xticks(xlocations, labels, rotation=90)
#     plt.yticks(xlocations, labels)
#     plt.ylabel('Index of True Classes')
#     plt.xlabel('Index of Predict Classes')
#     plt.savefig('confusion_matrix.jpg', dpi=300)
#     plt.show()
# title='Confusion Matrix'
# labels = ['A', 'B', 'C', 'F', 'G']
# y_true = 
# y_pred = [1, 2, 3, 4, 5]# np.loadtxt(r'/home/dingtom/b.txt')
# plot＿confusion_matrix(title, y_true,y_pred, labels)