# Training Our Custom Models 🏋️
- in this notebook we train and test our custom implementations of Deep Speech 2, Jasper, and Conformer
- the models take as input Mel Filterbanks (2D images) and output probability distribution over characters

## 1/5 Setup

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Custom libraries
from utils.preprocessing import Preprocessing
from utils.word_model import WordModel
from utils.decoders import DecoderGreedy
from utils.metrics import avg_cer, avg_wer
from utils.misc import pretty_params

# Plots
import wandb

# Models
from models.deep_speech.deep_speech_base import DeepSpeechBase
from models.deep_speech.deep_speech_attention import DeepSpeechAttention
from models.jasper.jasper_base import Jasper
from models.jasper.jasper_dr import JasperDR


Notes

- to use unigrams: set `WordModel("unigram")` and `stride=2`
- to use bigrams: set `WordModel("bigrams")` and `stride=4`

In [2]:
# General Hyper-params
seed = 42
batch_size = 16
epochs = 10

n_features = 128            # freq axis
stride = 2                  # time-axis striding

word_model = WordModel("unigram")
lr = 0.0005


In [3]:
# Deep Speech Hyper-params
stage_1 = 3                 # 1st stage
stage_2 = 12                 # 2nd stage

rnn_dim = 512               # for DeepSpeechBase
emb_dim = 512               # for DeepSpeechAttention


In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

cuda


## 2/5 Data processing
- set `download=True` if you are downloading for the first time
- we donwload the train set (`train-clean-100`), dev set (`dev-clean` and `dev-other`), and test set (`test-clean` and `test-other`)

In [5]:
prep = Preprocessing()

train_clean = prep.download(split='train-clean-100', download=False)

dev_clean = prep.download(split='dev-clean', download=False)
dev_other = prep.download(split='dev-other', download=False)

test_clean = prep.download(split='test-clean', download=False)
test_other = prep.download(split='test-other', download=False)



In [6]:
train_loader = DataLoader(dataset=train_clean,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=lambda x: prep.preprocess(x, "train-clean-100", stride, word_model))

dev_clean_loader = DataLoader(dataset=dev_clean,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: prep.preprocess(x, "dev-clean", stride, word_model))

dev_other_loader = DataLoader(dataset=dev_other,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: prep.preprocess(x, "dev-other", stride, word_model))

test_clean_loader = DataLoader(dataset=test_clean,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: prep.preprocess(x, "test-clean", stride, word_model))

test_other_loader = DataLoader(dataset=test_other,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: prep.preprocess(x, "test-other", stride, word_model))

In [7]:
## Sanity check ##

print(f"~ TRAIN \t{len(train_loader.dataset)} samples ~")
print(f"~ DEV CLEAN \t{len(dev_clean_loader.dataset)} samples ~")
print(f"~ DEV OTHER \t{len(dev_other_loader.dataset)} samples ~")
print(f"~ TEST CLEAN \t{len(test_clean_loader.dataset)} samples ~")
print(f"~ TEST OTHER \t{len(test_other_loader.dataset)} samples ~")

prep.print_loader_info(train_loader)

~ TRAIN 	28539 samples ~
~ DEV CLEAN 	2703 samples ~
~ DEV OTHER 	2864 samples ~
~ TEST CLEAN 	2620 samples ~
~ TEST OTHER 	2939 samples ~
+------ Dataloader length: 28539 ------+
# Batches: 1784
Spectrogram shape: [16, 1, 128, 1308]
Label shape: [16, 257]
Mel length (length of each spectrogram): [577, 654, 300, 484, 542, 584] ...
Idx length (length of each label): [197, 257, 121, 173, 196, 205] ...
+------------------------------------+


## 3/5 Model
Available models:

- DeepSpeech-Base
- DeepSpeech-Attention
- Jasper-Base

In [15]:
# DeepSpeechBase
deep_speech = DeepSpeechBase(n_cnn=stage_1,
                             n_rnn=stage_2, 
                             rnn_dim=rnn_dim, 
                             n_features=n_features, 
                             n_class=word_model.get_n_class(),
                             stride=stride,
                             drop_rate=0.2).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])
model_name="deep_speech_base"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 56.79M (56792861)


In [16]:
# DeepSpeechAttention
deep_speech = DeepSpeechAttention(n_cnn=stage_1,
                                  n_enc=stage_2,
                                  n_features=n_features,
                                  n_class=word_model.get_n_class(),
                                  emb_dim=emb_dim,
                                  n_heads=4,
                                  stride=stride,
                                  drop_rate=0.2).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])
model_name="deep_speech_attention"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 20.32M (20319005)


In [8]:
# JasperBase
jasper = Jasper().to(device)

tot_params = sum([p.numel() for p in jasper.parameters()])
model_name="jasper_base"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 107.87M (107873693)


In [8]:
# Jasper DR (Dense Residual)
jasper = JasperDR().to(device)

tot_params = sum([p.numel() for p in jasper.parameters()])
model_name="jasper_DR"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 109.91M (109908125)


In [17]:
decoder = DecoderGreedy(word_model.get_blank_id())

## 4/5 Optimizer, loss, scheduler
- set here the model you want to train

In [18]:
model_to_train = deep_speech
# model_to_train = jasper

- Optionally, load model's weights

In [19]:
#path_to_weights = "./weights/jasper_base_lr_variable.pth"
path_to_weights = "./weights/3invres_12attention.pth"
model_to_train.load_state_dict(torch.load(path_to_weights))

<All keys matched successfully>

In [20]:
# Change model
adamW = optim.AdamW(model_to_train.parameters(), lr)
ctc_loss = nn.CTCLoss(blank=word_model.get_blank_id()).to(device)
one_cycle_lr = optim.lr_scheduler.OneCycleLR(adamW,
                                             max_lr=lr,
                                             steps_per_epoch=int(len(train_loader)),
                                             epochs=epochs,
                                             anneal_strategy="linear")

## 5/5 Train & Test

In [None]:
# Online fancy plots
! wandb login

wandb.init(
    project="asr_librispeech",

    config= {
        "model": model_name,
        "word_model": word_model.get_name()
    }
)

Training info
- After the forward pass, output must be  `[batch_size, seq_len, n_class]`
- CTC loss expects predictions to be `[seq_len, batch_size, n_class]`
- Train your model for `epochs` number of epochs
- At the end of each epoch, get WER/CER on validation/test dataset

Deep Speech
- Works with 3D inputs, so leave `squeeze=False`
- Outputs Tensors with shape `[batch_size, seq_len, n_class]`, so leave `swap_dim=False`

Jasper
- Jasper takes 2D Tensors instead of 3D, so set `squeeze=True`
- It also outputs predictions with shape `[batch_size, n_class, seq_len]`, so set `swap_dim=True`

In [12]:
def train(epoch, dataset_loader, model, optimizer, scheduler, fn_loss, 
          squeeze=False, swap_dim=False):
    print(f"Traininig... (e={epoch})")
    
    # Train mode ON
    model.train()
    n_samples = int(len(dataset_loader.dataset))

    for idx, audio_data in enumerate(dataset_loader):
        
        # Get audio data
        spectrograms, indices, len_spectrograms, len_indices = audio_data
        spectrograms, indices = spectrograms.to(device), indices.to(device)

        if squeeze:
            spectrograms = spectrograms.squeeze()

        optimizer.zero_grad()

        # Forward pass
        out = model(spectrograms)

        if swap_dim:
            out = out.transpose(1, 2)
        
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1)
        
        # Backward pass
        loss = fn_loss(out, indices, len_spectrograms, len_indices)
        loss.backward()

        # Log
        wandb.log({
            "loss": loss.item(),
            "lr": scheduler.get_last_lr()[0]
        })

        # Step
        optimizer.step()
        scheduler.step()

        # Log
        if idx % 20 == 0 or idx == n_samples:
            print("Epoch: {}, [{}/{}], Loss: {:.6f}".format(
                epoch, 
                idx*len(spectrograms), 
                n_samples,
                loss.item()))

    wandb.finish()

In [28]:
def test(epoch, dataset_name, dataset_loader, model, optimizer, fn_loss, 
         debug=False, squeeze=False, swap_dim=False):
    print(f"Testing on {dataset_name} (epoch={epoch})")
    model.eval()

    total_loss = 0
    wer_list = []
    cer_list = []

    n_batch = int(len(dataset_loader))

    with torch.no_grad():
        for idx, audio_data in enumerate(dataset_loader):
        
            # Get audio data
            spectrograms, indices, len_spectrograms, len_indices = audio_data
            spectrograms, indices = spectrograms.to(device), indices.to(device)

            if squeeze:
                spectrograms = spectrograms.squeeze()

            optimizer.zero_grad()
            
            # Forward pass
            out = model(spectrograms)

            if swap_dim:
                out = out.transpose(1, 2)

            out = F.log_softmax(out, dim=2)
            out = out.transpose(0, 1)

            # Compute loss
            loss = fn_loss(out, indices, len_spectrograms, len_indices)
            total_loss += loss.item() / n_batch

            # Metrics
            decode_hypothesis = decoder.decode_prob(out, word_model)
            decode_reference = decoder.decode_labels(indices, len_indices, word_model)

            wer_list.append(avg_wer(decode_hypothesis, decode_reference))
            cer_list.append(avg_cer(decode_hypothesis, decode_reference))
            
            if idx % 20 == 0:
                print(f'Idx: {idx}')
                print(f'reference: {decode_reference[0]}')
                print(f'hypothesis: {decode_hypothesis[0]}')
                print(f"WER: {wer_list[-1]:.4f}, CER: {cer_list[-1]:.4f}")
                print()

            if debug: break
            
    print(f"Loss: {total_loss:.6f}")
    print(f"WER: {sum(wer_list)/len(wer_list):.4f}")
    print(f"CER: {sum(cer_list)/len(cer_list):.4f}")

In [None]:
for epoch in range(1, epochs+1):
    train(epoch, train_loader, model_to_train, adamW, one_cycle_lr, ctc_loss, squeeze=True, swap_dim=True)
    test(epoch, "dev-clean", dev_clean_loader, model_to_train, adamW, ctc_loss, squeeze=True, swap_dim=True)

Evaluation
- pick the split you want to get your metrics on

In [None]:
test(1, "dev-clean", dev_clean_loader, model_to_train, adamW, ctc_loss, squeeze=False, swap_dim=False)

In [None]:
test(1, "dev-other", dev_other_loader, model_to_train, adamW, ctc_loss, squeeze=True, swap_dim=True)

In [None]:
test(1, "test-clean", test_clean_loader, model_to_train, adamW, ctc_loss, squeeze=True, swap_dim=True)

In [None]:
test(1, "test-other", test_other_loader, jasper, adamW, ctc_loss, squeeze=True, swap_dim=True)