# Train Deep Speech🏋️
- in this notebook we train Deep Speech-based models
- we vary the number of conv and res blocks and check their performance

## 1/5 Setup

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Custom libraries
from utils.preprocessing import Preprocessing
from utils.decoders import DecoderBase
from utils.metrics import batch_cer, batch_wer
from utils.misc import pretty_params

# Plots
import wandb

# Models
from models.deep_speech import DeepSpeech

In [2]:
# Hyperparams
seed = 42
batch_size = 16
blank_id = 28
epochs = 10
n_class = 29            # alphabet (+26), space (+1), apostrophe (+1), blank (+1) = 29

n_features = 128
stride = 2              # time-freq downsample
n_conv2d = 6            # 1st stage
n_rnns = 3              # 2nd stage
rnn_dim = 512
drop_rate = 0.1
lr = 0.0005


In [3]:
# Set up CUDA device
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

print(device)

cuda


## 2/5 Data processing

In [4]:
# Data processing using custom library
prep = Preprocessing()
train_set = prep.download(split='train', download=False)
test_set = prep.download(split='test', download=False)



In [5]:
# Data loading
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=lambda x: prep.preprocess(x, "train"))

test_loader = DataLoader(dataset=test_set,
                          batch_size=batch_size,
                          shuffle=False,
                          collate_fn=lambda x: prep.preprocess(x, "test"))

In [6]:
n_batches_train = int(len(train_loader))
n_samples_train = int(len(train_loader.dataset))

n_batches_test = int(len(test_loader))
n_samples_test = int(len(test_loader.dataset))

print("~ TRAIN ~")
print(f"Number of batches: {n_batches_train}")
print(f"Number of samples: {n_samples_train}")
print()

print("~ TEST ~")
print(f"Number of batches: {n_batches_test}")
print(f"Number of samples: {n_samples_test}")
print()

~ TRAIN ~
Number of batches: 1784
Number of samples: 28539

~ TEST ~
Number of batches: 164
Number of samples: 2620



## 3/5 Model

In [7]:
# Model
deep_speech = DeepSpeech(n_conv2d, n_rnns, rnn_dim, n_class, n_features).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])
print(f"Number of parameters: {pretty_params(tot_params)} ({tot_params})")

Number of parameters: 14.25M (14252829)


In [8]:
# Decoder 
decoder = DecoderBase()

## 4/5 Optimizer, loss, scheduler

In [9]:
# Optimizer, loss, scheduler
optimizer = optim.AdamW(deep_speech.parameters(), lr)
criterion = nn.CTCLoss(blank=blank_id).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                          max_lr=lr,
                                          steps_per_epoch=n_batches_train,
                                          epochs=epochs,
                                          anneal_strategy="linear")

## 5/5 Train & Test

In [None]:
# Online fancy plots
! wandb login

wandb.init(
    project="asr_librispeech",

    config= {
        "model": "deep_speech",
        "n_conv2d": n_conv2d,
        "n_rnns": n_rnns,
        "rnn_dim": rnn_dim
    }
)

In [13]:
# Train loop
def train(epoch):
    print(f"Traininig... (e={epoch})")
    
    # Train mode ON
    deep_speech.train()
    
    for idx, audio_data in enumerate(train_loader):
        
        # Get audio data
        spectograms, indices, len_spectograms, len_indices = audio_data
        spectograms, indices = spectograms.to(device), indices.to(device)

        optimizer.zero_grad()
        
        # Forward pass
        out = deep_speech(spectograms) # [batch, time, n_class]
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1) # [time, batch, n_class]
        
        # Backward pass
        loss = criterion(out, indices, len_spectograms, len_indices)
        loss.backward()

        # Log
        wandb.log({
            "loss": loss.item(),
            "lr": scheduler.get_last_lr()
        })

        # Step
        optimizer.step()
        scheduler.step()

        # Log
        if idx % 10 == 0 or idx == n_samples_train:
            # avg_wer = sum(wer_list) / len(wer_list)
            # avg_cer = sum(cer_list) / len(cer_list)
            print("Epoch: {}, [{}/{}], Loss: {:.6f}".format(
                epoch, 
                idx*len(spectograms), 
                n_samples_train,
                loss.item()))

    wandb.finish()

In [14]:
def test(epoch):
    print(f"Testing... (e={epoch})")
    deep_speech.eval()

    total_loss = 0
    wer_list = []
    cer_list = []

    with torch.no_grad():
        for idx, audio_data in enumerate(train_loader):
        
            # Get audio data
            spectograms, indices, len_spectograms, len_indices = audio_data
            spectograms, indices = spectograms.to(device), indices.to(device)

            optimizer.zero_grad()
            
            # Forward pass
            out = deep_speech(spectograms) # [batch, time, n_class]
            out = F.log_softmax(out, dim=2)
            out = out.transpose(0, 1) # [time, batch, n_class]

            # Compute loss (but do not backprop)
            loss = criterion(out, indices, len_spectograms, len_indices)
            total_loss += loss.item() / n_batches_test

            # Metrics
            decode_preds = decoder.decode_prob(out, prep.tokenizer)
            decode_targets = decoder.decode_labels(indices, len_indices, prep.tokenizer)

            wer_list.append(batch_wer(decode_targets, decode_preds, average=True))
            cer_list.append(batch_cer(decode_targets, decode_preds, average=True))

    print(f"Loss: {total_loss:.6f}")
    print(f"WER: {sum(wer_list)/len(wer_list):.4f}")
    print(f"CER: {sum(cer_list)/len(cer_list):.4f}")

In [None]:
# Be careful: this will kill you GPU
for epoch in range(1, epochs+1):
    train(epoch)
    test(epoch)