<a href="https://colab.research.google.com/github/captainkeemo/Dysarthric-Speech-Transcription/blob/main/models/base_CTC_RNNT_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jiwer editdistance --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.1/3.1 MB[0m [31m157.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m78.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import torch
import torchaudio
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from jiwer import wer, cer
import editdistance
import glob
from tqdm import tqdm
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

In [4]:
from torch.cuda.amp import autocast, GradScaler
scaler_ctc = GradScaler()
scaler_rnnt = GradScaler()

  scaler_ctc = GradScaler()
  scaler_rnnt = GradScaler()


In [5]:
# Define vocabulary
vocab = list("abcdefghijklmnopqrstuvwxyz '") + ["|"]
char_to_index = {c: i for i, c in enumerate(vocab)}
index_to_char = {i: c for c, i in char_to_index.items()}


In [6]:
def clean_transcription(text):
    text = text.lower().strip()
    if '.jpg' in text or '[say' in text or text == 'xxx' or text == '':
        return None
    return text

def decode_sequence(indices, index_to_char):
    return ''.join([index_to_char.get(i, '') for i in indices]).replace('|', '').strip()

In [7]:
# Function to get the device
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# Preprocessing Mel spectrograms
def preprocess_and_save_mel(root_dir, save_dir, sample_rate=16000, n_mels=80, use_gpu=False):
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
    mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels).to(device)
    resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=sample_rate).to(device)

    # Step 1: Gather .wav files with tqdm
    print("Scanning for .wav files...")
    wav_paths = []
    for dirpath, _, filenames in tqdm(os.walk(root_dir), desc="Scanning folders"):
        for fname in filenames:
            if fname.endswith('.wav') and 'wav_headMic' in dirpath:
                wav_paths.append(os.path.join(dirpath, fname))

    print(f"Found {len(wav_paths)} .wav files. Starting preprocessing...")

    total_saved = 0

    # Step 2: Preprocess with tqdm
    for wav_path in tqdm(wav_paths, desc="Preprocessing Mel Spectrograms"):
        mel_path = os.path.join(save_dir, os.path.basename(wav_path).replace('.wav', '.pt'))
        txt_path = wav_path.replace('wav_headMic', 'prompts').replace('.wav', '.txt')

        if not os.path.exists(txt_path) or os.path.exists(mel_path):
            continue

        try:
            waveform, sr = torchaudio.load(wav_path)
            waveform = waveform.to(device)

            if sr != sample_rate:
                waveform = resampler(waveform)

            mel_spec = mel_transform(waveform).squeeze(0).transpose(0, 1).cpu()

            with open(txt_path, 'r', encoding='utf-8') as f:
                transcript = f.read().strip().lower()

            if '.jpg' in transcript or transcript in ('xxx', '') or '[say' in transcript:
                continue

            torch.save((mel_spec, transcript), mel_path)
            total_saved += 1
        except Exception as e:
            print(f"Error processing {wav_path}: {e}")

    print(f"\nDone: {total_saved} mel .pt files saved to {save_dir}")


In [9]:
'''
preprocess_and_save_mel(
    root_dir="/content/drive/MyDrive/TORGO",
    save_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"
)
'''

'\npreprocess_and_save_mel(\n    root_dir="/content/drive/MyDrive/TORGO",\n    save_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"\n)\n'

In [10]:
class TorgoDataset(Dataset):
    def __init__(self, root_dir, char_to_index, sample_rate=16000,
                 items_file=None, use_precomputed=False, mel_dir=None):
        self.char_to_index = char_to_index
        self.sample_rate = sample_rate
        self.use_precomputed = use_precomputed
        self.items = []

        if self.use_precomputed:
            assert mel_dir is not None, "If using precomputed features, you must provide mel_dir."
            self.paths = sorted([
                os.path.join(mel_dir, f)
                for f in os.listdir(mel_dir) if f.endswith('.pt')
            ])
        else:
            self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=80)
            self.resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=sample_rate)

            if items_file is not None:
                loaded_items = torch.load(items_file)
                for wav_fp, text in loaded_items:
                    cleaned = clean_transcription(text)
                    if cleaned:
                        self.items.append((wav_fp, cleaned))
            else:
                wav_files = []
                for dirpath, _, filenames in os.walk(root_dir):
                    for fname in filenames:
                        if fname.endswith('.wav') and 'wav_headMic' in dirpath:
                            wav_files.append(os.path.join(dirpath, fname))

                for wav_fp in tqdm(wav_files, desc=f"Building dataset from .wav files ({len(wav_files)} found)"):
                    txt_fp = wav_fp.replace('wav_headMic', 'prompts').replace('.wav', '.txt')
                    if not os.path.exists(txt_fp):
                        continue
                    with open(txt_fp, 'r', encoding='utf-8') as f:
                        text = f.read()
                    cleaned = clean_transcription(text)
                    if cleaned:
                        self.items.append((wav_fp, cleaned))

    def __len__(self):
        return len(self.paths) if self.use_precomputed else len(self.items)

    def __getitem__(self, idx):
        if self.use_precomputed:
            mel_path = self.paths[idx]
            mel_spec, transcript = torch.load(mel_path)
        else:
            wav_path, transcript = self.items[idx]
            waveform, sr = torchaudio.load(wav_path)
            if sr != self.sample_rate:
                waveform = self.resampler(waveform)
            mel_spec = self.mel_transform(waveform).squeeze(0).transpose(0, 1)

        target = torch.tensor([self.char_to_index[c] for c in transcript if c in self.char_to_index], dtype=torch.long)
        return mel_spec, target


In [11]:
def collate_fn(batch):
    features, targets = zip(*batch)

    # Compute original lengths
    input_lengths = torch.tensor([feat.size(0) for feat in features], dtype=torch.long)
    target_lengths = torch.tensor([len(tgt) for tgt in targets], dtype=torch.long)

    # Pad features (time dimension is dim 0)
    padded_features = pad_sequence(features, batch_first=True)  # shape: [B, T, F]

    # Pad targets for inspection (not used in loss)
    padded_targets = pad_sequence(targets, batch_first=True, padding_value=0)

    return padded_features, input_lengths, padded_targets, target_lengths


In [12]:
# Define the full RNN-T model with encoder, decoder, and joiner
class RNNTModel(nn.Module):
    def __init__(self, input_dim=80, vocab_size=len(vocab), hidden_dim=128, embed_dim=64):
        super().__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.decoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.joiner = nn.Linear(hidden_dim * 2 + hidden_dim, vocab_size)  # enc (bi) + dec (uni)

    def forward(self, encoder_input, target_input):
        # encoder_input: [B, T_enc, 80]
        # target_input: [B, T_dec]
        enc_out, _ = self.encoder(encoder_input)  # [B, T_enc, 2*H]
        dec_emb = self.embed(target_input)        # [B, T_dec, E]
        dec_out, _ = self.decoder(dec_emb)        # [B, T_dec, H]

        # Expand for joiner broadcasting
        enc_exp = enc_out.unsqueeze(2)            # [B, T_enc, 1, 2H]
        dec_exp = dec_out.unsqueeze(1)            # [B, 1, T_dec, H]

        # Joiner: concat and project
        enc_b, t_enc, _, h_enc = enc_exp.shape
        dec_b, _, t_dec, h_dec = dec_exp.shape
        assert enc_b == dec_b

        enc_rep = enc_exp.expand(-1, t_enc, t_dec, -1)  # [B, T_enc, T_dec, 2H]
        dec_rep = dec_exp.expand(-1, t_enc, t_dec, -1)  # [B, T_enc, T_dec, H]

        joined = torch.cat([enc_rep, dec_rep], dim=-1)  # [B, T_enc, T_dec, 3H]
        logits = self.joiner(joined)                    # [B, T_enc, T_dec, V]
        return logits


In [13]:
# CTC Model Class
class CTCModel(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=128, vocab_size=len(vocab), num_layers=2, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, x):
        x, _ = self.lstm(x)
        logits = self.fc(x)
        return logits  # shape: [B, T, vocab_size]


In [14]:
class BeamSearch:
    def __init__(self, beam_width, vocab_size, blank_index):
        self.beam_width = beam_width
        self.vocab_size = vocab_size
        self.blank_index = blank_index

    def decode(self, logits):
        batch_size = logits.size(1)
        T = logits.size(0)
        results = []

        for b in range(batch_size):
            beams = [([], 0.0)]
            for t in range(T):
                new_beams = []
                log_probs = F.log_softmax(logits[t, b], dim=-1)
                topk = torch.topk(log_probs, self.beam_width)

                for seq, score in beams:
                    for i in range(self.beam_width):
                        token = topk.indices[i].item()
                        new_score = score + topk.values[i].item()
                        if token == self.blank_index:
                            new_beams.append((seq, new_score))
                        else:
                            new_beams.append((seq + [token], new_score))

                new_beams.sort(key=lambda x: x[1], reverse=True)
                beams = new_beams[:self.beam_width]

            best_seq = max(beams, key=lambda x: x[1])[0]
            results.append(best_seq)  # Changed from results.append([best_seq])

        # Ensure results is always a list of lists even with beam_width=1
        results = [[item] if isinstance(item, int) else item for item in results]

        return results

In [15]:
# Compute WER, CER, Edit Distance
from torchaudio.functional import edit_distance as torchaudio_edit_distance
from jiwer import wer, cer

def compute_wer_cer(preds, input_lens, targets, target_lens, index_to_char):
    total_wer, total_cer = 0.0, 0.0
    offset = 0
    valid_examples = 0

    for i in range(len(preds)):
        tlen = target_lens[i].item()
        if tlen == 0:
            continue  # skip empty targets

        target_slice = targets[offset:offset + tlen]
        target_seq = target_slice.view(-1).tolist()
        offset += tlen

        pred_seq = preds[i] if isinstance(preds[i], (list, tuple)) else [preds[i]]
        pred_str = ''.join([index_to_char.get(p, '') for p in pred_seq])
        target_str = ''.join([index_to_char.get(t, '') for t in target_seq])

        if not target_str.strip():
            continue  # skip if the reference string is still empty

        total_wer += wer(target_str, pred_str)
        total_cer += cer(target_str, pred_str)
        valid_examples += 1

    if valid_examples == 0:
        return 1.0, 1.0  # fallback if no valid examples
    return total_wer / valid_examples, total_cer / valid_examples



In [16]:
from jiwer import cer

def decode_indices(indices, index_to_char):
    return ''.join([index_to_char[i] for i in indices if i in index_to_char]).replace('|', ' ').strip()

def compute_cer_accuracy(preds, targets, target_lens, index_to_char):
    offset = 0
    all_refs, all_hyps = [], []
    for i, tlen in enumerate(target_lens):
        ref = ''.join([index_to_char[x.item()] for x in targets[offset:offset + tlen]])
        hyp = ''.join([index_to_char[x] for x in preds[i]])
        all_refs.append(ref)
        all_hyps.append(hyp)
        offset += tlen
    return 1.0 - cer(all_refs, all_hyps)


In [17]:
# Greedy decoding for character prediction from model output
def greedy_decode(logits, input_lens, blank_index):
    pred = logits.argmax(dim=2)
    results = []
    for i in range(pred.size(0)):
        tokens, prev = [], None
        for j in range(input_lens[i]):
            idx = pred[i, j].item()
            if idx != blank_index and idx != prev:
                tokens.append(idx)
            prev = idx
        results.append(tokens)
    return results

In [18]:
from tqdm import tqdm

def train_ctc(model, train_loader, val_loader, optimizer, loss_fn, index_to_char,
              device, epochs=20, patience=5, blank_idx=1, beam_width=5):

    model.to(device)
    best_val_loss = float('inf')
    wait = 0

    train_loss_hist, val_loss_hist = [], []
    train_wer_hist, val_wer_hist = [], []
    train_cer_hist, val_cer_hist = [], []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        model.train()
        train_loss, train_wer, train_cer = 0, 0, 0

        train_bar = tqdm(train_loader, desc="Training", leave=False)
        for feats, feat_lens, targets, target_lens in train_bar:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            optimizer.zero_grad()
            logits = model(feats)
            log_probs = logits.log_softmax(2).transpose(0, 1)
            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            with torch.no_grad():
                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                raw_preds = decoder.decode(log_probs.cpu())

                preds = [beam[0] if isinstance(beam[0], (list, tuple)) else beam for beam in raw_preds]
                preds = [p if isinstance(p, list) else [p] for p in preds]

                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                train_wer += wer_score
                train_cer += cer_score

            train_bar.set_postfix(loss=loss.item(), WER=wer_score, CER=cer_score)

        train_loss /= len(train_loader)
        train_wer /= len(train_loader)
        train_cer /= len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss, val_wer, val_cer = 0, 0, 0
        val_bar = tqdm(val_loader, desc="Validating", leave=False)
        with torch.no_grad():
            for feats, feat_lens, targets, target_lens in val_bar:
                feats, feat_lens = feats.to(device), feat_lens.to(device)
                targets, target_lens = targets.to(device), target_lens.to(device)

                logits = model(feats)
                log_probs = logits.log_softmax(2).transpose(0, 1)
                loss = loss_fn(log_probs, targets, feat_lens, target_lens)
                val_loss += loss.item()

                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                raw_preds = decoder.decode(log_probs.cpu())

                preds = [beam[0] if isinstance(beam[0], (list, tuple)) else beam for beam in raw_preds]
                preds = [p if isinstance(p, list) else [p] for p in preds]

                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                val_wer += wer_score
                val_cer += cer_score

                val_bar.set_postfix(loss=loss.item(), WER=wer_score, CER=cer_score)

        val_loss /= len(val_loader)
        val_wer /= len(val_loader)
        val_cer /= len(val_loader)

        train_loss_hist.append(train_loss)
        val_loss_hist.append(val_loss)
        train_wer_hist.append(train_wer)
        val_wer_hist.append(val_wer)
        train_cer_hist.append(train_cer)
        val_cer_hist.append(val_cer)

        print(f"[CTC] Epoch {epoch+1}: "
              f"Train Loss={train_loss:.4f}, WER={train_wer:.4f}, CER={train_cer:.4f} | "
              f"Val Loss={val_loss:.4f}, WER={val_wer:.4f}, CER={val_cer:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_ctc_model.pt')
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    return train_loss_hist, val_loss_hist, train_wer_hist, val_wer_hist, train_cer_hist, val_cer_hist


In [19]:
from tqdm import tqdm

def train_rnnt(model, train_loader, val_loader, optimizer, loss_fn, index_to_char,
               device, epochs=20, patience=5, blank_idx=1, beam_width=5):

    model.to(device)
    best_val_loss = float('inf')
    wait = 0

    train_loss_hist, val_loss_hist = [], []
    train_wer_hist, val_wer_hist = [], []
    train_cer_hist, val_cer_hist = [], []

    for epoch in range(epochs):
        model.train()
        train_loss, train_wer, train_cer = 0, 0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} - Training (RNN-T)")
        for feats, feat_lens, targets, target_lens in train_pbar:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            optimizer.zero_grad()
            logits = model(feats, targets)
            T_enc = logits.shape[1]
            T_dec = logits.shape[2]
            min_len = min(T_enc, T_dec, feat_lens.max().item())
            logits = logits[:, torch.arange(min_len), torch.arange(min_len), :]
            log_probs = logits.log_softmax(2).transpose(0, 1)
            feat_lens = torch.clamp(feat_lens, max=log_probs.shape[0])

            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            with torch.no_grad():
                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                preds = decoder.decode(log_probs.cpu())
                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                train_wer += wer_score
                train_cer += cer_score

        train_loss /= len(train_loader)
        train_wer /= len(train_loader)
        train_cer /= len(train_loader)

        model.eval()
        val_loss, val_wer, val_cer = 0, 0, 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} - Validation (RNN-T)")
        with torch.no_grad():
            for feats, feat_lens, targets, target_lens in val_pbar:
                feats, feat_lens = feats.to(device), feat_lens.to(device)
                targets, target_lens = targets.to(device), target_lens.to(device)

                logits = model(feats, targets)
                T_enc = logits.shape[1]
                T_dec = logits.shape[2]
                min_len = min(T_enc, T_dec, feat_lens.max().item())
                logits = logits[:, torch.arange(min_len), torch.arange(min_len), :]
                log_probs = logits.log_softmax(2).transpose(0, 1)
                feat_lens = torch.clamp(feat_lens, max=log_probs.shape[0])

                loss = loss_fn(log_probs, targets, feat_lens, target_lens)
                val_loss += loss.item()

                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                preds = decoder.decode(log_probs.cpu())
                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                val_wer += wer_score
                val_cer += cer_score

        val_loss /= len(val_loader)
        val_wer /= len(val_loader)
        val_cer /= len(val_loader)

        train_loss_hist.append(train_loss)
        val_loss_hist.append(val_loss)
        train_wer_hist.append(train_wer)
        val_wer_hist.append(val_wer)
        train_cer_hist.append(train_cer)
        val_cer_hist.append(val_cer)

        print(f"[RNN-T] Epoch {epoch+1}: Train Loss={train_loss:.4f}, WER={train_wer:.4f}, CER={train_cer:.4f} | "
              f"Val Loss={val_loss:.4f}, WER={val_wer:.4f}, CER={val_cer:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_rnnt_model.pt')
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    return train_loss_hist, val_loss_hist, train_wer_hist, val_wer_hist, train_cer_hist, val_cer_hist


In [31]:
def evaluate_and_plot_model(
    model, test_loader, loss_fn, index_to_char, device,
    blank_idx=1, use_beam=True, beam_width=5, model_type="ctc"
):
    model.eval()
    test_loss = 0
    total_wer = 0
    total_cer = 0

    with torch.no_grad():
        for feats, feat_lens, targets, target_lens in test_loader:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            if model_type == "rnnt":
                logits = model(feats, targets)
                T_enc = logits.shape[1]
                T_dec = logits.shape[2]
                min_len = min(T_enc, T_dec, feat_lens.max().item())
                logits = logits[:, torch.arange(min_len), torch.arange(min_len), :]
            else:
                logits = model(feats)

            log_probs = logits.log_softmax(2).transpose(0, 1)

            # Clamp feat_lens in case it exceeds time dim of logits
            feat_lens = torch.clamp(feat_lens, max=log_probs.size(0))

            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            test_loss += loss.item()

            # Decode predictions
            if use_beam:
                decoder = BeamSearch(beam_width, logits.size(-1), blank_idx)
                preds = decoder.decode(log_probs.cpu())
            else:
                preds = logits.cpu().argmax(2)  # greedy decode
                preds = [pred[:feat_lens[i]].tolist() for i, pred in enumerate(preds)]

            wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
            total_wer += wer_score
            total_cer += cer_score

    avg_loss = test_loss / len(test_loader)
    avg_wer = total_wer / len(test_loader)
    avg_cer = total_cer / len(test_loader)

    print(f"[{model_type.upper()}] Test Loss={avg_loss:.4f}, WER={avg_wer:.4f}, CER={avg_cer:.4f}")

    return avg_loss, avg_wer, avg_cer


In [21]:
# Load data
data_dir = "/content/drive/MyDrive/TORGO"

# Load dataset
#dataset = TorgoDataset(data_dir, char_to_index)
#torch.save(dataset.items, '/content/drive/MyDrive/torgo_items.pt')
dataset = TorgoDataset(
    root_dir="/content/drive/MyDrive/TORGO",  # still required for structure
    char_to_index=char_to_index,
    use_precomputed=True,
    mel_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"
)
print(f"Dataset length: {len(dataset)}")

# Calculate lengths
total_len = len(dataset)
train_len = int(0.7 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len

# Perform split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn, pin_memory=True)



Dataset length: 993


In [23]:
import time

# Train CTC model
model_ctc = CTCModel()
optimizer_ctc = torch.optim.Adam(model_ctc.parameters(), lr=1e-5)
loss_fn_ctc = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
device = get_device()

start_time = time.time()
ctc_train_loss, ctc_val_loss, ctc_train_wer, ctc_val_wer, ctc_train_cer, ctc_val_cer = train_ctc(
    model=model_ctc,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer_ctc,
    loss_fn=loss_fn_ctc,
    index_to_char=index_to_char,
    epochs=50,
    patience=3,
    device=device
)
end_time = time.time()
print(f"Baseline CTC Training time: {end_time - start_time} seconds")


Epoch 1/50




[CTC] Epoch 1: Train Loss=109.3613, WER=2.8536, CER=1.7385 | Val Loss=112.8398, WER=1.4667, CER=1.6343

Epoch 2/50




[CTC] Epoch 2: Train Loss=109.1865, WER=1.7649, CER=1.6454 | Val Loss=112.5942, WER=1.1667, CER=1.6357

Epoch 3/50




[CTC] Epoch 3: Train Loss=108.7976, WER=1.2526, CER=1.8312 | Val Loss=112.3103, WER=1.1667, CER=1.6365

Epoch 4/50




[CTC] Epoch 4: Train Loss=108.6299, WER=1.1828, CER=1.5666 | Val Loss=111.9798, WER=1.1667, CER=1.6389

Epoch 5/50




[CTC] Epoch 5: Train Loss=108.3899, WER=1.2594, CER=1.5588 | Val Loss=111.5968, WER=1.1111, CER=1.6419

Epoch 6/50




[CTC] Epoch 6: Train Loss=107.8791, WER=1.0952, CER=1.6558 | Val Loss=111.1653, WER=1.1111, CER=1.6471

Epoch 7/50




[CTC] Epoch 7: Train Loss=107.6187, WER=1.0545, CER=1.6357 | Val Loss=110.6840, WER=1.0556, CER=1.6490

Epoch 8/50




[CTC] Epoch 8: Train Loss=107.0574, WER=1.0928, CER=1.6609 | Val Loss=110.1566, WER=1.0556, CER=1.6513

Epoch 9/50




[CTC] Epoch 9: Train Loss=106.5746, WER=1.0996, CER=1.5028 | Val Loss=109.5877, WER=1.1111, CER=1.6541

Epoch 10/50




[CTC] Epoch 10: Train Loss=106.0998, WER=1.0615, CER=1.5707 | Val Loss=108.9754, WER=1.1111, CER=1.6563

Epoch 11/50




[CTC] Epoch 11: Train Loss=105.2394, WER=1.0899, CER=1.8223 | Val Loss=108.3137, WER=1.0556, CER=1.6569

Epoch 12/50




[CTC] Epoch 12: Train Loss=104.9297, WER=1.0247, CER=1.7580 | Val Loss=107.6030, WER=1.0000, CER=1.6582

Epoch 13/50




[CTC] Epoch 13: Train Loss=104.1397, WER=1.0152, CER=1.7930 | Val Loss=106.8384, WER=1.0000, CER=1.6607

Epoch 14/50




[CTC] Epoch 14: Train Loss=103.2422, WER=1.0000, CER=2.0476 | Val Loss=106.0067, WER=1.0556, CER=1.6619

Epoch 15/50




[CTC] Epoch 15: Train Loss=102.4248, WER=1.0000, CER=1.4546 | Val Loss=105.1074, WER=1.0556, CER=1.6624

Epoch 16/50




[CTC] Epoch 16: Train Loss=101.5875, WER=1.0000, CER=1.6027 | Val Loss=104.1183, WER=1.0000, CER=1.6626

Epoch 17/50




[CTC] Epoch 17: Train Loss=100.6315, WER=1.0000, CER=1.5355 | Val Loss=103.0395, WER=1.0000, CER=1.6627

Epoch 18/50




[CTC] Epoch 18: Train Loss=99.5155, WER=1.0000, CER=1.7014 | Val Loss=101.8442, WER=1.0000, CER=1.6627

Epoch 19/50




[CTC] Epoch 19: Train Loss=98.2418, WER=1.0091, CER=1.9350 | Val Loss=100.5142, WER=1.0000, CER=1.6631

Epoch 20/50




[CTC] Epoch 20: Train Loss=96.8716, WER=1.0000, CER=1.5721 | Val Loss=99.0293, WER=1.0000, CER=1.6631

Epoch 21/50




[CTC] Epoch 21: Train Loss=95.3872, WER=1.0000, CER=1.6689 | Val Loss=97.3591, WER=1.0000, CER=1.6631

Epoch 22/50




[CTC] Epoch 22: Train Loss=93.9920, WER=1.0000, CER=1.7286 | Val Loss=95.5007, WER=1.0000, CER=1.6631

Epoch 23/50




[CTC] Epoch 23: Train Loss=91.8259, WER=1.0000, CER=1.8599 | Val Loss=93.3420, WER=1.0000, CER=1.6631

Epoch 24/50




[CTC] Epoch 24: Train Loss=89.6339, WER=1.0000, CER=1.6848 | Val Loss=90.9074, WER=1.0000, CER=1.6631

Epoch 25/50




[CTC] Epoch 25: Train Loss=87.2854, WER=1.0000, CER=1.6428 | Val Loss=88.1054, WER=1.0000, CER=1.6631

Epoch 26/50




[CTC] Epoch 26: Train Loss=84.3280, WER=1.0000, CER=1.4924 | Val Loss=84.8311, WER=1.0000, CER=1.6631

Epoch 27/50




[CTC] Epoch 27: Train Loss=81.0130, WER=1.0000, CER=1.4174 | Val Loss=81.0193, WER=1.0000, CER=1.6631

Epoch 28/50




[CTC] Epoch 28: Train Loss=77.0639, WER=1.0000, CER=1.5182 | Val Loss=76.5658, WER=1.0000, CER=1.6631

Epoch 29/50




[CTC] Epoch 29: Train Loss=72.4868, WER=1.0000, CER=1.7755 | Val Loss=71.3528, WER=1.0000, CER=1.6631

Epoch 30/50




[CTC] Epoch 30: Train Loss=67.0914, WER=1.0000, CER=1.6398 | Val Loss=65.3599, WER=1.0000, CER=1.6631

Epoch 31/50




[CTC] Epoch 31: Train Loss=60.9146, WER=1.0000, CER=1.6259 | Val Loss=58.4344, WER=1.0000, CER=1.6631

Epoch 32/50




[CTC] Epoch 32: Train Loss=53.9220, WER=1.0000, CER=1.7627 | Val Loss=50.6937, WER=1.0000, CER=1.6631

Epoch 33/50




[CTC] Epoch 33: Train Loss=46.1040, WER=1.0000, CER=1.8512 | Val Loss=42.3982, WER=1.0000, CER=1.6631

Epoch 34/50




[CTC] Epoch 34: Train Loss=37.9517, WER=1.0000, CER=1.6891 | Val Loss=34.1063, WER=1.0000, CER=1.6631

Epoch 35/50




[CTC] Epoch 35: Train Loss=30.0345, WER=1.0000, CER=1.6603 | Val Loss=26.4507, WER=1.0000, CER=1.6631

Epoch 36/50




[CTC] Epoch 36: Train Loss=23.2047, WER=1.0000, CER=1.5000 | Val Loss=20.1858, WER=1.0000, CER=1.6631

Epoch 37/50




[CTC] Epoch 37: Train Loss=17.7532, WER=1.0000, CER=1.6795 | Val Loss=15.5310, WER=1.0000, CER=1.6631

Epoch 38/50




[CTC] Epoch 38: Train Loss=13.8585, WER=1.0000, CER=1.5851 | Val Loss=12.2413, WER=1.0000, CER=1.6631

Epoch 39/50




[CTC] Epoch 39: Train Loss=11.0868, WER=1.0000, CER=1.5611 | Val Loss=10.0523, WER=1.0000, CER=1.6631

Epoch 40/50




[CTC] Epoch 40: Train Loss=9.2588, WER=1.0000, CER=1.7217 | Val Loss=8.5392, WER=1.0000, CER=1.6631

Epoch 41/50




[CTC] Epoch 41: Train Loss=7.9791, WER=1.0000, CER=1.6356 | Val Loss=7.4788, WER=1.0000, CER=1.6631

Epoch 42/50




[CTC] Epoch 42: Train Loss=7.0524, WER=1.0000, CER=1.6977 | Val Loss=6.7387, WER=1.0000, CER=1.6631

Epoch 43/50




[CTC] Epoch 43: Train Loss=6.4079, WER=1.0000, CER=1.7307 | Val Loss=6.1826, WER=1.0000, CER=1.6631

Epoch 44/50




[CTC] Epoch 44: Train Loss=5.9077, WER=1.0000, CER=1.8896 | Val Loss=5.7613, WER=1.0000, CER=1.6631

Epoch 45/50




[CTC] Epoch 45: Train Loss=5.5427, WER=1.0000, CER=1.6540 | Val Loss=5.4354, WER=1.0000, CER=1.6631

Epoch 46/50




[CTC] Epoch 46: Train Loss=5.2422, WER=1.0000, CER=1.5971 | Val Loss=5.1754, WER=1.0000, CER=1.6631

Epoch 47/50




[CTC] Epoch 47: Train Loss=5.0080, WER=1.0000, CER=1.7144 | Val Loss=4.9656, WER=1.0000, CER=1.6631

Epoch 48/50




[CTC] Epoch 48: Train Loss=4.8164, WER=1.0000, CER=1.6323 | Val Loss=4.7971, WER=1.0000, CER=1.6631

Epoch 49/50




[CTC] Epoch 49: Train Loss=4.6583, WER=1.0000, CER=1.7489 | Val Loss=4.6565, WER=1.0000, CER=1.6631

Epoch 50/50


                                                                                     

[CTC] Epoch 50: Train Loss=4.5330, WER=1.0000, CER=1.5944 | Val Loss=4.5413, WER=1.0000, CER=1.6631
Baseline CTC Training time: 4052.1429607868195 seconds




In [24]:
import time

# Train RNN-T model
model_rnnt = RNNTModel()
optimizer_rnnt = torch.optim.Adam(model_rnnt.parameters(), lr=1e-4)
loss_fn_rnnt = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
device = get_device()

start_time = time.time()
rnnt_train_loss, rnnt_val_loss, rnnt_train_wer, rnnt_val_wer, rnnt_train_cer, rnnt_val_cer = train_rnnt(
    model=model_rnnt,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer_rnnt,
    loss_fn=loss_fn_rnnt,
    index_to_char=index_to_char,
    epochs=50,
    patience=3,
    device=device
)
end_time = time.time()
print(f"Baseline RNN-T Training time: {end_time - start_time} seconds")

Epoch 1 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 1 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 1: Train Loss=36.4078, WER=1.0000, CER=0.9906 | Val Loss=39.6374, WER=1.0000, CER=0.9967


Epoch 2 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.01it/s]
Epoch 2 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 2: Train Loss=37.1022, WER=1.0000, CER=0.9983 | Val Loss=36.2681, WER=1.0000, CER=0.9999


Epoch 3 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Epoch 3 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 3: Train Loss=32.2876, WER=1.0000, CER=0.9998 | Val Loss=31.9682, WER=1.0000, CER=1.0000


Epoch 4 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 4 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 4: Train Loss=27.7080, WER=1.0000, CER=1.0000 | Val Loss=26.4101, WER=1.0000, CER=1.0000


Epoch 5 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 5 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 5: Train Loss=21.7209, WER=1.0000, CER=1.0000 | Val Loss=19.5983, WER=1.0000, CER=1.0000


Epoch 6 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.14it/s]
Epoch 6 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 6: Train Loss=14.7944, WER=1.0000, CER=1.0000 | Val Loss=12.9566, WER=1.0000, CER=1.0000


Epoch 7 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.10it/s]
Epoch 7 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 7: Train Loss=9.7065, WER=1.0000, CER=1.0000 | Val Loss=7.2053, WER=1.0000, CER=1.0000


Epoch 8 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.05it/s]
Epoch 8 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 8: Train Loss=5.3527, WER=1.0000, CER=1.0000 | Val Loss=4.0405, WER=1.0000, CER=1.0000


Epoch 9 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.10it/s]
Epoch 9 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 9: Train Loss=3.8283, WER=1.0000, CER=1.0000 | Val Loss=3.7272, WER=1.0000, CER=1.0000


Epoch 10 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.06it/s]
Epoch 10 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


[RNN-T] Epoch 10: Train Loss=3.7368, WER=1.0000, CER=1.0000 | Val Loss=3.6923, WER=1.0000, CER=1.0000


Epoch 11 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 11 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 11: Train Loss=3.6802, WER=1.0000, CER=1.0000 | Val Loss=3.6297, WER=1.0000, CER=1.0000


Epoch 12 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.10it/s]
Epoch 12 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 12: Train Loss=3.6008, WER=1.0000, CER=1.0000 | Val Loss=3.5786, WER=1.0000, CER=1.0000


Epoch 13 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.12it/s]
Epoch 13 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 13: Train Loss=3.5750, WER=1.0000, CER=1.0000 | Val Loss=3.5439, WER=1.0000, CER=1.0000


Epoch 14 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 14 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


[RNN-T] Epoch 14: Train Loss=3.5306, WER=1.0000, CER=1.0000 | Val Loss=3.5162, WER=1.0000, CER=1.0000


Epoch 15 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.11it/s]
Epoch 15 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 15: Train Loss=3.5011, WER=1.0000, CER=1.0000 | Val Loss=3.4924, WER=1.0000, CER=1.0000


Epoch 16 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.11it/s]
Epoch 16 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 16: Train Loss=3.4728, WER=1.0000, CER=1.0000 | Val Loss=3.4717, WER=1.0000, CER=1.0000


Epoch 17 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 17 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


[RNN-T] Epoch 17: Train Loss=3.4534, WER=1.0000, CER=1.0000 | Val Loss=3.4541, WER=1.0000, CER=1.0000


Epoch 18 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Epoch 18 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 18: Train Loss=3.4429, WER=1.0000, CER=1.0000 | Val Loss=3.4384, WER=1.0000, CER=1.0000


Epoch 19 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 19 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


[RNN-T] Epoch 19: Train Loss=3.4261, WER=1.0000, CER=1.0000 | Val Loss=3.4246, WER=1.0000, CER=1.0000


Epoch 20 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 20 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 20: Train Loss=3.4133, WER=1.0000, CER=1.0000 | Val Loss=3.4119, WER=1.0000, CER=1.0000


Epoch 21 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.06it/s]
Epoch 21 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 21: Train Loss=3.3930, WER=1.0000, CER=1.0000 | Val Loss=3.4013, WER=1.0000, CER=1.0000


Epoch 22 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 22 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 22: Train Loss=3.4118, WER=1.0000, CER=1.0000 | Val Loss=3.3915, WER=1.0000, CER=1.0000


Epoch 23 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.13it/s]
Epoch 23 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


[RNN-T] Epoch 23: Train Loss=3.3730, WER=1.0000, CER=1.0000 | Val Loss=3.3831, WER=1.0000, CER=1.0000


Epoch 24 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 24 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 24: Train Loss=3.3707, WER=1.0000, CER=1.0000 | Val Loss=3.3755, WER=1.0000, CER=1.0000


Epoch 25 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 25 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 25: Train Loss=3.3691, WER=1.0000, CER=1.0000 | Val Loss=3.3670, WER=1.0000, CER=1.0000


Epoch 26 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 26 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 26: Train Loss=3.3669, WER=1.0000, CER=1.0000 | Val Loss=3.3587, WER=1.0000, CER=1.0000


Epoch 27 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 27 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 27: Train Loss=3.3448, WER=1.0000, CER=1.0000 | Val Loss=3.3515, WER=1.0000, CER=1.0000


Epoch 28 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 28 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 28: Train Loss=3.3544, WER=1.0000, CER=1.0000 | Val Loss=3.3448, WER=1.0000, CER=1.0000


Epoch 29 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 29 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 29: Train Loss=3.3378, WER=1.0000, CER=1.0000 | Val Loss=3.3377, WER=1.0000, CER=1.0000


Epoch 30 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.06it/s]
Epoch 30 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 30: Train Loss=3.3359, WER=1.0000, CER=1.0000 | Val Loss=3.3306, WER=1.0000, CER=1.0000


Epoch 31 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.13it/s]
Epoch 31 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 31: Train Loss=3.3228, WER=1.0000, CER=1.0000 | Val Loss=3.3247, WER=1.0000, CER=1.0000


Epoch 32 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 32 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 32: Train Loss=3.3292, WER=1.0000, CER=1.0000 | Val Loss=3.3193, WER=1.0000, CER=1.0000


Epoch 33 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 33 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


[RNN-T] Epoch 33: Train Loss=3.3189, WER=1.0000, CER=1.0000 | Val Loss=3.3127, WER=1.0000, CER=1.0000


Epoch 34 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 34 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 34: Train Loss=3.3186, WER=1.0000, CER=1.0000 | Val Loss=3.3064, WER=1.0000, CER=1.0000


Epoch 35 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 35 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 35: Train Loss=3.2952, WER=1.0000, CER=1.0000 | Val Loss=3.2991, WER=1.0000, CER=1.0000


Epoch 36 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 36 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 36: Train Loss=3.2976, WER=1.0000, CER=1.0000 | Val Loss=3.2927, WER=1.0000, CER=1.0000


Epoch 37 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 37 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 37: Train Loss=3.2928, WER=1.0000, CER=1.0000 | Val Loss=3.2871, WER=1.0000, CER=1.0000


Epoch 38 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.18it/s]
Epoch 38 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


[RNN-T] Epoch 38: Train Loss=3.2734, WER=1.0000, CER=1.0000 | Val Loss=3.2815, WER=1.0000, CER=1.0000


Epoch 39 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Epoch 39 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 39: Train Loss=3.2856, WER=1.0000, CER=1.0000 | Val Loss=3.2759, WER=1.0000, CER=1.0000


Epoch 40 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.14it/s]
Epoch 40 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 40: Train Loss=3.2603, WER=1.0000, CER=1.0000 | Val Loss=3.2689, WER=1.0000, CER=1.0000


Epoch 41 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.10it/s]
Epoch 41 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 41: Train Loss=3.2707, WER=1.0000, CER=1.0000 | Val Loss=3.2617, WER=1.0000, CER=1.0000


Epoch 42 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.05it/s]
Epoch 42 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


[RNN-T] Epoch 42: Train Loss=3.2682, WER=1.0000, CER=1.0000 | Val Loss=3.2549, WER=1.0000, CER=1.0000


Epoch 43 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 43 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


[RNN-T] Epoch 43: Train Loss=3.2448, WER=1.0000, CER=1.0000 | Val Loss=3.2486, WER=1.0000, CER=1.0000


Epoch 44 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Epoch 44 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 44: Train Loss=3.2622, WER=1.0000, CER=1.0000 | Val Loss=3.2428, WER=1.0000, CER=1.0000


Epoch 45 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Epoch 45 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


[RNN-T] Epoch 45: Train Loss=3.2451, WER=1.0000, CER=1.0000 | Val Loss=3.2361, WER=1.0000, CER=1.0000


Epoch 46 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 46 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 46: Train Loss=3.2331, WER=1.0000, CER=1.0000 | Val Loss=3.2301, WER=1.0000, CER=1.0000


Epoch 47 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 47 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 47: Train Loss=3.2382, WER=1.0000, CER=1.0000 | Val Loss=3.2234, WER=1.0000, CER=1.0000


Epoch 48 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
Epoch 48 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 48: Train Loss=3.2219, WER=1.0000, CER=1.0000 | Val Loss=3.2167, WER=1.0000, CER=1.0000


Epoch 49 - Training (RNN-T): 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Epoch 49 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


[RNN-T] Epoch 49: Train Loss=3.2211, WER=1.0000, CER=1.0000 | Val Loss=3.2091, WER=1.0000, CER=1.0000


Epoch 50 - Training (RNN-T): 100%|██████████| 11/11 [00:09<00:00,  1.11it/s]
Epoch 50 - Validation (RNN-T): 100%|██████████| 2/2 [00:01<00:00,  1.37it/s]

[RNN-T] Epoch 50: Train Loss=3.2218, WER=1.0000, CER=1.0000 | Val Loss=3.2028, WER=1.0000, CER=1.0000
Baseline RNN-T Training time: 579.5520398616791 seconds





In [29]:
ctc_results = evaluate_and_plot_model(
    model=model_ctc,
    test_loader=test_loader,
    loss_fn=loss_fn_ctc,
    index_to_char=index_to_char,
    device='cuda',
    blank_idx=char_to_index['|'],
    use_beam=True,
    beam_width=5,
    model_type="ctc"
)


[Test - CTC] Loss: 4.4558 | WER: 1.0000 | CER: 1.0000


In [32]:
rnnt_results = evaluate_and_plot_model(
    model=model_rnnt,
    test_loader=test_loader,
    loss_fn=loss_fn_rnnt,
    index_to_char=index_to_char,
    device='cuda',
    blank_idx=char_to_index['|'],
    use_beam=True,
    beam_width=5,
    model_type="rnnt"
)


[RNNT] Test Loss=3.2369, WER=1.0000, CER=1.0000
