In [1]:
import os
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
from utils.models import RegressionLoss
from utils.models import save_model
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from utils.utils import EarlyStopping, IterMeter, data_processing_DeepSpeech
import torch.nn.functional as F

import random
from utils.transforms import apply_delta_deltadelta, Transform_Compose
import matplotlib.pyplot as plt
import numpy as np
from utils.transforms import apply_MVN
import torch.nn as nn
from utils.utils import data_processing_DeepSpeech, GreedyDecoder
from jiwer import wer
from utils.database import Alaryngeal_data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time
import yaml
import os
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
from utils.models import MyLSTM, SpeechRecognitionModel
from utils.models import RegressionLoss
from utils.models import save_model
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from utils.utils import EarlyStopping, IterMeter, data_processing_DeepSpeech
import torch.nn.functional as F

import random
from utils.transforms import ema_random_rotate, ema_time_mask, ema_freq_mask, ema_sin_noise, ema_random_scale, ema_time_seg_mask
from utils.transforms import apply_delta_deltadelta, Transform_Compose
from utils.transforms import apply_MVN
import numpy as np
import torchaudio

In [33]:
def augmentation_parsing(config, train_transform):

    random_sin_noise_inj = config['data_augmentation']['random_sin_noise_inj']
    random_rotate_apply = config['data_augmentation']['random_rotate']
    random_time_mask = config['data_augmentation']['random_time_mask']
    random_freq_mask = config['data_augmentation']['random_freq_mask']
    random_scale = config['data_augmentation']['random_scale']
    random_time_seg_mask = config['data_augmentation']['random_time_seg_mask']   
    normalize_input = config['articulatory_data']['normalize_input']    

    if random_sin_noise_inj == True:
        ratio = config['random_sin_noise_inj']['ratio']
        noise_energy_ratio = config['random_sin_noise_inj']['noise_energy_ratio']
        noise_freq = config['random_sin_noise_inj']['noise_freq']
        fs = 100
        train_transform.append(ema_sin_noise(ratio, noise_energy_ratio, noise_freq, fs)) 

    if random_rotate_apply == True:
        ratio = config['random_rotate']['ratio']
        r_min = config['random_rotate']['r_min']
        r_max = config['random_rotate']['r_max']
        train_transform.append(ema_random_rotate(ratio,  [r_min, r_max])) 

    if random_scale == True:
        ratio = config['random_scale']['ratio']
        scale_min = config['random_scale']['scale_min']
        scale_max = config['random_scale']['scale_max']
        train_transform.append(ema_random_scale(ratio, scale_min, scale_max)) 
        
    train_transform.append(apply_delta_deltadelta()) 
    
    if normalize_input == True:
        norm_transform = [apply_delta_deltadelta()]
        norm_transforms_all = Transform_Compose(norm_transform)

        train_loader_norm = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    collate_fn=lambda x: data_processing_DeepSpeech(x, transforms = norm_transforms_all))
        EMA_all = {}
        i = 0
        for batch_idx, _data in enumerate(train_loader_norm):
            file_id, EMA, labels, input_lengths, label_lengths = _data 
            ema = EMA[0][0].T
            EMA_all[i] = ema
            i+=1

        EMA_block = np.concatenate([EMA_all[x] for x in EMA_all], 0)
        EMA_mean, EMA_std  = np.mean(EMA_block, 0), np.std(EMA_block, 0)
        
        train_transform.append(apply_MVN(EMA_mean, EMA_std))

    if random_time_mask == True:
        ratio = config['random_time_mask']['ratio']
        mask_num = config['random_time_mask']['mask_num']
        train_transform.append(ema_time_mask(ratio, mask_num))

    if random_freq_mask == True:
        ratio = config['random_freq_mask']['ratio']
        mask_num = config['random_freq_mask']['mask_num']
        train_transform.append(ema_freq_mask(ratio, mask_num))

    if random_time_seg_mask == True:
        ratio = config['random_time_seg_mask']['ratio']
        mask_num = config['random_time_seg_mask']['mask_num']
        mask_length = config['random_time_seg_mask']['mask_length']
        train_transform.append(ema_time_seg_mask(ratio, mask_num, mask_length))    
        
    return train_transform, EMA_mean, EMA_std
    

In [34]:
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [35]:
import yaml
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Dimension setup ###
config = yaml.load(open('conf/SSR_conf.yaml', 'r'), Loader=yaml.FullLoader)
sel_sensors = config['articulatory_data']['sel_sensors']
sel_dim = config['articulatory_data']['sel_dim'] 
delta = config['articulatory_data']['delta']
d = 3 if delta == True else 1
D_in = len(sel_sensors)*len(sel_dim)*d
D_out = 41

### Model setup ###
n_cnn_layers = config['deep_speech_setup']['n_cnn_layers']
n_rnn_layers = config['deep_speech_setup']['n_rnn_layers']    
rnn_dim = config['deep_speech_setup']['rnn_dim']
stride = config['deep_speech_setup']['stride']
dropout = config['deep_speech_setup']['dropout']

### Training setup ###
learning_rate = config['deep_speech_setup']['learning_rate']
batch_size = config['deep_speech_setup']['batch_size']
epochs = config['deep_speech_setup']['epochs']
early_stop = config['deep_speech_setup']['early_stop']
patient = config['deep_speech_setup']['patient']

In [36]:
from utils.IO_func import read_file_list

data_path = '/home/beiming/RAW_DATA/Haskins_IEEE'
SPK = 'DL001'
data_path_SPK = os.path.join(data_path, SPK)

filesets_path = os.path.join(data_path, 'filesets')
filesets_path_SPK = os.path.join(filesets_path, SPK)

file_id_list = read_file_list(os.path.join(filesets_path_SPK, 'file_id_list.scp'))
train_id_list = read_file_list(os.path.join(filesets_path_SPK, 'train_id_list.scp'))
valid_id_list = read_file_list(os.path.join(filesets_path_SPK, 'valid_id_list.scp'))
test_id_list = read_file_list(os.path.join(filesets_path_SPK, 'test_id_list.scp'))

In [37]:
train_dataset = Alaryngeal_data(data_path, train_id_list, transforms = None)
valid_dataset = Alaryngeal_data(data_path, valid_id_list, transforms = None)
test_dataset = Alaryngeal_data(data_path, test_id_list, transforms = None)

In [38]:
train_transform = []
valid_transform = []

train_transform, EMA_mean, EMA_std = augmentation_parsing(config, train_transform)

valid_transform.append(apply_delta_deltadelta())
valid_transform.append(apply_MVN(EMA_mean, EMA_std))

train_transforms_all = Transform_Compose(train_transform)
valid_transforms_all = Transform_Compose(valid_transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=lambda x: data_processing_DeepSpeech(x, transforms = train_transforms_all))

valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=lambda x: data_processing_DeepSpeech(x, transforms = valid_transforms_all))

In [39]:
model = SpeechRecognitionModel(n_cnn_layers, n_rnn_layers, rnn_dim, D_out, D_in, stride, dropout).to(device)

optimizer = torch.optim.AdamW(model.parameters(), learning_rate)
criterion = torch.nn.CTCLoss(blank=40).to(device)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=int(len(train_loader)), epochs=epochs, anneal_strategy='linear')

data_len = len(train_loader.dataset)
if early_stop == True:
    print('Applying early stop.')
    early_stopping = EarlyStopping(patience=patient)

iter_meter = IterMeter()

Applying early stop.


In [40]:
for epoch in range(epochs):
    model.train()
    loss_train = []
    for batch_idx, _data in enumerate(train_loader):
        file_id, ema, labels, input_lengths, label_lengths = _data 

        ema, labels = ema.to(device), labels.to(device)

        output = model(ema)  # (batch, time, n_class)

        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        iter_meter.step()

        loss_train.append(loss.detach().cpu().numpy())
    avg_loss_train = sum(loss_train)/len(loss_train)

    model.eval()
    loss_valid = []
    for batch_idx, _data in enumerate(valid_loader):  
        file_id, ema, labels, input_lengths, label_lengths = _data 
        ema, labels = ema.to(device), labels.to(device)           

        output = model(ema)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)
        loss = criterion(output, labels, input_lengths, label_lengths)    
        loss_valid.append(loss.detach().cpu().numpy())
    avg_loss_valid = sum(loss_valid)/len(loss_valid) 
    SPK = file_id[0][:3]

    early_stopping(avg_loss_valid)
    if early_stopping.early_stop:
        break

    print('epoch %-3d \t train_loss = %0.5f \t valid_loss = %0.5f' % (epoch, avg_loss_train, avg_loss_valid))

   # model_out_folder = os.path.join(exp_output_folder, 'trained_models')
    model_out_folder = 'trained_models'
    if not os.path.exists(model_out_folder):
        os.makedirs(model_out_folder)
    if early_stopping.save_model == True:
        save_model(model, os.path.join(model_out_folder, 'DL001' + '_DS'))

epoch 0   	 train_loss = 14.32792 	 valid_loss = 11.47437
epoch 1   	 train_loss = 11.10949 	 valid_loss = 5.93669
epoch 2   	 train_loss = 5.53574 	 valid_loss = 4.32457
INFO: Early stopping counter 1 of 10
epoch 3   	 train_loss = 4.24014 	 valid_loss = 5.99731
INFO: Early stopping counter 2 of 10
epoch 4   	 train_loss = 4.70192 	 valid_loss = 4.68415
epoch 5   	 train_loss = 4.00821 	 valid_loss = 3.64048
INFO: Early stopping counter 1 of 10
epoch 6   	 train_loss = 3.72243 	 valid_loss = 3.72334
INFO: Early stopping counter 2 of 10
epoch 7   	 train_loss = 3.60166 	 valid_loss = 3.93245
epoch 8   	 train_loss = 3.57489 	 valid_loss = 3.57667
INFO: Early stopping counter 1 of 10
epoch 9   	 train_loss = 3.51873 	 valid_loss = 3.61357
INFO: Early stopping counter 2 of 10
epoch 10  	 train_loss = 3.52181 	 valid_loss = 3.64482
epoch 11  	 train_loss = 3.47954 	 valid_loss = 3.53833
INFO: Early stopping counter 1 of 10
epoch 12  	 train_loss = 3.47639 	 valid_loss = 3.58028
INFO: Earl

In [43]:
test_transform = []
test_transform.append(apply_delta_deltadelta())
normalize_input = True
if normalize_input == True:
    norm_transform = [apply_delta_deltadelta()]
    norm_transforms_all = Transform_Compose(norm_transform)

    train_loader_norm = torch.utils.data.DataLoader(dataset=train_dataset,
                                batch_size=1,
                                shuffle=True,
                                collate_fn=lambda x: data_processing_DeepSpeech(x, transforms = norm_transforms_all))

    EMA_all = {}
    i = 0
    for batch_idx, _data in enumerate(train_loader_norm):
        file_id, EMA, labels, input_lengths, label_lengths = _data 
        ema = EMA[0][0].T
        EMA_all[i] = ema
        i+=1

    EMA_block = np.concatenate([EMA_all[x] for x in EMA_all], 0)
    EMA_mean, EMA_std  = np.mean(EMA_block, 0), np.std(EMA_block, 0)

    test_transform.append(apply_MVN(EMA_mean, EMA_std))

test_transforms_all = Transform_Compose(test_transform)

### Test ###
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                        batch_size=1,
                        shuffle=False,
                        collate_fn=lambda x: data_processing_DeepSpeech(x, transforms = test_transforms_all))

model_out_folder = 'trained_models'

SPK_model_path = os.path.join(model_out_folder)
model_path = os.path.join(SPK_model_path, 'DL001' + '_DS')
model = SpeechRecognitionModel(n_cnn_layers, n_rnn_layers, rnn_dim, D_out, D_in, stride, dropout)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

pred = []
label = []

for batch_idx, _data in enumerate(test_loader):
    fid, ema, labels, input_lengths, label_lengths = _data 
    ema, labels = ema, labels

    output = model(ema)  # (batch, time, n_class)

    output = F.log_softmax(output, dim=2)
    output = output.transpose(0, 1) # (time, batch, n_class)

    decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)

    pred.append(' '.join(decoded_preds[0]))
    label.append(' '.join(decoded_targets[0]))

error = wer(pred, label)

In [44]:
print(error)

4.378048780487805
