In [1]:
import os
import string
from typing import Tuple, List, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import ipywidgets as widgets
import itertools
from torch import optim
from torchaudio.transforms import RNNTLoss
from tqdm import tqdm_notebook, tqdm
from IPython.display import display, clear_output

In [2]:
if not os.path.isdir("./data"):
    os.makedirs("./data")

train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)


  0%|          | 0.00/5.95G [00:00<?, ?B/s]

  0%|          | 0.00/331M [00:00<?, ?B/s]

In [3]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.1-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 32.5 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 59.8 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 13.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 56.2 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling 

In [4]:
from transformers import DistilBertTokenizer, DistilBertModel
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
text = "heeelllllooooooooooo"
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
output['last_hidden_state'].shape

torch.Size([1, 11, 768])

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
device

device(type='cuda', index=0)

In [7]:
len(tokenizer.vocab)

30522

In [8]:
dataset_transforms = nn.Sequential(
    torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
).to(device)


  "At least one mel filterbank has all zero values. "


In [10]:
class RNNTNet(nn.Module):
    def __init__(self, model, tokenizer):
        super().__init__()

        self.model = model
        self.tokenizer = tokenizer 
        for param in self.model.parameters():
            param.requires_grad_(False)

        self.model_cnn = nn.Sequential(
            nn.Conv2d(1, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 1, 4, 2, 1),
        )

        self.mat1 = nn.Parameter(torch.randn(4, 1).to(device))
        self.activation = nn.LeakyReLU(0.01)
        self.mat2 = nn.Linear(768, len(self.tokenizer.vocab))
        
    def forward(self, input_spec, input_seq):
        output_seq = self.model(**input_seq)['last_hidden_state'] # 1, u, 768
        output_spec = self.model_cnn(input_spec) # 1, 1, 4, len

        output_spec = output_spec.squeeze(0).squeeze(0) # 4, len
        output_spec = output_spec.T # len, 4
        output_seq = output_seq.reshape(1, -1)# 1, u * 768
        output_seq = self.mat1 @ output_seq # 4, u * 768
        output_seq = self.activation(output_seq) # 4, u * 768

        output_final = output_spec @ output_seq # len, u * 768

        len = output_final.shape[0]
        output_final = output_final.reshape(len, -1, 768)
        logits = self.mat2(output_final)
        return logits



In [16]:
speech_model = RNNTNet(model, tokenizer).to(device)

In [17]:
def sanity_check(index):
    input_wav = train_dataset[index][0].to(device)
    input_spec = dataset_transforms(input_wav).unsqueeze(0)

    input_text = tokenizer(train_dataset[index][2], return_tensors='pt')
    for key in input_text.keys():
        input_text[key] = input_text[key].to(device)
    output = speech_model(input_spec, input_text)
    print(input_text['input_ids'].shape)
    print(f"blank index: {input_text['input_ids'][0][0].item()}")
    print(output.shape)
    return input_text['input_ids'][0][0].item()

blank_index = sanity_check(0)

torch.Size([1, 47])
blank index: 101
torch.Size([35, 47, 30522])


In [18]:
speech_model = speech_model.to(device)
optimizer = torch.optim.Adam(speech_model.parameters(), lr=5e-4)
criterion = torchaudio.transforms.RNNTLoss(reduction='mean', blank = blank_index)

In [23]:
from tqdm.auto import trange, tqdm

num_epochs = 1
speech_model.train()
for epoch in trange(num_epochs):
    sum_loss, cnt_loss = 0, 0
    pbar = tqdm(train_dataset)
    for batch in pbar:
        optimizer.zero_grad()
        input_wav = batch[0].to(device)
        input_spec = dataset_transforms(input_wav).unsqueeze(0)

        input_text = tokenizer(batch[2], return_tensors='pt')
        for key in input_text.keys():
            input_text[key] = input_text[key].to(device)
        output = speech_model(input_spec, input_text).unsqueeze(0)
        target = input_text['input_ids'][:, 1:].type(torch.int32)
        logit_lengths = torch.tensor([int(output.shape[1])], dtype=torch.int).to(device)
        target_lengths = torch.tensor([int(target.shape[1])], dtype=torch.int).to(device)
        loss = criterion(output, target, logit_lengths, target_lengths)
        loss.backward()
        pbar.set_description(f"train loss: {round(loss.item(), 3)}")
        optimizer.step()
    print(f"MEAN LOSS:{sum_loss / cnt_loss}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/28539 [00:00<?, ?it/s]

ZeroDivisionError: ignored

In [24]:
tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(elem.item()) for elem in encoded_input['input_ids'].squeeze(0)])

'[CLS] heeelllllooooooooooo [SEP]'

In [65]:
import random
@torch.no_grad()
def inference(index, iterations=10):
    speech_model.eval()
    all_tokens = ['[CLS]']
    for it in range(iterations):
        input_text = tokenizer.convert_tokens_to_string(all_tokens)
        input_wav = test_dataset[index][0].to(device)
        input_spec = dataset_transforms(input_wav).unsqueeze(0)

        input_text = tokenizer(test_dataset[index][2], return_tensors='pt')
        for key in input_text.keys():
            input_text[key] = input_text[key].to(device)
        output = speech_model(input_spec, input_text)
        last_dim = int(output.shape[2])
        index_random =random.randint(200, 400)
        pos = torch.topk(output.flatten(), index_random).indices[-1].item()

        pos %= last_dim
        if tokenizer.convert_ids_to_tokens(pos)[:5] == "[unus":
            break
        all_tokens.append(tokenizer.convert_ids_to_tokens(pos))
    return tokenizer.convert_tokens_to_string(all_tokens)


In [69]:
inference(0, 40)

'[CLS] [CLS] he ll ve he t add ll as ll would [CLS] an an had he in all were d in ll for he d ll d what could already ll and [CLS] there d [CLS] this and had he'

In [59]:
import IPython

IPython.display.Audio(test_dataset[0][0], rate = 16000)

In [70]:
inference(1, 20)

'[CLS] d head t themselves mister themselves often call themselves her the things believe their instead are head d t pushed'

In [71]:
import IPython

IPython.display.Audio(test_dataset[1][0], rate = 16000)

In [72]:
inference(2, 20)

'[CLS] around haze they lay over the whom managed her enough next behind soon others the sometimes t let o one'

In [73]:
import IPython

IPython.display.Audio(test_dataset[2][0], rate = 16000)