In [1]:
!pip install arpa


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting arpa
  Downloading arpa-0.1.0b4-py3-none-any.whl (9.6 kB)
Installing collected packages: arpa
Successfully installed arpa-0.1.0b4


In [2]:
import math
import os
import shutil
import string
import time
from collections import defaultdict
from typing import List, Tuple, TypeVar, Optional, Callable, Iterable

import arpa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from matplotlib.colors import LogNorm
from torch import optim
from tqdm.notebook import tqdm


In [3]:
!mkdir week_05_files
!wget -O week_05_files/utils.py https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/utils.py
!wget -O week_05_files/test_matrix.txt -q https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/test_matrix.txt
!wget -O week_05_files/soft_alignment.txt -q https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/soft_alignment.txt
!wget -O week_05_files/test_decode.txt -q https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/test_decode.txt
!wget -O week_05_files/test_labels.txt -q https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/test_labels.txt
!wget -O week_05_files/3-gram.pruned.1e-7.arpa -q https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/3-gram.pruned.1e-7.arpa


--2022-08-11 06:27:20--  https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_05/utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6770 (6.6K) [text/plain]
Saving to: ‘week_05_files/utils.py’


2022-08-11 06:27:21 (71.8 MB/s) - ‘week_05_files/utils.py’ saved [6770/6770]



In [4]:
import os
import torchaudio
if not os.path.isdir("./data"):
    os.makedirs("./data")

train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)


  0%|          | 0.00/5.95G [00:00<?, ?B/s]

  0%|          | 0.00/331M [00:00<?, ?B/s]

In [5]:
train_dataset[0]

(tensor([[-0.0065, -0.0055, -0.0062,  ...,  0.0033,  0.0005, -0.0095]]),
 16000,
 'CHAPTER ONE MISSUS RACHEL LYNDE IS SURPRISED MISSUS RACHEL LYNDE LIVED JUST WHERE THE AVONLEA MAIN ROAD DIPPED DOWN INTO A LITTLE HOLLOW FRINGED WITH ALDERS AND LADIES EARDROPS AND TRAVERSED BY A BROOK',
 103,
 1240,
 0)

In [6]:
import torch 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [7]:
import torch.nn as nn

train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
).to(device)

test_audio_transforms = nn.Sequential(
    torchaudio.transforms.MFCC(sample_rate=16000, n_mfcc=128)
).to(device)


  "At least one mel filterbank has all zero values. "


In [8]:
BLANK_SYMBOL = "_"

class Tokenizer:
    """
    Maps characters to integers and vice versa
    """
    def __init__(self):
        self.char_map = {}
        self.index_map = {}
        for i, ch in enumerate(["'", " "] + list(string.ascii_lowercase) + [BLANK_SYMBOL]):
            self.char_map[ch] = i
            self.index_map[i] = ch
        
    def text_to_indices(self, text: str) -> List[int]:
        return [self.char_map[ch] for ch in text]

    def indices_to_text(self, labels: List[int]) -> str:                                                                                                                                                                                                                                 
        return "".join([self.index_map[i] for i in labels])
    
    def get_symbol_index(self, sym: str) -> int:
        return self.char_map[sym]
    

tokenizer = Tokenizer()

In [9]:
class Collate:
  def __init__(self, data_type = 'test') -> None:
        super(Collate, self).__init__() 

        self.data_type = data_type

  def __call__(self, data: torchaudio.datasets.librispeech.LIBRISPEECH):
        """
        :param data: is a list of tuples of [features, label], where features has dimensions [n_features, length]
        "returns features, lengths, labels: 
              features is a Tensor [batchsize, features, max_length]
              lengths is a Tensor of lengths [batchsize]
              labels is a Tesnor of targets [batchsize]
        """

        spectrograms = []
        labels = []
        input_lengths = []
        label_lengths = []
        for (waveform, _, utterance, _, _, _) in data:
            if self.data_type == 'train':
                spec = train_audio_transforms(waveform.to(device)).squeeze(0).transpose(0, 1)
            elif self.data_type == 'test':
                spec = test_audio_transforms(waveform.to(device)).squeeze(0).transpose(0, 1)
            else:
                raise Exception('data_type should be train or valid')
            spectrograms.append(spec)
            label = torch.Tensor(tokenizer.text_to_indices(utterance.lower())).to(device)
            labels.append(label)
            input_lengths.append(spec.shape[0] // 2)
            label_lengths.append(len(label))

        spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
        labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

        return spectrograms, labels, input_lengths, label_lengths


In [10]:
train_collate_fn = Collate(data_type='train')
test_collate_fn = Collate(data_type='test')

kwargs = {'num_workers': 0}
train_loader = data.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=train_collate_fn, **kwargs)

kwargs = {'num_workers': 0} 
test_loader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=test_collate_fn, **kwargs)


In [11]:
lim2show = 3
for i, batch in enumerate(train_loader):
    for j in range(len(batch)):
        if isinstance(batch[j], torch.Tensor):
          print(batch[j].shape)
        else:
          print(batch[j])
    if i >= lim2show - 1:
        break    

torch.Size([1, 1, 128, 1055])
torch.Size([1, 195])
[527]
[195]
torch.Size([1, 1, 128, 1252])
torch.Size([1, 219])
[626]
[219]
torch.Size([1, 1, 128, 1084])
torch.Size([1, 151])
[542]
[151]


In [12]:
class BiRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
          nn.Conv2d(in_channels=1, out_channels = 32, kernel_size=3, stride=1, padding=1),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.Conv2d(in_channels=32, out_channels = 32, kernel_size=3, stride=1, padding=1),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.Conv2d(in_channels=32, out_channels = 32, kernel_size=3, stride=1, padding=1),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.Conv2d(in_channels=32, out_channels = 1, kernel_size=3, stride=1, padding=1),
          nn.BatchNorm2d(1),
        )
        self.rnn = nn.GRU(128, 16, bidirectional = True, batch_first=True)
        self.last = nn.Linear(32, len(tokenizer.char_map.keys()))
        self.softmax = nn.Softmax(dim=2)
    def forward(self, input):
        features = self.features(input)
        # 1, 1, 128, 1237
        features = features.squeeze(0)
        # 1, 128, 1237
        features = features.permute(0, 2, 1)
        # 1, 1237, 128
        output, (hidden) = self.rnn(features)
        # 1, 1237, 32
        logits = self.last(output)
        # 1, 1237, num_words
        return self.softmax(logits)


In [13]:
model = BiRNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CTCLoss()

In [14]:
def decode(output_model: torch.Tensor):
    output = output_model.squeeze(0)
    output = torch.argmax(output, dim=1).cpu().detach().numpy()

    text = tokenizer.indices_to_text(output)
    result_text = ""
    for symb in text:
        if len(result_text) != 0 and result_text[-1] == symb:
            continue
        result_text += symb

    return result_text
    

In [16]:
from tqdm.auto import tqdm

num_epochs = 5

losses = []

for epoch in tqdm(range(num_epochs)):

    sum_loss, cnt_loss = 0, 0
    for batch in tqdm(train_loader):
        spectrograms, labels, input_lengths, label_lengths = batch
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()
        output = model(spectrograms)
        output = output.transpose(0, 1)

        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()
        sum_loss += loss.item()
        cnt_loss += 1
        optimizer.step()
    print(f"MEAN TRAIN LOSS PER {epoch + 1}: {sum_loss / cnt_loss}")
    sum_loss, cnt_loss = 0, 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            spectrograms, labels, input_lengths, label_lengths = batch
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            output = model(spectrograms)
            output = output.transpose(0, 1)
            loss = criterion(output, labels, input_lengths, label_lengths)
            sum_loss += loss.item()
            cnt_loss += 1

    print(f"MEAN TEST LOSS PER {epoch + 1}: {sum_loss / cnt_loss}")


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/28539 [00:00<?, ?it/s]

MEAN TRAIN LOSS PER 1: -3.03473911137031


  0%|          | 0/2620 [00:00<?, ?it/s]

MEAN TEST LOSS PER 1: -3.1307193499485044


  0%|          | 0/28539 [00:00<?, ?it/s]

MEAN TRAIN LOSS PER 2: -3.0376002788919636


  0%|          | 0/2620 [00:00<?, ?it/s]

MEAN TEST LOSS PER 2: -3.120141229374718


  0%|          | 0/28539 [00:00<?, ?it/s]

MEAN TRAIN LOSS PER 3: -3.038194098148956


  0%|          | 0/2620 [00:00<?, ?it/s]

MEAN TEST LOSS PER 3: -3.117859229240709


  0%|          | 0/28539 [00:00<?, ?it/s]

MEAN TRAIN LOSS PER 4: -3.0384568298968677


  0%|          | 0/2620 [00:00<?, ?it/s]

MEAN TEST LOSS PER 4: -3.1218641996838663


  0%|          | 0/28539 [00:00<?, ?it/s]

MEAN TRAIN LOSS PER 5: -3.0386496254530555


  0%|          | 0/2620 [00:00<?, ?it/s]

MEAN TEST LOSS PER 5: -3.1237300250366444


In [17]:
lim2show = 5
with torch.no_grad():
    for batch in tqdm(test_loader):
        spectrograms, labels, input_lengths, label_lengths = batch
        spectrograms, labels = spectrograms.to(device), labels.to(device)
            
        output = model(spectrograms)
        print(output)
        print(decode(output.cpu().detach()))
        lim2show -= 1
        if lim2show == 0:
            break        

  0%|          | 0/2620 [00:00<?, ?it/s]

tensor([[[0.4282, 0.0187, 0.0865,  ..., 0.0095, 0.0030, 0.0032],
         [0.3603, 0.0163, 0.1020,  ..., 0.0102, 0.0026, 0.0023],
         [0.3292, 0.0311, 0.0721,  ..., 0.0118, 0.0029, 0.0020],
         ...,
         [0.3324, 0.0593, 0.0395,  ..., 0.0149, 0.0049, 0.0057],
         [0.2834, 0.0627, 0.0397,  ..., 0.0154, 0.0053, 0.0061],
         [0.2317, 0.0629, 0.0412,  ..., 0.0171, 0.0061, 0.0072]]],
       device='cuda:0')
'
tensor([[[0.4710, 0.0162, 0.0737,  ..., 0.0099, 0.0035, 0.0037],
         [0.4020, 0.0149, 0.0876,  ..., 0.0108, 0.0031, 0.0027],
         [0.3865, 0.0258, 0.0620,  ..., 0.0119, 0.0031, 0.0022],
         ...,
         [0.3536, 0.0614, 0.0375,  ..., 0.0142, 0.0050, 0.0055],
         [0.2881, 0.0674, 0.0368,  ..., 0.0148, 0.0058, 0.0065],
         [0.2351, 0.0677, 0.0378,  ..., 0.0169, 0.0067, 0.0077]]],
       device='cuda:0')
'
tensor([[[0.4150, 0.0177, 0.0941,  ..., 0.0101, 0.0033, 0.0034],
         [0.3509, 0.0161, 0.1078,  ..., 0.0109, 0.0028, 0.0025],
      