In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.display import display, HTML, Video
display(HTML("<style>.container { width:90% !important; }</style>"))

# Load data

In [3]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

import sentencepiece

In [4]:
import os
import glob
import json
import regex

import tqdm.notebook as tqdm

import numpy as np
import pandas as pd

from ipywidgets import GridBox, Audio, HBox, VBox, Box, Label, Layout

import matplotlib.pyplot as plt
import matplotlib_inline

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [5]:
def sample_dataset(dataset, n=4):
    grid = []
    for idx in range(n):
        grid.append(
            VBox([
                Label('{0:d}, {1}, {2:.1f}'.format(idx, dataset['text'][idx], dataset.get('duration', dataset['audio_len'])[idx])),
                Audio.from_file(dataset['audio_path'][idx], autoplay=False, loop=False),
            ])
        )
    
    return HBox([VBox(grid[0::2]), VBox(grid[1::2])])

In [6]:
base_path = './dataset/'

libri_speech_base_path = os.path.join(base_path, 'LibriSpeech/ruls_data')

In [7]:
from src.dataset import get_libri_speech_dataset, get_golos_dataset

In [28]:
libri_speech_dev = get_libri_speech_dataset(libri_speech_base_path, split='dev')

print('Loaded {0:d} objects'.format(len(libri_speech_dev['audio_path'])))

Loaded 1400 objects


In [9]:
# Load tokenizer model
sp_tokenizer = sentencepiece.SentencePieceProcessor(model_file='tokenizer.model')

In [10]:
from src.dataset import AudioDataset, collate_fn

In [11]:
libri_speech_dev_ds = AudioDataset(libri_speech_dev, sp_tokenizer, min_duration=1.36, max_duration=10.96)
libri_speech_dev_ds[0]

('./dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_16_pushkin_0039.wav',
 tensor([-4.8218e-03, -4.8828e-03, -4.7913e-03,  ...,  3.0518e-05,
          0.0000e+00,  3.0518e-05]),
 34720,
 'дай бог чтоб просветились мы',
 tensor([ 26,   3,  12,  39,   6,  29, 113,  22,  89,   5,  58,  59,  37,   5,
          42,   1,   9,  20]),
 18)

In [12]:
batch_size = 8
num_workers = 0

libri_speech_dev_dl = DataLoader(
    libri_speech_dev_ds, batch_size=batch_size, shuffle=False,
    num_workers=num_workers, pin_memory=False, collate_fn=collate_fn
)
batch = next(iter(libri_speech_dev_dl))
batch

{'audio_path': ('./dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_16_pushkin_0039.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_19_pushkin_0090.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_31_pushkin_0052.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_20_pushkin_0033.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_19_pushkin_0055.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_13_pushkin_0030.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_17_pushkin_0055.wav',
  './dataset/LibriSpeech/ruls_data/dev/audio/5397/2145/poemi_19_pushkin_0103.wav'),
 'audio': tensor([[-4.8218e-03, -4.8828e-03, -4.7913e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.0071e-03,  1.0071e-03,  8.8501e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-9.3079e-03, -1.1719e-02, -1.3184e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00

In [13]:
sample_dataset(batch)

HBox(children=(VBox(children=(VBox(children=(Label(value='0, дай бог чтоб просветились мы, 34720.0'), Audio(va…

In [14]:
from src.preprocessor import AudioToMelSpectrogramPreprocessor

In [15]:
preprocessor = AudioToMelSpectrogramPreprocessor()

In [16]:
features, feature_lengths = preprocessor(batch['audio'], batch['audio_len'])

# Implement `src.preprocessor.ConvSubsampling`

In [17]:
import copy

def deterministic_fill(module, seed=64, verbose=False, eps=1e-5):
    module = copy.deepcopy(module)
    
    rng = np.random.default_rng(seed)
     
    for name, param in sorted(module.state_dict().items()):
        if verbose:
            print(name)
        data = torch.tensor(rng.random(param.shape), dtype=param.dtype, device=param.device)
        data = (data - data.mean()) / (data.std() + eps)
        param.data.copy_(data)
        
    return module

def diff_check(left, right, decimals):
    return abs(float(left) - float(right)) < 10 ** (-decimals)

In [18]:
from src.preprocessor import ConvSubsampling

In [19]:
pre_encode = ConvSubsampling(
    feat_in=80, feat_out=176,
    sampling_num=2, conv_channels=176
)

In [20]:
features_enc, feature_enc_lengths = pre_encode(torch.transpose(features, 1, 2), feature_lengths)

In [21]:
preprocessor.eval()

rng = np.random.default_rng(42)
rnd_audio = torch.tensor(rng.random([13, 12345]), dtype=torch.float32)
rnd_audio_len = torch.tensor(rng.integers(low=0, high=12345, size=[13]), dtype=torch.long)
rnd_features, rnd_feature_lengths = preprocessor(rnd_audio, rnd_audio_len)

rnd_pre_encode = deterministic_fill(pre_encode, seed=64).eval()
rnd_features_enc, rnd_feature_enc_lengths = rnd_pre_encode(torch.transpose(rnd_features, 1, 2), rnd_feature_lengths)

assert sum(p.numel() for p in rnd_pre_encode.parameters()) == 900416
assert list(rnd_features_enc.shape) == [13, 20, 176]
assert rnd_feature_lengths.cpu().tolist() == [43, 48, 34, 32, 54, 45, 41, 59, 21, 4, 72, 16, 31]
assert diff_check(rnd_features_enc.abs().mean(), 1817.9062, 4)

# Implement `src.preprocessor.RelPositionalEncoding`

In [22]:
import math
from src.encoding import RelPositionalEncoding

In [23]:
pos_enc = RelPositionalEncoding(
    d_model=176,
    dropout_rate=0.1,
    max_len=5000,
    xscale=math.sqrt(176),
    dropout_rate_emb=0.0,
)

pos_enc.extend_pe(length=5000, device=torch.device('cpu'))

In [24]:
assert list(pos_enc.pe.shape) == [1, 9999, 176]
assert diff_check(pos_enc.pe.abs().mean(), 0.6366, 4)

In [25]:
preprocessor.eval()
pre_encode.eval()
pos_enc.eval()

rng = np.random.default_rng(42)
rnd_audio = torch.tensor(rng.random([13, 12345]), dtype=torch.float32)
rnd_audio_len = torch.tensor(rng.integers(low=0, high=12345, size=[13]), dtype=torch.long)
rnd_features, rnd_feature_lengths = preprocessor(rnd_audio, rnd_audio_len)

rnd_pre_encode = deterministic_fill(pre_encode, seed=64).eval()
rnd_features_enc, rnd_feature_enc_lengths = rnd_pre_encode(torch.transpose(rnd_features, 1, 2), rnd_feature_lengths)

rnd_features_emb, rnd_pos_emb = pos_enc(rnd_features_enc)

assert list(rnd_features_emb.shape) == [13, 20, 176]
assert list(rnd_pos_emb.shape) == [1, 39, 176]

assert diff_check(rnd_features_emb.abs().mean(), 24117.2559, 4)
assert diff_check(rnd_pos_emb.abs().mean(), 0.5670, 4)

# Implement `src.encoder.ConformerEncoder._create_masks`

In [26]:
from src.encoder import ConformerEncoder

In [27]:
max_length, lengths = 10, torch.tensor([4, 5, 2, 7], dtype=torch.long)
pad_mask, att_mask = ConformerEncoder._create_masks(max_length, lengths, device=torch.device('cpu'))

assert list(pad_mask.shape) == [lengths.shape[0], max_length]
assert list(att_mask.shape) == [lengths.shape[0], max_length, max_length]

assert torch.all(pad_mask.sum(dim=1) == max_length - lengths)
assert torch.all(torch.sum(~att_mask, dim=(1, 2)) == lengths ** 2)

print(pad_mask, att_mask, sep='\n')

tensor([[False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True]])
tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True, 