In [None]:
import os
import string
from typing import Tuple, List, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import ipywidgets as widgets
import itertools
from torch import optim
from torchaudio.transforms import RNNTLoss
from tqdm import tqdm_notebook, tqdm
from IPython.display import display, clear_output


In [None]:
if not os.path.isdir("./data"):
    os.makedirs("./data")

train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
dataset_transforms = nn.Sequential(
    torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
).to(device)


In [None]:
from gensim.utils import tokenize

class Vocab:
    def __init__(self, device):
        self.word2ind = {}
        self.ind2word = {}
        self.num_words = 0
        self.device = device
        self._add_word("<UNK>")


    def _add_word(self, word):
        if word not in self.word2ind.keys():
            self.word2ind[word] = self.num_words + 1
            self.ind2word[self.num_words + 1] = word
            self.num_words += 1

    def add_sentence(self, sentence):
        for word in tokenize(sentence):
            self._add_word(word)

    def tokenize_sentence(self, sentence):
        result_list = []
        for word in tokenize(sentence):
            if word in self.word2ind.keys():
                result_list.append(self.word2ind[word])
            else:
                result_list.append(self.word2ind["<UNK>"])
        return torch.LongTensor(result_list).to(device) 

    def __len__(self):
        return self.num_words + 10


In [None]:
from tqdm.auto import tqdm

vocab = Vocab(device)

for batch in tqdm(train_dataset):
    vocab.add_sentence(batch[2])

print(len(vocab))

In [None]:
import gensim
import gensim.downloader as api

embeddings = api.load('word2vec-google-news-300')

In [None]:
'hello' in embeddings.vocab

In [None]:
import numpy as np 
from tqdm.auto import trange

class TacotronEncoder(nn.Module):
    def __init__(self): 
        super().__init__()
        weights = []
        for i in trange(len(vocab)):
            if i in vocab.ind2word.keys() and vocab.ind2word[i] in embeddings.vocab:
                weights.append(list(embeddings.get_vector(vocab.ind2word[i])))
            else:
                weights.append(list(np.zeros(300)))

        weights = torch.FloatTensor(weights).to(device)
        self.embedding = nn.Embedding.from_pretrained(weights)   
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 1, 3, 1, 1)
        )
        self.rnn = nn.LSTM(300, 64, 1, bidirectional = True, batch_first=True)

    def forward(self, input_text):
        # 1, L
        hidden = self.embedding(input_text)
        # 1, L, 300
        hidden = hidden.unsqueeze(0)
        # 1, 1, L, 300
        hidden = self.conv(hidden)
        # 1, 1, L, 300
        hidden = hidden.squeeze(0)
        # 1, L, 300
        hidden, (h, c) = self.rnn(hidden)
        # 1, L, 128
        return hidden

In [None]:
def encoder_sanity_check():
    encoder = TacotronEncoder().to(device)
    text = train_dataset[0][2]
    tensor_text = vocab.tokenize_sentence(text).unsqueeze(0)
    encoded = encoder(tensor_text)
    print(encoded.shape)

encoder_sanity_check()

In [None]:
class RNNWithAttention(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.rnn = nn.LSTM(128, 64, 1, bidirectional = True, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim, 32),
            nn.LeakyReLU(0.05),
            nn.Linear(32, 128)
        )
        self.fc2 = torch.ones(1, 128).to(device)

    def forward(self, input_seq, encoder_output):
        # 1, L, embedding_dim | 1, T, 128
        h, c = None, None
        result_vecs = None
        for i in range(input_seq.shape[1]):
            cur_vector = input_seq[0, i, :]
            # embedding_dim
            vec2attn = self.fc(cur_vector).unsqueeze(1)
            # 128, 1
            weights = (encoder_output @ vec2attn)
            # 1, T, 1
            weights_norm = weights @ self.fc2
            # 1, T, 128
            sum_outputs = encoder_output * encoder_output 
            result_attention_vec = torch.sum(sum_outputs, 1).unsqueeze(0)
            # 1, 1, 128
            if h is None:
                result_vecs, (h, c) = self.rnn(result_attention_vec)
            else:
                rnn_output, (h, c) = self.rnn(result_attention_vec,  (h, c))
                result_vecs = torch.cat((result_vecs, rnn_output), 1)

        # 1, L, 128 
        return result_vecs

    

In [None]:
def attention_sanity_check():
    encoder = RNNWithAttention(128).to(device)
    encoded = encoder(torch.zeros((1, 33, 128)).to(device), torch.zeros((1, 23, 128)).to(device))
    print(encoded.shape)

attention_sanity_check()

In [None]:
class TacotronDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn1 = RNNWithAttention(128)
        self.fc1 = nn.Linear(128, 128)
        self.rnn2 = RNNWithAttention(128)
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 1, 3, 1, 1)
        )
        self.fc2 = nn.Linear(128, 128)

    def forward(self, input_spectragram, encoder_output):
        hidden = self.rnn1(input_spectragram, encoder_output)
        hidden1 = self.fc1(hidden)
        hidden2 = self.rnn2(hidden1, encoder_output)
        hidden3 = self.fc2(hidden) + hidden2
        return self.conv1(hidden3.unsqueeze(1)).squeeze(1)

In [None]:
def decoder_sanity_check():
    encoder = TacotronDecoder().to(device)
    encoded = encoder(torch.zeros((1, 33, 128)).to(device), torch.zeros((1, 23, 128)).to(device))
    print(encoded.shape)
decoder_sanity_check()

In [None]:
class Tacotron(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = TacotronEncoder()
        self.decoder = TacotronDecoder()
    
    def forward(self, input_spec, input_text):
        decoded_text = self.encoder(input_text)
        encoded_spec = self.decoder(input_spec, decoded_text)
        return encoded_spec

In [None]:
def generate_spectrogram(t2s_model: Tacotron, input_text, start_token, end_token, num_iterations, eps):
    encoded_text = t2s_model.encoder(input_text)
    answer = start_token
    for i in range(num_iterations):
        output_decoder = t2s_model(answer)
        last_token = output_decoder[:, -1, :].reshape(1, 1, -1)
        if torch.sum(end_token - last_token).item() < eps:
            break
        answer = torch.cat((answer, last_token), 1)
    return answer

In [None]:
START_TOKEN = torch.zeros((1, 1, 128)).to(device)
END_TOKEN = torch.zeros((1, 1, 128)).to(device)

In [None]:
model = Tacotron().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)

In [None]:
from tqdm.auto import trange, tqdm
import queue

num_epochs = 4

for epoch in trange(num_epochs):
    pbar = tqdm(train_dataset)
    sum_loss, cnt_loss = 0, 0
    log_window = 10
    for batch in pbar:
        optimizer.zero_grad()
        input_text = batch[2]
        tensor_text = vocab.tokenize_sentence(input_text).reshape(1, -1)
        input_wav = batch[0].to(device)
        input_spectrogram = dataset_transforms(input_wav).permute(0, 2, 1)
        model_input = torch.cat((START_TOKEN, input_spectrogram), 1)
        model_target = torch.cat((input_spectrogram, END_TOKEN), 1)
        model_output = model(model_input, tensor_text)
        loss = criterion(model_output, model_target)
        loss.backward()

        sum_loss += loss.item()
        cnt_loss += 1

        descritption = f"Last loss : {round(loss.item(), 2)} | Mean loss : {round(sum_loss/cnt_loss, 2)}"
        pbar.set_description(descritption)
        if cnt_loss == log_window:
            sum_loss, cnt_loss = 0, 0
        optimizer.step()



In [None]:
from librosa.feature.inverse import mfcc_to_audio

input_wav = batch[0].to(device)
sample_spec = dataset_transforms(input_wav)
mfcc_to_audio(sample_spec.cpu().detach().numpy().squeeze(0))
