In [1]:
import pandas as pd
import torch
from torch import nn, optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
import torchaudio

from torchaudio import transforms as T
from torchtext.vocab import build_vocab_from_iterator
import matplotlib.pyplot as plt
from IPython.display import Audio
import torchvision
from torchvision.models import resnet18
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from torcheval.metrics import Mean
from string import ascii_lowercase
import math
import re

In [2]:
specials = ["", "B", "E"]
n_ft = 300
h_len = 400
d_model = 240
cnn_out = 240
bs = 50
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def train_one_epoch(model, train_loader, loss_fn, optimizer, epoch=None):
    model.train()
    metric = Mean().to(device)
    with tqdm(train_loader, unit='batch') as tepochs:
        for wave, labels, sr in tepochs:
            if epoch is not None:
                tepochs.set_description(f'epoch:{epoch}')
            yp = model(train_transform(wave.to(device)), labels.to(device)[:, :-1])
            loss = loss_fn(yp.transpose(2, 1), labels.to(device)[:, 1:])
            loss.backward()
            clip_grad_norm_(model.parameters(), 0.25)
            optimizer.step()
            optimizer.zero_grad()
            metric.update(loss)

            tepochs.set_postfix(loss=metric.compute().item())
    return model, metric.compute().item()


def evaluate(model, test_loader, loss_fn):
    model.eval()
    metric = Mean().to(device)
    with torch.no_grad():
        for wave, labels, sr in test_loader:
            yp = model(train_transform(wave.to(device)), labels.to(device)[:, :-1])
            loss = loss_fn(yp.transpose(2, 1), labels.to(device)[:, 1:])
            metric.update(loss)
    print(metric.compute().item())
    return metric.compute().item()


In [4]:
def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(T.AmplitudeToDB()(specgram), origin="lower", aspect="auto")
    fig.colorbar(im, ax=axs)
    plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

In [5]:
data = pd.read_csv('new_data.csv')

In [6]:
get_lower = lambda x: re.sub(r"[^1-9a-z\s']", '', x.lower())
vocab = build_vocab_from_iterator(
    data['info'].apply(get_lower), special_first=True, specials=specials)
vocab.set_default_index(vocab[''])

In [7]:
vocab.get_itos()

['',
 'B',
 'E',
 ' ',
 'e',
 't',
 'a',
 'o',
 'n',
 'i',
 's',
 'r',
 'h',
 'd',
 'l',
 'c',
 'f',
 'u',
 'm',
 'w',
 'p',
 'g',
 'b',
 'y',
 'v',
 'k',
 'x',
 'q',
 'j',
 "'",
 '1',
 '2',
 'z',
 '3',
 '6',
 '9',
 '8',
 '5',
 '4',
 '7']

In [8]:
train_transform = nn.Sequential(T.MelSpectrogram(n_fft=n_ft, hop_length=h_len, n_mels=80), T.FrequencyMasking(10), 
                                T.TimeMasking(10)).to(device)
valid_transform = T.MelSpectrogram(n_fft=n_ft, hop_length=h_len, n_mels=80).to(device)



In [9]:
train_data = data.iloc[:10000]

In [10]:
valid_data = data.iloc[10000:11500]

In [11]:
class LJSpeechSet(Dataset):
    def __init__(self, data, phase='train'):
        
        pathx = 'LJSpeech-1.1/wavs/'
        self.data = dict()
        self.wave_forms, self.labels, self.sample_rates = list(), list(), list()

        for i, row in data.iterrows():

            wave_form, s_r = torchaudio.load(pathx + str(row['f_name']))
            label = torch.LongTensor([vocab[c] for c in "B"+
                                      re.sub(r"[^1-9a-z\s']", '', str(row['info'].lower()))+"E"])

            self.wave_forms.append(wave_form.squeeze()), self.labels.append(label), self.sample_rates.append(s_r)


        self.phase = phase
        self.wave_forms = pad_sequence(self.wave_forms, batch_first=True)

    def __len__(self):
        return len(self.wave_forms)
    
    def __getitem__(self, ind):
        return self.wave_forms[ind], self.labels[ind], self.sample_rates[ind]

In [12]:
train_set = LJSpeechSet(train_data)

In [13]:
valid_set = LJSpeechSet(valid_data)

In [14]:
train_transform(train_set[0][0].to(device)).shape

torch.Size([80, 557])

In [15]:
def collate(batch):
    mel_waves = pad_sequence([b[0] for b in batch], batch_first=True).unsqueeze(1)
    labels = pad_sequence([b[1] for b in batch], batch_first=True, padding_value=vocab[''])
    rates = [b[2] for b in batch]
    return mel_waves, labels, rates

In [17]:
train_loader = DataLoader(train_set, bs, shuffle=True, collate_fn=collate)
valid_loader = DataLoader(valid_set, bs, shuffle=True, collate_fn=collate)

In [18]:
train_transform = nn.Sequential(T.Spectrogram(n_fft=n_ft, hop_length=h_len), T.FrequencyMasking(20), 
                                T.TimeMasking(20)).to(device)
valid_transform = T.Spectrogram(n_fft=n_ft, hop_length=h_len).to(device)

In [19]:
class Block(nn.Module):
    
    def __init__(self, inp, out, kernel, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(inp, out, kernel_size=kernel, padding=kernel//2)
        self.conv2 = nn.Conv2d(out, out, kernel_size=kernel, padding=kernel//2)
        self.maxpool = nn.MaxPool2d(kernel, stride=stride, padding=kernel//2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.batch_n = nn.BatchNorm2d(out)
    
    def forward(self, x):  
        
        x = self.relu(self.batch_n(self.conv1(x)))
        x = self.relu(self.batch_n(self.conv2(x)))
        x = self.maxpool(x)
        return x

class CNNFeatureExtractor(nn.Module):
    
    def __init__(self, output):
        
        super().__init__()
        self.block1 = Block(1, 8, 7, 4)
        self.block2 = Block(8, 12, 5, 2)
        self.block3 = Block(12, 16, 4, 2)
        self.block4 = Block(16, 14, 3, 1)
        self.output = nn.LazyLinear(output)
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        # reshaping to bs, c*frames* bin which has strided for 8 times
        # and using linear to increase bin features
        x = x.view(x.size(0), -1, x.size(2))
        y = self.output(x)
        return y

In [20]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 200):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


In [28]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, cnn_out):
        
        super().__init__()
        self.embedding = nn.Embedding(len(vocab), d_model, padding_idx=0)

        self.d_model = d_model
        self.positional_encoding = PositionalEncoding(d_model, dropout=0.1)
        enc_layer = nn.TransformerEncoderLayer(d_model=cnn_out, nhead=6, dim_feedforward=cnn_out*4, 
                                               batch_first=True, activation='gelu')
        
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=8)
        dec_layer = nn.TransformerDecoderLayer(d_model, 6, d_model*4, batch_first=True, activation='gelu')
        self.dec = nn.TransformerDecoder(dec_layer, 8)


    def forward(self, x, inp):
        tgt = self.embedding(inp)
        y = self.enc(x)
#         z = self.positional_encoding(tgt)
        y = self.dec(tgt, y)
        
        return y

In [29]:
class ASRNeural(nn.Module):
    
    def __init__(self, d_model, cnn_out):
        
        super().__init__()
        self.cnn_ = CNNFeatureExtractor(cnn_out)
        self.transformer = Transformer(d_model, cnn_out)
        self.head = nn.Linear(d_model, len(vocab))
        
    def forward(self, x, inp):
        y = self.cnn_(x)
        y = self.transformer(y, inp)
        y = self.head(y)
        return y

In [30]:
model = ASRNeural(d_model, cnn_out).to(device)

In [32]:
torch.cuda.empty_cache()

In [33]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=5)

In [34]:
loss_train_hist = list()
loss_valid_hist = list()
pre_train_hist = list()
pre_valid_hist = list()
best_loss_valid = 1e+4
epoch_counter = 0

In [None]:
n = 100
for epoch in range(n):
    model, train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch)
    valid_loss = evaluate(model, valid_loader, loss_fn)
    
    
    loss_train_hist.append(train_loss)
    loss_valid_hist.append(valid_loss)

    if valid_loss < best_loss_valid:
        torch.save(model,'modelx1.pt')
        best_loss_valid =  valid_loss
        print('Model SAVED') 

    epoch_counter +=1

epoch:0: 100%|██████████████████| 200/200 [01:22<00:00,  2.41batch/s, loss=3.57]


3.7860358079274494
Model SAVED


epoch:1: 100%|██████████████████| 200/200 [01:21<00:00,  2.44batch/s, loss=3.04]


2.930072816212972
Model SAVED


epoch:2:  48%|█████████▏         | 97/200 [00:40<00:43,  2.36batch/s, loss=2.91]

In [None]:
torch.cuda.empty_cache()

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
wave, label, sr = next(iter(train_loader))

In [None]:
model.eval()
generated = [vocab['B']]

for i in range(90):
    with torch.no_grad():
        preded = model(valid_transform(wave[0].to(device)).unsqueeze(0), 
                       torch.LongTensor(generated).to(device).unsqueeze(0))
    argm = torch.multinomial((preded.squeeze(0) / .4).softmax(-1), 1)[-1]
    if argm.item() == vocab['E']:
        print('I have predicted the last item bro i cant take it more')
        break
    generated.append(argm.squeeze())

In [None]:
itos = vocab.get_itos()

In [None]:
''.join([itos[d] for d in torch.unique_consecutive(torch.LongTensor(generated))])