# Train 🏋️
- this notebook is used to train different models
- we also train different versions of the same model
- we save the corresponding weights under the `weights` folder

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Custom libraries
from utils.preprocessing import Preprocessing
from utils.decoders import DecoderBase
from utils.metrics import batch_cer, batch_wer
from utils.misc import pretty_params

# Plots
import wandb

# Models
from models.deep_speech_2 import MyDeepSpeech

In [2]:
! wandb login

wandb: Currently logged in as: iu4hry (fantastic_4). Use `wandb login --relogin` to force relogin


In [13]:
# Hyperparams
seed = 42
batch_size = 16
blank_id = 28
epochs = 10

# Deep speech hyperparams
dp_hp = {
    "n_conv2d": 3,
    "n_rnns": 5,
    "rnn_dim": 512,
    "n_class": 29,
    "n_bins": 128,
    "stride": 2,
    "drop_rate": 0.1,
    "lr": 0.0005
}


In [14]:
# Set up the device
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

print(device)

cuda


In [15]:
# Data processing using custom library
prep = Preprocessing()
train_set = prep.download(split='train', download=False)



In [16]:
# Data loading
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=lambda x: prep.preprocess(x, "train"))

n_batches = int(len(train_loader))
n_samples = int(len(train_loader.dataset))

print(f"Number of batches: {n_batches}")
print(f"Number of samples: {n_samples}")

Number of batches: 1784
Number of samples: 28539


In [17]:
# Model
deep_speech = MyDeepSpeech(dp_hp['n_conv2d'],
                           dp_hp['n_rnns'], 
                           dp_hp['rnn_dim'], 
                           dp_hp['n_class'], 
                           dp_hp['n_bins']).to(device)

tot_params = sum([p.numel() for p in deep_speech.parameters()])

print(f"#params: {pretty_params(tot_params)} ({tot_params})")

#params: 23.71M (23705373)


In [18]:
# Decoder 
decoder = DecoderBase()

In [19]:
# Optimizer, loss, scheduler
optimizer = optim.AdamW(deep_speech.parameters(), dp_hp['lr'])
criterion = nn.CTCLoss(blank=blank_id).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                          max_lr=dp_hp['lr'],
                                          steps_per_epoch=n_batches,
                                          epochs=epochs,
                                          anneal_strategy="linear")

In [20]:
# Wandb
wandb.init(
    project="asr_librispeech",

    config= {
        "model": "deep_speech_2_base",
        "dataset": "librispeech_train",
    }
)

0,1
cer,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wer,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
cer,1.0
loss,3.00177
wer,1.0


In [None]:
# Train loop
for epoch in range(1, epochs+1):
    cer_list = []
    wer_list = []

    # Train mode ON
    deep_speech.train()
    
    for idx, audio_data in enumerate(train_loader):
        
        # Get audio data
        spectograms, indices, len_spectograms, len_indices = audio_data
        spectograms, indices = spectograms.to(device), indices.to(device)

        optimizer.zero_grad()
        
        # Forward pass
        out = deep_speech(spectograms) # [batch, time, n_class]
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1) # [time, batch, n_class]
        
        # Backward pass
        loss = criterion(out, indices, len_spectograms, len_indices)
        loss.backward()

        # Step
        optimizer.step()
        scheduler.step()

        # Metrics
        decode_preds = decoder.decode_prob(out, prep.tokenizer)
        decode_targets = decoder.decode_labels(indices, len_indices, prep.tokenizer)

        wer_list.append(batch_wer(decode_targets, decode_preds, average=True))
        cer_list.append(batch_cer(decode_targets, decode_preds, average=True))

        # Log
        if idx % 64 == 0 or idx == n_samples:
            avg_wer = sum(wer_list) / len(wer_list)
            avg_cer = sum(cer_list) / len(cer_list)
            print("Epoch: {}, [{}/{}], Loss: {:.6f}, avg_WER: {:.2f}, avg_CER: {:.2f}".format(
                epoch, 
                idx*len(spectograms), 
                n_samples,
                loss.item(),
                avg_wer,
                avg_cer))

            wandb.log({
                "loss": loss.item(),
                "wer": avg_wer,
                "cer": avg_cer
            })

wandb.finish()