# Train Deep Speech🏋️
- in this notebook we train Deep Speech-based models
- we vary the number of conv and res blocks and check their performance

## 1/5 Setup

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Custom libraries
from utils.preprocessing import Preprocessing
from utils.word_model import WordModel
from utils.decoders import DecoderBase
from utils.metrics import avg_cer, avg_wer
from utils.misc import pretty_params

# Plots
import wandb

# Models
from models.deep_speech_base import DeepSpeechBase
from models.deep_speech_attention import DeepSpeechAttention

Notes
- if you set `word_model` to `unigram`, set `stride=2`
- if you set `word_model` to `bigram`, set `stride=2`
- striding on the frequencies is set to `2` (left unchanged)

In [2]:
# Hyperparams
seed = 42
batch_size = 16
epochs = 10

# Unigram: apostrophe (+1), space (+1), alphabet (+26), blank (+1) = 29
# Bigram: apostrophe (+1), space (+1), alphabet (+702), blank (+1) = 705
#word_model = WordModel("unigram")
word_model = WordModel("bigram")

n_features = 128
stride = 4              # time-axis striding

stage_1 = 3                 # 1st stage
stage_2 = 6                 # 2nd stage

rnn_dim = 512           # for DeepSpeechBase
emb_dim = 512           # for DeepSpeechAttention

drop_rate = 0.2
lr = 0.0005

In [3]:
# Set up CUDA device
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

print(device)

cuda


## 2/5 Data processing

In [4]:
# Data processing using custom library
prep = Preprocessing()
train_set = prep.download(split='train', download=False)
# dev test
test_set = prep.download(split='test', download=False)



In [5]:
# Data loading
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=lambda x: prep.preprocess(x, "train", stride, word_model))

test_loader = DataLoader(dataset=test_set,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=lambda x: prep.preprocess(x, "test", stride, word_model))

In [6]:
n_batches_train = int(len(train_loader))
n_samples_train = int(len(train_loader.dataset))

n_batches_test = int(len(test_loader))
n_samples_test = int(len(test_loader.dataset))

print("~ TRAIN ~")
print(f"Number of batches: {n_batches_train}")
print(f"Number of samples: {n_samples_train}")
print()

print("~ TEST ~")
print(f"Number of batches: {n_batches_test}")
print(f"Number of samples: {n_samples_test}")
print()

prep.print_loader_info(train_loader)

~ TRAIN ~
Number of batches: 1784
Number of samples: 28539

~ TEST ~
Number of batches: 164
Number of samples: 2620

+------ Dataloader length: 28539 ------+
# Batches: 1784
Spectogram shape: [16, 1, 128, 1308]
Label shape: [16, 164]
Mel length (length of each spectogram): [288, 327, 150, 242, 271, 292] ...
Idx length (length of each label): [123, 164, 81, 110, 128, 130] ...
+------------------------------------+


## 3/5 Model
- there are 2 available models: DeepSpeechBase and DeepSpeechAttention

In [7]:
# DeepSpeechBase
deep_speech = DeepSpeechBase(stage_1, 
                             stage_2, 
                             rnn_dim, 
                             n_features, 
                             word_model.get_n_class(),
                             stride=stride,
                             drop_rate=drop_rate).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])
model_name="deep_speech_base"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 28.78M (28778945)


In [7]:
# DeepSpeechAttention
deep_speech = DeepSpeechAttention(stage_1,
                                  stage_2,
                                  n_features,
                                  word_model.get_n_class(),
                                  emb_dim,
                                  n_heads=4,
                                  stride=stride,
                                  drop_rate=0.2).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])
model_name="deep_speech_attention"
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 11.20M (11197889)


In [None]:
# Load model weights
path_to_weights = "./weights/deep_speech_12_04.pth"
deep_speech.load_state_dict(torch.load(path_to_weights))

In [8]:
# Decoder 
decoder = DecoderBase()

## 4/5 Optimizer, loss, scheduler

In [9]:
# Optimizer, loss, scheduler
adamW = optim.AdamW(deep_speech.parameters(), lr)
ctc_loss = nn.CTCLoss(blank=word_model.get_blank_id()).to(device)
one_cycle_lr = optim.lr_scheduler.OneCycleLR(adamW,
                                             max_lr=lr,
                                             steps_per_epoch=n_batches_train,
                                             epochs=epochs,
                                             anneal_strategy="linear")

## 5/5 Train & Test

In [11]:
# Online fancy plots
! wandb login

wandb.init(
    project="asr_librispeech",

    config= {
        "model": model_name,
        "stage_1": stage_1,
        "stage_2": stage_2,
        "word_model": word_model.get_name(),
        "stride": stride
    }
)

wandb: Currently logged in as: iu4hry (fantastic_4). Use `wandb login --relogin` to force relogin
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33miu4hry[0m ([33mfantastic_4[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# Train loop
def train(epoch, dataset_loader, model, optimizer, scheduler, fn_loss):
    print(f"Traininig... (e={epoch})")
    
    # Train mode ON
    model.train()
    n_samples = int(len(dataset_loader.dataset))

    for idx, audio_data in enumerate(dataset_loader):
        
        # Get audio data
        spectograms, indices, len_spectograms, len_indices = audio_data
        spectograms, indices = spectograms.to(device), indices.to(device)

        optimizer.zero_grad()
        
        # Forward pass
        out = model(spectograms) # [batch, time, n_class]
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1) # [time, batch, n_class]
        
        # Backward pass
        loss = fn_loss(out, indices, len_spectograms, len_indices)
        loss.backward()

        # Log
        wandb.log({
            "loss": loss.item(),
            "lr": scheduler.get_last_lr()
        })

        # Step
        optimizer.step()
        scheduler.step()

        # Log
        if idx % 10 == 0 or idx == n_samples:
            print("Epoch: {}, [{}/{}], Loss: {:.6f}".format(
                epoch, 
                idx*len(spectograms), 
                n_samples,
                loss.item()))

    wandb.finish()

In [11]:
def test(epoch, dataset_loader, model, optimizer, fn_loss, debug=False):
    print(f"Testing... (e={epoch})")
    model.eval()

    total_loss = 0
    wer_list = []
    cer_list = []

    n_batch = int(len(dataset_loader))

    with torch.no_grad():
        for idx, audio_data in enumerate(dataset_loader):
        
            # Get audio data
            spectograms, indices, len_spectograms, len_indices = audio_data
            spectograms, indices = spectograms.to(device), indices.to(device)

            optimizer.zero_grad()
            
            # Forward pass
            out = model(spectograms) # [batch, time, n_class]
            out = F.log_softmax(out, dim=2)
            out = out.transpose(0, 1) # [time, batch, n_class]

            # Compute loss (but do not backprop)
            loss = fn_loss(out, indices, len_spectograms, len_indices)
            total_loss += loss.item() / n_batch

            # Metrics
            decode_hypothesis = decoder.decode_prob(out, word_model)
            decode_reference = decoder.decode_labels(indices, len_indices, word_model)

            wer_list.append(avg_wer(decode_hypothesis, decode_reference))
            cer_list.append(avg_cer(decode_hypothesis, decode_reference))
            
            if debug: break
            
    print(f"Loss: {total_loss:.6f}")
    print(f"WER: {sum(wer_list)/len(wer_list):.4f}")
    print(f"CER: {sum(cer_list)/len(cer_list):.4f}")

In [None]:
# Be careful: this will kill you GPU
for epoch in range(1, epochs+1):
    train(epoch, train_loader, deep_speech, adamW, one_cycle_lr, ctc_loss)
    test(epoch, test_loader, deep_speech, adamW, ctc_loss)

In [None]:
# Dummy train (one epoch)
train(1, train_loader, deep_speech, adamW, one_cycle_lr, ctc_loss)

In [None]:
# Evaluation
test(1, test_loader, deep_speech, adamW, ctc_loss, debug=True)