In [1]:
import time
import yaml
import os
import torch
from utils.IO_func import read_file_list, load_binary_file, array_to_binary_file, load_Haskins_SSR_data
from shutil import copyfile
from utils.transforms import Transform_Compose
from utils.transforms import FixMissingValues
import IPython
import matplotlib.pyplot as plt

import argparse
import pickle
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
from comet_ml import Experiment
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np

In [3]:
class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class SpeechRecognitionModel(nn.Module):
    
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats//2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


In [4]:
class IterMeter(object):
    """keeps track of total iterations"""
    def __init__(self):
        self.val = 0

    def step(self):
        self.val += 1

    def get(self):
        return self.val

In [230]:
import argparse
import pickle
from torch.utils.data import Dataset, DataLoader

conf_dir = 'conf/SSR_conf.yaml'
buff_dir = 'current_exp'

config = yaml.load(open(conf_dir, 'r'), Loader=yaml.FullLoader)

data_path = os.path.join(buff_dir, 'data_CV')
SPK_list = ['M01']

for test_SPK in SPK_list:
    data_path_SPK = os.path.join(data_path, test_SPK)

    tr = open(os.path.join(data_path_SPK, 'train_data.pkl'), 'rb') 
    va = open(os.path.join(data_path_SPK, 'valid_data.pkl'), 'rb')        
    train_dataset, valid_dataset = pickle.load(tr), pickle.load(va)

In [231]:
def GreedyDecoder(output, labels, label_lengths, blank_label=40, collapse_repeated=True):
    
    from utils.database import PhoneTransform

    text_transform = PhoneTransform()

    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):

        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j -1]:
                    continue
                decode.append(index.item())
        decodes.append(text_transform.int_to_text(decode))
    return decodes, targets

In [232]:
def data_processing(data, transforms = None):
    ema = []
    labels = []
    input_lengths = []
    label_lengths = []
    
    for file_id, x, y in data:
        if transforms is not None:
            x = transforms(x)

        ema.append(torch.FloatTensor(x))
        labels.append(y)
        input_lengths.append(x.shape[0] // 2)
        label_lengths.append(len(y))
        
    ema = torch.nn.utils.rnn.pad_sequence(ema, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)        
    
    return file_id, ema, labels, input_lengths, label_lengths

In [233]:
learning_rate=1e-4
batch_size=20
epochs=80

In [234]:
hparams = {
    "n_cnn_layers": 3,
    "n_rnn_layers": 2,
    "rnn_dim": 512,
    "n_class": 41,
    "n_feats": 18,
    "stride": 2,
    "dropout": 0.1,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs
}
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                            batch_size=8,
                            shuffle=True,
                            collate_fn=lambda x: data_processing(x, None))

test_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                            batch_size=1,
                            shuffle=False,
                            collate_fn=lambda x: data_processing(x, None))

model = SpeechRecognitionModel(
    hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
    hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
    ).to(device)

#print(model)
#print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
criterion = nn.CTCLoss(blank=40).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                                        steps_per_epoch=int(len(train_loader)),
                                        epochs=hparams['epochs'],
                                        anneal_strategy='linear')

iter_meter = IterMeter()



In [235]:
    model.train()
    data_len = len(train_loader.dataset)
    for epoch in range(epochs):
        for batch_idx, _data in enumerate(train_loader):
            file_id, ema, labels, input_lengths, label_lengths = _data 
            ema, labels = ema.to(device), labels.to(device)
            optimizer.zero_grad()

            output = model(ema)  # (batch, time, n_class)

            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()

            optimizer.step()
            scheduler.step()
            iter_meter.step()
            if batch_idx % 100 == 0 or batch_idx == data_len:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(ema), data_len,
                    100. * batch_idx / len(train_loader), loss.item()))
        




In [238]:
    model.eval()
    pred = []
    label = []

    from jiwer import wer

    for batch_idx, _data in enumerate(test_loader):
        fid, ema, labels, input_lengths, label_lengths = _data 
        ema, labels = ema.to(device), labels.to(device)

        optimizer.zero_grad()

        output = model(ema)  # (batch, time, n_class)

        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
      #  test_loss += loss.item() / len(test_loader)

        decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)

     #   print(decoded_preds)
     #   print(decoded_targets)

        pred.append(' '.join(decoded_preds[0]))
        label.append(' '.join(decoded_targets[0]))
    
print(len(pred))
print(len(label))

error = wer(pred, label)
print(error)

50
50
0.8303094983991463


In [None]:
        for epoch in range(epochs):
            for batch_idx, _data in enumerate(train_loader):
                file_id, ema, labels, input_lengths, label_lengths = _data 
                ema, labels = ema.to(device), labels.to(device)
                optimizer.zero_grad()

                output = model(ema)  # (batch, time, n_class)

                output = F.log_softmax(output, dim=2)
                output = output.transpose(0, 1) # (time, batch, n_class)

                loss = criterion(output, labels, input_lengths, label_lengths)
                loss.backward()

                optimizer.step()
                scheduler.step()
                iter_meter.step()
                if batch_idx % 100 == 0 or batch_idx == data_len:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(ema), data_len, 100. * batch_idx / len(train_loader), loss.item()))

            for 

In [237]:
for i in range(len(pred)):
    print(pred[i])
    print(label[i])
    print('#####################')

SP DH AH S   R B ER R AH Z AE M B ER   D ER M DH AH IY SP
SP DH AH S AO L T B R IY Z K EY M AH K R AO S F ER M DH AH S IY SP
#####################
SP DH AH S   R M ER N Z S K IY P AH ER R AH Z ER M AH S IH SP
SP DH AH S AO L T B R IY Z K EY M AH K R AO S F ER M DH AH S IY SP
#####################
SP AH R IH M IH Z AH L N AE N M   R K N M
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
SP DH AH L AE D DH AH M UW S OW N EH N IY B   N S SP
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
SP   L EY N DH AH M P UW OW EH IH B   Z SP
SP DH AH G ER L AE T DH AH B UW TH S OW L D F IH F T IY B   N D Z SP
#####################
SP DH AH S M AH P   M   T L OY AH Z S AH L SP
SP DH AH S M AO L P AH P N AO D AH SHH OW L AH N DH AH S   K SP
#####################
SP DH AH S M AO   R B M ER   N W S AO D
SP DH AH S M AO L P AH P N AO D AH SHH OW L AH N DH AH S   K SP
#####################
SP DH AH S M   L M P   P AH D AH 