<a href="https://colab.research.google.com/github/captainkeemo/Dysarthric-Speech-Transcription/blob/main/models/base_CTC_RNNT_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jiwer editdistance --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.1/3.1 MB[0m [31m122.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
import torch
import torchaudio
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from jiwer import wer, cer
import editdistance
import glob
from tqdm import tqdm
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

In [4]:
from torch.cuda.amp import autocast, GradScaler
scaler_ctc = GradScaler()
scaler_rnnt = GradScaler()

  scaler_ctc = GradScaler()
  scaler_rnnt = GradScaler()


In [5]:
# Define vocabulary
vocab = list("abcdefghijklmnopqrstuvwxyz '") + ["|"]
char_to_index = {c: i for i, c in enumerate(vocab)}
index_to_char = {i: c for c, i in char_to_index.items()}


In [6]:
def clean_transcription(text):
    text = text.lower().strip()
    if '.jpg' in text or '[say' in text or text == 'xxx' or text == '':
        return None
    return text

def decode_sequence(indices, index_to_char):
    return ''.join([index_to_char.get(i, '') for i in indices]).replace('|', '').strip()

In [7]:
# Function to get the device
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# Preprocessing Mel spectrograms
def preprocess_and_save_mel(root_dir, save_dir, sample_rate=16000, n_mels=80, use_gpu=False):
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
    mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels).to(device)
    resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=sample_rate).to(device)

    # Step 1: Gather .wav files with tqdm
    print("Scanning for .wav files...")
    wav_paths = []
    for dirpath, _, filenames in tqdm(os.walk(root_dir), desc="Scanning folders"):
        for fname in filenames:
            if fname.endswith('.wav') and 'wav_headMic' in dirpath:
                wav_paths.append(os.path.join(dirpath, fname))

    print(f"Found {len(wav_paths)} .wav files. Starting preprocessing...")

    total_saved = 0

    # Step 2: Preprocess with tqdm
    for wav_path in tqdm(wav_paths, desc="Preprocessing Mel Spectrograms"):
        mel_path = os.path.join(save_dir, os.path.basename(wav_path).replace('.wav', '.pt'))
        txt_path = wav_path.replace('wav_headMic', 'prompts').replace('.wav', '.txt')

        if not os.path.exists(txt_path) or os.path.exists(mel_path):
            continue

        try:
            waveform, sr = torchaudio.load(wav_path)
            waveform = waveform.to(device)

            if sr != sample_rate:
                waveform = resampler(waveform)

            mel_spec = mel_transform(waveform).squeeze(0).transpose(0, 1).cpu()

            with open(txt_path, 'r', encoding='utf-8') as f:
                transcript = f.read().strip().lower()

            if '.jpg' in transcript or transcript in ('xxx', '') or '[say' in transcript:
                continue

            torch.save((mel_spec, transcript), mel_path)
            total_saved += 1
        except Exception as e:
            print(f"Error processing {wav_path}: {e}")

    print(f"\nDone: {total_saved} mel .pt files saved to {save_dir}")


In [9]:
'''
preprocess_and_save_mel(
    root_dir="/content/drive/MyDrive/TORGO",
    save_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"
)
'''

'\npreprocess_and_save_mel(\n    root_dir="/content/drive/MyDrive/TORGO",\n    save_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"\n)\n'

In [9]:
class TorgoDataset(Dataset):
    def __init__(self, root_dir, char_to_index, sample_rate=16000,
                 items_file=None, use_precomputed=False, mel_dir=None):
        self.char_to_index = char_to_index
        self.sample_rate = sample_rate
        self.use_precomputed = use_precomputed
        self.items = []

        if self.use_precomputed:
            assert mel_dir is not None, "If using precomputed features, you must provide mel_dir."
            self.paths = sorted([
                os.path.join(mel_dir, f)
                for f in os.listdir(mel_dir) if f.endswith('.pt')
            ])
        else:
            self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=80)
            self.resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=sample_rate)

            if items_file is not None:
                loaded_items = torch.load(items_file)
                for wav_fp, text in loaded_items:
                    cleaned = clean_transcription(text)
                    if cleaned:
                        self.items.append((wav_fp, cleaned))
            else:
                wav_files = []
                for dirpath, _, filenames in os.walk(root_dir):
                    for fname in filenames:
                        if fname.endswith('.wav') and 'wav_headMic' in dirpath:
                            wav_files.append(os.path.join(dirpath, fname))

                for wav_fp in tqdm(wav_files, desc=f"Building dataset from .wav files ({len(wav_files)} found)"):
                    txt_fp = wav_fp.replace('wav_headMic', 'prompts').replace('.wav', '.txt')
                    if not os.path.exists(txt_fp):
                        continue
                    with open(txt_fp, 'r', encoding='utf-8') as f:
                        text = f.read()
                    cleaned = clean_transcription(text)
                    if cleaned:
                        self.items.append((wav_fp, cleaned))

    def __len__(self):
        return len(self.paths) if self.use_precomputed else len(self.items)

    def __getitem__(self, idx):
        if self.use_precomputed:
            mel_path = self.paths[idx]
            mel_spec, transcript = torch.load(mel_path)
        else:
            wav_path, transcript = self.items[idx]
            waveform, sr = torchaudio.load(wav_path)
            if sr != self.sample_rate:
                waveform = self.resampler(waveform)
            mel_spec = self.mel_transform(waveform).squeeze(0).transpose(0, 1)

        target = torch.tensor([self.char_to_index[c] for c in transcript if c in self.char_to_index], dtype=torch.long)
        return mel_spec, target


In [10]:
def collate_fn(batch):
    features, targets = zip(*batch)

    # Compute original lengths
    input_lengths = torch.tensor([feat.size(0) for feat in features], dtype=torch.long)
    target_lengths = torch.tensor([len(tgt) for tgt in targets], dtype=torch.long)

    # Pad features (time dimension is dim 0)
    padded_features = pad_sequence(features, batch_first=True)  # shape: [B, T, F]

    # Pad targets for inspection (not used in loss)
    padded_targets = pad_sequence(targets, batch_first=True, padding_value=0)

    return padded_features, input_lengths, padded_targets, target_lengths


In [11]:
# Define the full RNN-T model with encoder, decoder, and joiner
class RNNTModel(nn.Module):
    def __init__(self, input_dim=80, vocab_size=len(vocab), hidden_dim=128, embed_dim=64):
        super().__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.decoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.joiner = nn.Linear(hidden_dim * 2 + hidden_dim, vocab_size)  # enc (bi) + dec (uni)

    def forward(self, encoder_input, target_input):
        # encoder_input: [B, T_enc, 80]
        # target_input: [B, T_dec]
        enc_out, _ = self.encoder(encoder_input)  # [B, T_enc, 2*H]
        dec_emb = self.embed(target_input)        # [B, T_dec, E]
        dec_out, _ = self.decoder(dec_emb)        # [B, T_dec, H]

        # Expand for joiner broadcasting
        enc_exp = enc_out.unsqueeze(2)            # [B, T_enc, 1, 2H]
        dec_exp = dec_out.unsqueeze(1)            # [B, 1, T_dec, H]

        # Joiner: concat and project
        enc_b, t_enc, _, h_enc = enc_exp.shape
        dec_b, _, t_dec, h_dec = dec_exp.shape
        assert enc_b == dec_b

        enc_rep = enc_exp.expand(-1, t_enc, t_dec, -1)  # [B, T_enc, T_dec, 2H]
        dec_rep = dec_exp.expand(-1, t_enc, t_dec, -1)  # [B, T_enc, T_dec, H]

        joined = torch.cat([enc_rep, dec_rep], dim=-1)  # [B, T_enc, T_dec, 3H]
        logits = self.joiner(joined)                    # [B, T_enc, T_dec, V]
        return logits


In [12]:
# CTC Model Class
class CTCModel(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=128, vocab_size=len(vocab), num_layers=2, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True, num_layers=num_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, x):
        x, _ = self.lstm(x)
        logits = self.fc(x)
        return logits  # shape: [B, T, vocab_size]


In [13]:
class BeamSearch:
    def __init__(self, beam_width, vocab_size, blank_index):
        self.beam_width = beam_width
        self.vocab_size = vocab_size
        self.blank_index = blank_index

    def decode(self, logits):
        batch_size = logits.size(1)
        T = logits.size(0)
        results = []

        for b in range(batch_size):
            beams = [([], 0.0)]
            for t in range(T):
                new_beams = []
                log_probs = F.log_softmax(logits[t, b], dim=-1)
                topk = torch.topk(log_probs, self.beam_width)

                for seq, score in beams:
                    for i in range(self.beam_width):
                        token = topk.indices[i].item()
                        new_score = score + topk.values[i].item()
                        if token == self.blank_index:
                            new_beams.append((seq, new_score))
                        else:
                            new_beams.append((seq + [token], new_score))

                new_beams.sort(key=lambda x: x[1], reverse=True)
                beams = new_beams[:self.beam_width]

            best_seq = max(beams, key=lambda x: x[1])[0]
            results.append(best_seq)  # Changed from results.append([best_seq])

        # Ensure results is always a list of lists even with beam_width=1
        results = [[item] if isinstance(item, int) else item for item in results]

        return results

In [14]:
# Compute WER, CER, Edit Distance
from torchaudio.functional import edit_distance as torchaudio_edit_distance
from jiwer import wer, cer

def compute_wer_cer(preds, input_lens, targets, target_lens, index_to_char):
    total_wer, total_cer = 0.0, 0.0
    offset = 0
    valid_examples = 0

    for i in range(len(preds)):
        tlen = target_lens[i].item()
        if tlen == 0:
            continue  # skip empty targets

        target_slice = targets[offset:offset + tlen]
        target_seq = target_slice.view(-1).tolist()
        offset += tlen

        pred_seq = preds[i] if isinstance(preds[i], (list, tuple)) else [preds[i]]
        pred_str = ''.join([index_to_char.get(p, '') for p in pred_seq])
        target_str = ''.join([index_to_char.get(t, '') for t in target_seq])

        if not target_str.strip():
            continue  # skip if the reference string is still empty

        total_wer += wer(target_str, pred_str)
        total_cer += cer(target_str, pred_str)
        valid_examples += 1

    if valid_examples == 0:
        return 1.0, 1.0  # fallback if no valid examples
    return total_wer / valid_examples, total_cer / valid_examples



In [15]:
from jiwer import cer

def decode_indices(indices, index_to_char):
    return ''.join([index_to_char[i] for i in indices if i in index_to_char]).replace('|', ' ').strip()

def compute_cer_accuracy(preds, targets, target_lens, index_to_char):
    offset = 0
    all_refs, all_hyps = [], []
    for i, tlen in enumerate(target_lens):
        ref = ''.join([index_to_char[x.item()] for x in targets[offset:offset + tlen]])
        hyp = ''.join([index_to_char[x] for x in preds[i]])
        all_refs.append(ref)
        all_hyps.append(hyp)
        offset += tlen
    return 1.0 - cer(all_refs, all_hyps)


In [16]:
# Greedy decoding for character prediction from model output
def greedy_decode(logits, input_lens, blank_index):
    pred = logits.argmax(dim=2)
    results = []
    for i in range(pred.size(0)):
        tokens, prev = [], None
        for j in range(input_lens[i]):
            idx = pred[i, j].item()
            if idx != blank_index and idx != prev:
                tokens.append(idx)
            prev = idx
        results.append(tokens)
    return results

In [17]:
from tqdm import tqdm

def train_ctc(model, train_loader, val_loader, optimizer, loss_fn, index_to_char,
              device, epochs=20, patience=5, blank_idx=1, beam_width=5):

    model.to(device)
    best_val_loss = float('inf')
    wait = 0

    train_loss_hist, val_loss_hist = [], []
    train_wer_hist, val_wer_hist = [], []
    train_cer_hist, val_cer_hist = [], []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        model.train()
        train_loss, train_wer, train_cer = 0, 0, 0

        train_bar = tqdm(train_loader, desc="Training", leave=False)
        for feats, feat_lens, targets, target_lens in train_bar:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            optimizer.zero_grad()
            logits = model(feats)
            log_probs = logits.log_softmax(2).transpose(0, 1)
            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            with torch.no_grad():
                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                raw_preds = decoder.decode(log_probs.cpu())

                preds = [beam[0] if isinstance(beam[0], (list, tuple)) else beam for beam in raw_preds]
                preds = [p if isinstance(p, list) else [p] for p in preds]

                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                train_wer += wer_score
                train_cer += cer_score

            train_bar.set_postfix(loss=loss.item(), WER=wer_score, CER=cer_score)

        train_loss /= len(train_loader)
        train_wer /= len(train_loader)
        train_cer /= len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss, val_wer, val_cer = 0, 0, 0
        val_bar = tqdm(val_loader, desc="Validating", leave=False)
        with torch.no_grad():
            for feats, feat_lens, targets, target_lens in val_bar:
                feats, feat_lens = feats.to(device), feat_lens.to(device)
                targets, target_lens = targets.to(device), target_lens.to(device)

                logits = model(feats)
                log_probs = logits.log_softmax(2).transpose(0, 1)
                loss = loss_fn(log_probs, targets, feat_lens, target_lens)
                val_loss += loss.item()

                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                raw_preds = decoder.decode(log_probs.cpu())

                preds = [beam[0] if isinstance(beam[0], (list, tuple)) else beam for beam in raw_preds]
                preds = [p if isinstance(p, list) else [p] for p in preds]

                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                val_wer += wer_score
                val_cer += cer_score

                val_bar.set_postfix(loss=loss.item(), WER=wer_score, CER=cer_score)

        val_loss /= len(val_loader)
        val_wer /= len(val_loader)
        val_cer /= len(val_loader)

        train_loss_hist.append(train_loss)
        val_loss_hist.append(val_loss)
        train_wer_hist.append(train_wer)
        val_wer_hist.append(val_wer)
        train_cer_hist.append(train_cer)
        val_cer_hist.append(val_cer)

        print(f"[CTC] Epoch {epoch+1}: "
              f"Train Loss={train_loss:.4f}, WER={train_wer:.4f}, CER={train_cer:.4f} | "
              f"Val Loss={val_loss:.4f}, WER={val_wer:.4f}, CER={val_cer:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_ctc_model.pt')
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    return train_loss_hist, val_loss_hist, train_wer_hist, val_wer_hist, train_cer_hist, val_cer_hist


In [39]:
from tqdm import tqdm

def train_rnnt(model, train_loader, val_loader, optimizer, loss_fn, index_to_char,
               device, epochs=20, patience=5, blank_idx=1, beam_width=5):

    model.to(device)
    best_val_loss = float('inf')
    wait = 0

    train_loss_hist, val_loss_hist = [], []
    train_wer_hist, val_wer_hist = [], []
    train_cer_hist, val_cer_hist = [], []

    for epoch in range(epochs):
        model.train()
        train_loss, train_wer, train_cer = 0, 0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} - Training (RNN-T)")
        for feats, feat_lens, targets, target_lens in train_pbar:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            optimizer.zero_grad()
            logits = model(feats, targets)
            T_enc = logits.shape[1]
            T_dec = logits.shape[2]
            min_len = min(T_enc, T_dec, feat_lens.max().item())
            logits = logits[:, torch.arange(min_len), torch.arange(min_len), :]
            log_probs = logits.log_softmax(2).transpose(0, 1)
            feat_lens = torch.clamp(feat_lens, max=log_probs.shape[0])

            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            with torch.no_grad():
                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                preds = decoder.decode(log_probs.cpu())
                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                train_wer += wer_score
                train_cer += cer_score

        train_loss /= len(train_loader)
        train_wer /= len(train_loader)
        train_cer /= len(train_loader)

        model.eval()
        val_loss, val_wer, val_cer = 0, 0, 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} - Validation (RNN-T)")
        with torch.no_grad():
            for feats, feat_lens, targets, target_lens in val_pbar:
                feats, feat_lens = feats.to(device), feat_lens.to(device)
                targets, target_lens = targets.to(device), target_lens.to(device)

                logits = model(feats, targets)
                T_enc = logits.shape[1]
                T_dec = logits.shape[2]
                min_len = min(T_enc, T_dec, feat_lens.max().item())
                logits = logits[:, torch.arange(min_len), torch.arange(min_len), :]
                log_probs = logits.log_softmax(2).transpose(0, 1)
                feat_lens = torch.clamp(feat_lens, max=log_probs.shape[0])

                loss = loss_fn(log_probs, targets, feat_lens, target_lens)
                val_loss += loss.item()

                decoder = BeamSearch(beam_width, logits.size(2), blank_idx)
                preds = decoder.decode(log_probs.cpu())
                wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
                val_wer += wer_score
                val_cer += cer_score

        val_loss /= len(val_loader)
        val_wer /= len(val_loader)
        val_cer /= len(val_loader)

        train_loss_hist.append(train_loss)
        val_loss_hist.append(val_loss)
        train_wer_hist.append(train_wer)
        val_wer_hist.append(val_wer)
        train_cer_hist.append(train_cer)
        val_cer_hist.append(val_cer)

        print(f"[RNN-T] Epoch {epoch+1}: Train Loss={train_loss:.4f}, WER={train_wer:.4f}, CER={train_cer:.4f} | "
              f"Val Loss={val_loss:.4f}, WER={val_wer:.4f}, CER={val_cer:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_rnnt_model.pt')
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping.")
                break

    return train_loss_hist, val_loss_hist, train_wer_hist, val_wer_hist, train_cer_hist, val_cer_hist


In [19]:
def evaluate_and_plot_model(model, test_loader, loss_fn, index_to_char, device, blank_idx=1,
                            use_beam=True, beam_width=5, model_type="ctc"):
    model.eval()
    test_loss = 0
    total_wer = 0
    total_cer = 0

    with torch.no_grad():
        for feats, feat_lens, targets, target_lens in test_loader:
            feats, feat_lens = feats.to(device), feat_lens.to(device)
            targets, target_lens = targets.to(device), target_lens.to(device)

            logits = model(feats)
            log_probs = logits.log_softmax(2).transpose(0, 1)
            loss = loss_fn(log_probs, targets, feat_lens, target_lens)
            test_loss += loss.item()

            if use_beam:
                decoder = BeamSearch(beam_width, logits.size(-1), blank_idx)
                preds = decoder.decode(log_probs.cpu())
            else:
                preds = greedy_decode(logits.cpu(), feat_lens.cpu(), blank_idx)

            wer_score, cer_score = compute_wer_cer(preds, feat_lens.cpu(), targets.cpu(), target_lens.cpu(), index_to_char)
            total_wer += wer_score
            total_cer += cer_score

    avg_loss = test_loss / len(test_loader)
    avg_wer = total_wer / len(test_loader)
    avg_cer = total_cer / len(test_loader)

    print(f"[Test - {model_type.upper()}] Loss: {avg_loss:.4f} | WER: {avg_wer:.4f} | CER: {avg_cer:.4f}")
    return avg_loss, avg_wer, avg_cer


In [20]:
# Load data
data_dir = "/content/drive/MyDrive/TORGO"

# Load dataset
#dataset = TorgoDataset(data_dir, char_to_index)
#torch.save(dataset.items, '/content/drive/MyDrive/torgo_items.pt')
dataset = TorgoDataset(
    root_dir="/content/drive/MyDrive/TORGO",  # still required for structure
    char_to_index=char_to_index,
    use_precomputed=True,
    mel_dir="/content/drive/MyDrive/TORGO_mel_preprocessed"
)
print(f"Dataset length: {len(dataset)}")

# Calculate lengths
total_len = len(dataset)
train_len = int(0.7 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len

# Perform split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn, pin_memory=True)



Dataset length: 993


In [32]:
# Train CTC model
model_ctc = CTCModel()
optimizer_ctc = torch.optim.Adam(model_ctc.parameters(), lr=1e-5)
loss_fn_ctc = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
device = get_device()

ctc_train_loss, ctc_val_loss, ctc_train_wer, ctc_val_wer, ctc_train_cer, ctc_val_cer = train_ctc(
    model=model_ctc,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer_ctc,
    loss_fn=loss_fn_ctc,
    index_to_char=index_to_char,
    device=device
)



Epoch 1/20




KeyboardInterrupt: 

In [40]:
# Train RNN-T model
model_rnnt = RNNTModel()
optimizer_rnnt = torch.optim.Adam(model_rnnt.parameters(), lr=1e-4)
loss_fn_rnnt = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
device = get_device()

rnnt_train_loss, rnnt_val_loss, rnnt_train_wer, rnnt_val_wer, rnnt_train_cer, rnnt_val_cer = train_rnnt(
    model=model_rnnt,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer_rnnt,
    loss_fn=loss_fn_rnnt,
    index_to_char=index_to_char,
    epochs=30,
    patience=5,
    device=device
)

Epoch 1 - Training (RNN-T):  27%|██▋       | 3/11 [00:13<00:35,  4.43s/it]


KeyboardInterrupt: 

In [None]:
ctc_results = evaluate_and_plot_model(
    model=model_ctc,
    test_loader=test_loader,
    loss_fn=loss_fn_ctc,
    blank_index=char_to_index['|'],
    index_to_char=index_to_char,
    train_losses=ctc_train_loss,
    val_losses=ctc_val_loss,
    train_accuracies=ctc_train_acc,
    val_accuracies=ctc_val_acc,
    model_name="CTC",
    device='cuda'
)

In [None]:
rnnt_results = evaluate_and_plot_model(
    model=model_rnnt,
    test_loader=test_loader,
    loss_fn=loss_fn_rnnt,
    blank_index=char_to_index['|'],
    index_to_char=index_to_char,
    train_losses=rnnt_train_loss,
    val_losses=rnnt_val_loss,
    train_accuracies=rnnt_train_acc,
    val_accuracies=rnnt_val_acc,
    model_name="RNN-T",
    device='cuda'
)