In [14]:
import time
import yaml
import os
import torch
from utils.IO_func import read_file_list, load_binary_file, array_to_binary_file, load_Haskins_SSR_data
from shutil import copyfile
from utils.transforms import Transform_Compose
from utils.transforms import FixMissingValues
import IPython
import matplotlib.pyplot as plt

import argparse
import pickle
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.autograd import Variable

In [3]:
import os
from comet_ml import Experiment
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np

In [4]:
class MyLSTM(nn.Module):
    def __init__(self, D_in = 18, H = 256, D_out = 41, num_layers= 3, bidirectional=False):
        super(MyLSTM, self).__init__()
        self.hidden_dim = H
        self.num_layers = num_layers
        self.num_direction =  2 if bidirectional else 1
        self.lstm = nn.LSTM(D_in, H, num_layers, bidirectional=bidirectional)
        self.hidden2out = nn.Linear(self.num_direction*self.hidden_dim, D_out)

    def init_hidden(self, x):
        h, c = (Variable(torch.zeros(self.num_layers * self.num_direction, x.shape[1], self.hidden_dim)),
                Variable(torch.zeros(self.num_layers * self.num_direction, x.shape[1], self.hidden_dim)))
        return h, c

    def forward(self, sequence, h, c):
        output, (h, c) = self.lstm(sequence, (h, c))
        output = self.hidden2out(output)
        return output

In [5]:
class IterMeter(object):
    """keeps track of total iterations"""
    def __init__(self):
        self.val = 0

    def step(self):
        self.val += 1

    def get(self):
        return self.val

In [36]:
import argparse
import pickle
from torch.utils.data import Dataset, DataLoader

conf_dir = 'conf/SSR_conf.yaml'
buff_dir = 'current_exp'

config = yaml.load(open(conf_dir, 'r'), Loader=yaml.FullLoader)

data_path = os.path.join(buff_dir, 'data_CV')
SPK_list = ['F01']

for test_SPK in SPK_list:
    data_path_SPK = os.path.join(data_path, test_SPK)

    tr = open(os.path.join(data_path_SPK, 'train_data.pkl'), 'rb') 
    va = open(os.path.join(data_path_SPK, 'valid_data.pkl'), 'rb')        
    train_dataset, valid_dataset = pickle.load(tr), pickle.load(va)

In [37]:
def GreedyDecoder(output, labels, label_lengths, blank_label=40, collapse_repeated=True):
    
    from utils.database import PhoneTransform

    text_transform = PhoneTransform()

    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):

        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j -1]:
                    continue
                decode.append(index.item())
        decodes.append(text_transform.int_to_text(decode))
    return decodes, targets

In [38]:
def data_processing(data, transforms = None):
    ema = []
    labels = []
    input_lengths = []
    label_lengths = []
    
    for file_id, x, y in data:
        if transforms is not None:
            x = transforms(x)

        ema.append(torch.FloatTensor(x))
        labels.append(y)
        input_lengths.append(x.shape[0])
        label_lengths.append(len(y))
        
    ema = torch.nn.utils.rnn.pad_sequence(ema, batch_first=True)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)        
    
    return file_id, ema, labels, input_lengths, label_lengths

In [45]:
learning_rate=2e-4
batch_size=20
epochs=80

In [46]:
hparams = {
    "n_cnn_layers": 3,
    "n_rnn_layers": 3,
    "rnn_dim": 256,
    "n_class": 41,
    "n_feats": 18,
    "stride": 2,
    "dropout": 0.1,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs
}
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                            batch_size=1,
                            shuffle=True,
                            collate_fn=lambda x: data_processing(x, None))

test_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                            batch_size=1,
                            shuffle=False,
                            collate_fn=lambda x: data_processing(x, None))

model = MyLSTM().to(device)

#print(model)
#print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
criterion = nn.CTCLoss(blank=40).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                                        steps_per_epoch=int(len(train_loader)),
                                        epochs=hparams['epochs'],
                                        anneal_strategy='linear')

iter_meter = IterMeter()



In [47]:
model.train()
data_len = len(train_loader.dataset)
for epoch in range(epochs):
    for batch_idx, _data in enumerate(train_loader):
        file_id, ema, labels, input_lengths, label_lengths = _data 
        ema, labels = ema.to(device), labels.to(device)
        optimizer.zero_grad()

        
        h, c = model.init_hidden(ema)
        h, c = h.to(device), c.to(device)

        output = model(ema, h, c)  # (batch, time, n_class)
        
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()

        optimizer.step()
        scheduler.step()
        iter_meter.step()
        if batch_idx % 100 == 0 or batch_idx == data_len:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(ema), data_len,
                100. * batch_idx / len(train_loader), loss.item()))
        


















In [48]:
model.eval()
pred = []
label = []

from jiwer import wer

for batch_idx, _data in enumerate(test_loader):
    fid, ema, labels, input_lengths, label_lengths = _data 
    ema, labels = ema.to(device), labels.to(device)

    optimizer.zero_grad()

    h, c = model.init_hidden(ema)
    h, c = h.to(device), c.to(device)

    output = model(ema, h, c)  # (batch, time, n_class)

    output = F.log_softmax(output, dim=2)
    output = output.transpose(0, 1) # (time, batch, n_class)
    
    loss = criterion(output, labels, input_lengths, label_lengths)
  #  test_loss += loss.item() / len(test_loader)

    decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)

    
    pred.append(' '.join(decoded_preds[0]))
    label.append(' '.join(decoded_targets[0]))
    
print(len(pred))
print(len(label))

error = wer(pred, label)
print(error)

50
50
3.149390243902439


In [49]:
for i in range(len(pred)):
    print(pred[i])
    print(label[i])
    print('#####################')

AY L L EY
SP S M OW K IY F AY ER Z L AE K F L EY M AE N D SP SHH IY T SP
#####################
AH K EY R M DH
SP DH AH S AO L T B R IY Z K EY M AH K R AO S F ER M DH AH S IY SP
#####################
AH L D R K R S ER M DH DH
SP DH AH S AO L T B R IY Z K EY M SP AH K R AO S F ER M DH AH S IY SP
#####################
DH AH R K R AO S ER AH M DH DH
SP DH AH S AO L T B R IY Z K EY M AH K R AO S F ER M DH AH S IY SP
#####################
AH OW DH AH AH L AH AH P N
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
AH ER AH AH DH L V
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
AH ER AH DH L AH
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
DH AH M N L S S
SP DH AH S M AO L P AH P N AO D AH SHH OW L IH N DH AH S   K SP
#####################
AH N DH AH S
SP DH AH S M AO L P AH P SP N AO D AH SHH OW L IH N DH AH S   K SP
#####################
DH AH N L DH AH S S
SP 