In [1]:
%config Completer.use_jedi = False

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import warnings
from pprint import pprint

warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning)

In [9]:
import os
import json
import itertools
from time import time
from collections import defaultdict

import regex
import numpy as np
import pandas as pd
import editdistance

from tqdm import tqdm_notebook as tqdm

import matplotlib.pyplot as plt
from ctcdecode import CTCBeamDecoder

import torchaudio

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, RandomSampler

from vocabulary import Vocab

from src.audio_utils import open_audio
from src.audio_utils import make_transform, get_default_audio_transforms
from src.audio_utils import AudioTransformsChain, AudioTransformsExclusive
from src.audio_utils import SpectrogramTransform, compute_log_mel_spectrogram

from src.datasets import AudioDataset
from src.datasets import AudioDatasetSampler, collate_fn

from src.datasets import manifest_train_test_split
from src.datasets import convert_libri_manifest_to_common_voice
from src.datasets import convert_open_stt_manifest_to_common_voice

from src.deepspeech import Model

from src.decoding import calc_wer, calc_wer_for_batch
from src.decoding import decode, greedy_decoder, beam_search_decode, fast_beam_search_decode

from src.optimization import get_prediction, get_model_results
from src.optimization import get_prediction, get_model_results, validate, training

In [44]:
# Set proper device for computations,
dtype, device, cuda_device_id = torch.float32, None, 0
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(str(cuda_device_id) if cuda_device_id is not None else '')
if cuda_device_id is not None and torch.cuda.is_available():
    device = 'cuda:{0:d}'.format(0)
else:
    device = torch.device('cpu')
    
print(f'dtype: {dtype}, device: {device}, cuda_device_id {cuda_device_id}')

dtype: torch.float32, device: cuda:0, cuda_device_id 0


# Create Vocabulary

In [45]:
def get_num_tokens(vocab):
    ### write your code here ###
    num_tokens = len(vocab.tokens2indices())
    return num_tokens

def get_blank_index(vocab):
    ### write your code here ###
    blank_index = vocab['<blank>']
    return blank_index

In [46]:
alphabet = [
    'а', 'б', 'в', 'г', 'д', 'е', 'ё', 'ж', 'з', 'и', 'й', 'к',
    'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф', 'х', 'ц',
    'ч', 'ш', 'щ', 'ь', 'ы', 'ъ', 'э', 'ю', 'я', ' ', '<blank>'
]

vocab = Vocab(alphabet)

num_tokens = get_num_tokens(vocab) 
blank_index = get_blank_index(vocab)

The `unk_token` '<unk>' wasn't found in the tokens. Adding the `unk_token` to the end of the Vocab.


# Choose Audio Transforms

In [103]:
# audio_transforms = get_default_audio_transforms()
audio_transforms = None

In [104]:
sample_rate = 8000

# Create Datasets

## Load Common Voice dataset

In [105]:
common_voice_val_manifest_path = '/home/e.chuykova/data/val.txt'
common_voice_test_manifest_path = '/home/e.chuykova/data/test.txt'
common_voice_train_manifest_path = '/home/e.chuykova/data/train.txt'

common_voice_val_dataset = AudioDataset(
    common_voice_val_manifest_path, vocab, sample_rate=sample_rate,
#     evaluate_stats=True
)
common_voice_test_dataset = AudioDataset(
    common_voice_test_manifest_path, vocab, sample_rate=sample_rate,
#     evaluate_stats=True
)
common_voice_train_dataset = AudioDataset(
    common_voice_train_manifest_path, vocab, sample_rate=sample_rate,
    audio_transforms=audio_transforms,
#     evaluate_stats=True
)

## Load LibriSpeech dataset

In [20]:
ls_dev_manifest_path = '/data/mnakhodnov/voice_data/libri_speech/dev/manifest.json'
ls_test_manifest_path = '/data/mnakhodnov/voice_data/libri_speech/test/manifest.json'
ls_train_manifest_path = '/data/mnakhodnov/voice_data/libri_speech/train/manifest.json'

ls_dev_manifest_path = convert_libri_manifest_to_common_voice(ls_dev_manifest_path)
ls_test_manifest_path = convert_libri_manifest_to_common_voice(ls_test_manifest_path)
ls_train_manifest_path = convert_libri_manifest_to_common_voice(ls_train_manifest_path)

In [106]:
ls_dev_dataset = AudioDataset(
    ls_dev_manifest_path, vocab=vocab, sample_rate=sample_rate, max_duration=10.0,
#     evaluate_stats=True
)
ls_test_dataset = AudioDataset(
    ls_test_manifest_path, vocab=vocab, sample_rate=sample_rate, max_duration=10.0,
#     evaluate_stats=True
)
ls_train_dataset = AudioDataset(
    ls_train_manifest_path, vocab=vocab, sample_rate=sample_rate, max_duration=10.0,
    audio_transforms=audio_transforms,
#     evaluate_stats=True
)

## Load Open STT (radio_2) dataset

In [22]:
open_stt_manifest_path = '/data/mnakhodnov/voice_data/radio_2/radio_2.csv'

open_stt_manifest_path = convert_open_stt_manifest_to_common_voice(open_stt_manifest_path, min_duration=2.0)
open_stt_test_manifest_path, open_stt_train_manifest_path = manifest_train_test_split(open_stt_manifest_path, ratio=0.005)

In [107]:
open_stt_test_dataset = AudioDataset(
    open_stt_test_manifest_path, vocab=vocab, sample_rate=sample_rate, min_duration=2.0, max_duration=10.0,
#     evaluate_stats=True
)
open_stt_train_dataset = AudioDataset(
    open_stt_train_manifest_path, vocab=vocab, sample_rate=sample_rate, min_duration=2.0, max_duration=10.0,
    audio_transforms=audio_transforms,
#     evaluate_stats=True
)

## Combine all datasets for training

In [108]:
combined_dataset = AudioDataset(
    [common_voice_train_manifest_path, ls_train_manifest_path, open_stt_train_manifest_path], 
    vocab=vocab, sample_rate=sample_rate, max_duration=10.0,
    audio_transforms=audio_transforms,
#     evaluate_stats=True
)

# Create Dataloaders

In [109]:
batch_size = 80
num_workers = 8

## Common Voice

In [110]:
# YOUR CODE
common_voice_val_dataloader = DataLoader(
    common_voice_val_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)
# YOUR CODE
common_voice_test_dataloader = DataLoader(
    common_voice_test_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)
# YOUR CODE
common_voice_train_dataloader = DataLoader(
    common_voice_train_dataset, batch_size=batch_size, 
    sampler=AudioDatasetSampler(common_voice_train_dataset, batch_size=batch_size),
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)

## Libri Speech

In [111]:
ls_dev_dataloader = DataLoader(
    ls_dev_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)
ls_test_dataloader = DataLoader(
    ls_test_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)
ls_train_dataloader = DataLoader(
    ls_train_dataset, batch_size=batch_size, 
    sampler=AudioDatasetSampler(ls_train_dataset, batch_size=batch_size),
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)

## OpenSTT

In [112]:
open_stt_test_dataloader = DataLoader(
    open_stt_test_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)
open_stt_train_dataloader = DataLoader(
    open_stt_train_dataset, batch_size=batch_size, 
    sampler=AudioDatasetSampler(open_stt_train_dataset, batch_size=batch_size),
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)

## Combined Dataloader

In [113]:
combined_dataloader = DataLoader(
    combined_dataset, batch_size=batch_size, 
    sampler=AudioDatasetSampler(combined_dataset, batch_size=batch_size),
    num_workers=num_workers, pin_memory=True, collate_fn=collate_fn
)

# Create Model

## Choose LM for beam search decoder

In [114]:
kenlm_bin_path = '/home/mnakhodnov/kenlm/build/bin'

# This models are sorted wr to their size and speed 
# kenlm_data_path = '/data/mnakhodnov/language_data/cc100/xaa.processed.1'
# kenlm_data_path = '/data/mnakhodnov/language_data/cc100/xaa.processed.2'
# kenlm_data_path = '/data/mnakhodnov/language_data/cc100/xaa.processed.3'
# kenlm_data_path = '/data/mnakhodnov/language_data/cc100/xaa.processed.4'
kenlm_data_path = '/data/mnakhodnov/language_data/common_voice/train.txt'
kenlm_arpa_path, kenlm_binary_path = kenlm_data_path + '.arpa', kenlm_data_path + '.binary'

In [115]:
fast_beam_kwargs = {
    'beam_size': 10, 'cutoff_top_n': 5, 'cutoff_prob': 1.0, 
    'ext_scoring_func': kenlm_binary_path, 'alpha': 1.0, 'beta': 0.3, 'num_processes': 32
}

In [116]:
def load_from_ckpt(model, ckpt_path):
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])

In [117]:
num_tokens = len(vocab.tokens2indices()) - 1
num_mel_bins = 64
hidden_size= 512
num_layers = 4

In [118]:
num_epochs = 7
model_dir = 'models/6'
log_every_n_batch = 10

model = Model(
    num_mel_bins=num_mel_bins,
    hidden_size=hidden_size,
    num_layers=num_layers,
    num_tokens=num_tokens
)
# load_from_ckpt(model, '/home/e.chuykova/data/ckpt.pt')
load_from_ckpt(model, '/home/mnakhodnov/sirius-stt/models/2/epoch_5.pt')
model = model.to(device=device)

In [119]:
learning_rate = 2e-4
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CTCLoss(blank=blank_index, reduction='mean')

In [120]:
# ls_train_dataloader.sampler.epoch = 0
# open_stt_train_dataloader.sampler.epoch = 0
# common_voice_train_dataloader.sampler.epoch = 0

In [121]:
# spectrogram_transform = None
# spectrogram_transform_first_epoch = None

# spectrogram_transform = SpectrogramTransform(freq_mask_param=10, time_mask_param=10)
# spectrogram_transform_first_epoch = None

spectrogram_transform = SpectrogramTransform(freq_mask_param=10, time_mask_param=10)
spectrogram_transform_first_epoch = 1

In [None]:
num_epochs = 100

training(
    model=model, optimizer=opt, loss_fn=loss_fn, num_epochs=num_epochs, 
#     train_dataloader=[common_voice_train_dataloader, 'common_voice/train'],
#     train_dataloader=[ls_train_dataloader, 'libre_speech/train'],
#     train_dataloader=[open_stt_train_dataloader, 'open_stt/train'],
    train_dataloader=[combined_dataloader, 'combined/train'],
    val_dataloaders={
        'open_stt/test': ls_test_dataloader,
        'libre_speech/dev': ls_dev_dataloader,
        'libre_speech/test': ls_test_dataloader,
        'common_voice/val': common_voice_val_dataloader,
    }, log_every_n_batch=log_every_n_batch, model_dir=model_dir, vocab=vocab,
    beam_kwargs=fast_beam_kwargs, 
    spectrogram_transform=None, 
    spectrogram_transform_first_epoch=spectrogram_transform_first_epoch
)

  0%|          | 0/6013 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]


Epoch 1 of 100 took 2832.715866088867s, train loss: 0.6435314777989104, val loss: 0.23915940976142883, train wer: 0.4843038323742978, val wer: 0.37617130674880683


  0%|          | 0/6013 [00:00<?, ?it/s]