In [1]:
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data



In [2]:
##Paths
train_aud = '/tmp/target-segments/ru/clips/'
train_df = '/tmp/target-segments/ru/train.tsv'

**Transform functions**

Transform functions are the functions that are dealing with necessary input transformations e.g. feature extraction. They are feeded directly in the data loader. It helps to speed up data manipulation in contrast to reading all the file from the hard drive.

In [3]:
def find_maxlen(path, train_df):
    fnames = pd.read_csv(train_df, sep='\t')['path']
    maxlen = 0
    for n in tqdm(fnames):
        waveform, sample_rate = torchaudio.load(os.path.join(path, n))
        mfcc = torchaudio.transforms.MFCC()(waveform)
        size = mfcc.shape[2]
        if size > maxlen:
            maxlen = size
    print("Maxlen:", maxlen)


def extract_feats(path, maxlen=1083):
    '''
    Reads and processes one file at a time.
    Args:
        path: path to the file
        maxlen: maximum length of the spectrogram for padding
    '''
    waveform, sample_rate = torchaudio.load(path)
    #Calculate MFCC
    mfcc = torchaudio.transforms.MFCC()(waveform)
    #Calculate delta and double-delta
    deltas = torchaudio.transforms.ComputeDeltas()(mfcc)
    ddeltas = torchaudio.transforms.ComputeDeltas()(deltas)
    res = torch.cat((mfcc, deltas, ddeltas), dim=1).squeeze(0)
    #Normalize rows
    s = torch.sum(res, dim=1, keepdim=True)
    norm = torch.div(res, s)
    mask = torch.ones(norm.shape[0], norm.shape[1])
    padded_norm = nn.functional.pad(norm, pad=(0, maxlen-norm.shape[1], 0, 0), 
                                          mode="constant",value=0)
    padded_mask = nn.functional.pad(mask, pad=(0, maxlen-mask.shape[1], 0, 0), 
                                          mode="constant",value=0)
    return padded_norm, padded_mask

def alphabet_enc(csv_path):
    char2ind = {}
    sents = pd.read_csv(csv_path, sep='\t')['sentence']
    chars = list(set([char for sent in sents for char in sent]))
    for i in range(len(chars)):
        char2ind[chars[i]] = i
    char2ind["<eos>"] = len(chars)+1 
    return char2ind

In [4]:
class TrainData(data.Dataset):
    def __init__(self, csv_path, aud_path, transform):
        self.df = pd.read_csv(csv_path, sep='\t')
        self.aud_path = aud_path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        fname = os.path.join(self.aud_path, self.df['path'][idx])
        transcript = self.df['sentence'][idx].lower()

        feat, mask = self.transform(fname)

        sample = {'aud':feat, 'trans': transcript, 'mask':mask}
        return sample
    
def weights(m):
    '''
    Intialize random weights
    '''
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight.data)
        nn.init.constant_(m.bias.data,0.1)

**Proposed Architechture**

Attention-based Sequence-to-Sequence model:

![](/img/arch.png)

In [5]:
class Encoder(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        self.input_layer = nn.Linear(120, 512)
        self.blstm = nn.LSTM(input_size=512, 
                             hidden_size=256, 
                             num_layers=3, 
                             bidirectional=True)
        self.h0 = torch.zeros(3*2, batch_size, 256)
        self.c0 = torch.zeros(3*2, batch_size, 256)
        
    def forward(self, x):
        #Pass through the first linear layer
        outputs=[]
        for i in range(x.shape[2]):
            feature = x[:,:,i]
            out = self.input_layer(feature)
            out = torch.nn.LeakyReLU()(out)
            outputs.append(out)
        outputs = torch.stack(outputs)
        #Pass through LSTM layers
        output, (hn, cn) = self.blstm(outputs, (self.h0, self.c0))
        return output, (hn, cn)
    
    
class Decoder(nn.Module):
    def __init__(self, batch_size, char2ind):
        super().__init__()
        self.char2ind = char2ind
        self.embed_layer = nn.Linear(512, 512)
        self.blstm = nn.LSTM(input_size=512, 
                             hidden_size=512, 
                             num_layers=1)
        self.h0 = torch.zeros(1, batch_size, 512)
        self.c0 = torch.zeros(1, batch_size, 512)
        self.y0 = torch.zeros(1, 512)
        

In [6]:
char2ind = alphabet_enc(train_df)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder = Encoder(32)
encoder = encoder.to(device)
encoder.apply(weights)

cv_dataset = TrainData(train_df, train_aud, extract_feats)
loader = data.DataLoader(cv_dataset, batch_size=32, shuffle=True)

In [8]:
decoder = Decoder(32, char2ind)

In [None]:
for batch in loader:
    x = batch['aud'].to(device)
    out, (h, c) = encoder(x)
    print(out)

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore


tensor([[[-3.5850e-03,  4.8217e-03, -4.0246e-03,  ...,  1.2783e-02,
          -3.6178e-02,  8.2143e-04],
         [ 5.2252e-03,  3.5812e-03, -7.4774e-03,  ...,  4.8239e-02,
          -4.2366e-02, -6.0497e-03],
         [-1.8553e-03,  4.5822e-03, -6.5938e-03,  ...,  3.4576e-02,
          -3.8906e-02, -1.0214e-02],
         ...,
         [ 9.9936e-03,  2.5506e-03, -6.3852e-03,  ...,  2.8113e-02,
          -4.5586e-02, -2.9782e-02],
         [-3.6925e-04,  8.0831e-03, -2.8792e-03,  ...,  2.1298e-02,
          -4.1732e-02, -4.5197e-03],
         [ 3.9447e-03,  3.7340e-03,  3.4174e-04,  ...,  2.5715e-02,
          -3.9151e-02, -2.1908e-02]],

        [[-5.8371e-03,  4.7725e-03, -5.7921e-03,  ...,  8.9470e-03,
          -3.6790e-02,  1.9270e-03],
         [ 1.0485e-02,  2.1323e-03, -1.1960e-02,  ...,  5.8033e-02,
          -4.3373e-02, -4.0174e-03],
         [-1.3855e-03,  5.2263e-03, -1.0116e-02,  ...,  3.8426e-02,
          -3.7583e-02, -1.0563e-02],
         ...,
         [ 2.0178e-02,  9