<a href="https://colab.research.google.com/github/captainkeemo/Dysarthric-Speech-Transcription/blob/main/models/base_CTC_RNN_T_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jiwer editdistance --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.1/3.1 MB[0m [31m115.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m63.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [18]:
import os
import torch
import torchaudio
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from jiwer import wer, cer
import editdistance
import glob
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import random_split

In [4]:
# Define vocabulary
vocab = list("abcdefghijklmnopqrstuvwxyz '") + ["|"]
char_to_index = {c: i for i, c in enumerate(vocab)}
index_to_char = {i: c for c, i in char_to_index.items()}


In [16]:
# Dataset class that loads audio and transcripts from TORGO
class TorgoDataset(Dataset):
    def __init__(self, root_dir, char_to_index, sample_rate=16000, items_file=None):
        self.char_to_index = char_to_index
        self.sample_rate = sample_rate
        self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=80)

        if items_file is not None:
            # Load from pre-saved file
            self.items = torch.load(items_file)
        else:
            # Build dataset normally
            self.items = []
            wav_files = glob.glob(os.path.join(root_dir, '**', 'wav_headMic', '*.wav'), recursive=True)
            for wav_fp in tqdm(wav_files, desc=f"Building Dataset ({len(wav_files)} files)"):
                txt_fp = wav_fp.replace('wav_headMic', 'prompts').replace('.wav', '.txt')
                if os.path.exists(txt_fp):
                    with open(txt_fp, 'r') as f:
                        text = f.read().strip().lower()
                    self.items.append((wav_fp, text))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        wav_path, transcript = self.items[idx]
        try:
            waveform, sr = torchaudio.load(wav_path)
            if sr != self.sample_rate:
                waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
            spec = self.mel_transform(waveform).squeeze(0).transpose(0, 1)
            target = torch.tensor([self.char_to_index[c] for c in transcript if c in self.char_to_index])
            return spec, target
        except Exception:
            return None



In [6]:
# Batch collation for DataLoader: pad inputs and concatenate targets
def collate_fn(batch):
    # Filter out None samples
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        return None
    specs, targets = zip(*batch)
    spec_lens = torch.tensor([s.size(0) for s in specs], dtype=torch.long)
    target_lens = torch.tensor([t.size(0) for t in targets], dtype=torch.long)
    specs_padded = pad_sequence(specs, batch_first=True)
    targets_flat = torch.cat(targets)
    return specs_padded, targets_flat, spec_lens, target_lens


In [7]:
# Define the full RNN-T model with encoder, decoder, and joiner
class RNNTModel(nn.Module):
    def __init__(self, input_dim=80, vocab_size=len(vocab), hidden_dim=128, embed_dim=64):
        super().__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.decoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.joiner = nn.Linear(hidden_dim * 3, vocab_size)

    def forward(self, encoder_input, target_input):
        enc_out, _ = self.encoder(encoder_input)
        dec_emb = self.embed(target_input)
        dec_out, _ = self.decoder(dec_emb)
        enc_exp = enc_out.unsqueeze(2)
        dec_exp = dec_out.unsqueeze(1)
        join = torch.cat([
            enc_exp.expand(-1, -1, dec_exp.size(2), -1),
            dec_exp.expand(-1, enc_exp.size(1), -1, -1)
        ], dim=-1)
        return self.joiner(join).squeeze(2)

In [8]:
# CTC Model Class
class CTCModel(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=128, vocab_size=len(vocab)):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, x):
        return self.fc(self.lstm(x)[0])

In [9]:
# Compute WER, CER, Edit Distance
def compute_wer_cer(preds, targets, target_lens):
    total_wer, total_cer, total_edit, offset = 0, 0, 0, 0
    for i, p in enumerate(preds):
        ref = targets[offset:offset + target_lens[i]].tolist()
        offset += target_lens[i]
        ref_str = ''.join(index_to_char[r] for r in ref)
        hyp_str = ''.join(index_to_char[x] for x in p if x in index_to_char)
        total_wer += wer(ref_str, hyp_str)
        total_cer += cer(ref_str, hyp_str)
        total_edit += editdistance.eval(ref_str, hyp_str)
    n = len(preds)
    return total_wer / n, total_cer / n, total_edit / n

In [10]:
# Character-level accuracy between predicted and target sequences
def compute_char_accuracy(preds, targets, target_lens):
    acc, total, offset = 0, 0, 0
    for p, tlen in zip(preds, target_lens):
        ref = targets[offset:offset + tlen].tolist()
        offset += tlen
        acc += sum(a == b for a, b in zip(p, ref))
        total += max(len(ref), 1)
    return acc / total if total > 0 else 0.0

In [11]:
# Greedy decoding for character prediction from model output
def greedy_decode(logits, input_lens, blank_index):
    pred = logits.argmax(dim=2)
    results = []
    for i in range(pred.size(0)):
        tokens, prev = [], None
        for j in range(input_lens[i]):
            idx = pred[i, j].item()
            if idx != blank_index and idx != prev:
                tokens.append(idx)
            prev = idx
        results.append(tokens)
    return results

In [12]:
# Training function for CTC model
def train_ctc_model(model, loader, epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
    losses, accs = [], []

    for epoch in range(epochs):
        model.train()
        total_loss = total_acc = 0
        batches = 0
        for specs, targets_flat, input_lens, target_lens in loader:
            logits = model(specs)
            log_probs = logits.log_softmax(2).transpose(0, 1)
            loss = loss_fn(log_probs, targets_flat, input_lens, target_lens)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            preds = greedy_decode(logits, input_lens, char_to_index['|'])
            acc = compute_char_accuracy(preds, targets_flat, target_lens)

            total_loss += loss.item()
            total_acc += acc
            batches += 1

        avg_loss = total_loss / batches
        avg_acc = total_acc / batches
        losses.append(avg_loss)
        accs.append(avg_acc)
        torch.save(model.state_dict(), f"/content/drive/MyDrive/ctc_epoch_{epoch+1}.pt")
        print(f"[CTC Epoch {epoch+1}] Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")

    return losses, accs

In [13]:
# Training function for RNN-T model
def train_rnnt_model(model, loader, epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CTCLoss(blank=char_to_index['|'], zero_infinity=True)
    losses, accs, wers, cers, edits = [], [], [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss = total_acc = total_wer = total_cer = total_edit = 0
        batches = 0

        for specs, targets_flat, input_lens, target_lens in loader:
            B = specs.size(0)
            max_len = target_lens.max().item()
            decoder_inputs = torch.zeros(B, max_len, dtype=torch.long)
            offset = 0
            for i in range(B):
                decoder_inputs[i, :target_lens[i]] = targets_flat[offset:offset + target_lens[i]]
                offset += target_lens[i]

            logits = model(specs, decoder_inputs).log_softmax(-1)
            logits_ctc = logits.mean(dim=2).transpose(0, 1)
            loss = loss_fn(logits_ctc, targets_flat, input_lens, target_lens)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            preds = greedy_decode(logits_ctc.transpose(0, 1), input_lens, char_to_index['|'])
            acc = compute_char_accuracy(preds, targets_flat, target_lens)
            w, c, e = compute_wer_cer(preds, targets_flat, target_lens)

            total_loss += loss.item()
            total_acc += acc
            total_wer += w
            total_cer += c
            total_edit += e
            batches += 1

        avg_loss = total_loss / batches
        avg_acc = total_acc / batches
        avg_wer = total_wer / batches
        avg_cer = total_cer / batches
        avg_edit = total_edit / batches

        losses.append(avg_loss)
        accs.append(avg_acc)
        wers.append(avg_wer)
        cers.append(avg_cer)
        edits.append(avg_edit)
        torch.save(model.state_dict(), f"/content/drive/MyDrive/rnnt_epoch_{epoch+1}.pt")
        print(f"[RNN-T Epoch {epoch+1}] Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, WER: {avg_wer:.4f}")

    return losses, accs, wers, cers, edits

In [19]:
# Load data
data_dir = "/content/drive/MyDrive/TORGO"

# Load dataset
dataset = TorgoDataset(data_dir, char_to_index)
torch.save(dataset.items, '/content/drive/MyDrive/torgo_items.pt')
#dataset = TorgoDataset(data_dir, char_to_index, items_file='/content/drive/MyDrive/torgo_items.pt')
print(f"Dataset length: {len(dataset)}")

# Calculate lengths
total_len = len(dataset)
train_len = int(0.7 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len

# Perform split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)



Building Dataset (8224 files): 100%|██████████| 8224/8224 [44:48<00:00,  3.06it/s]

Dataset length: 8214





In [20]:
dataset = TorgoDataset(data_dir, char_to_index, items_file='/content/drive/MyDrive/torgo_items.pt')


In [None]:
# Train CTC model
ctc_model = CTCModel()
ctc_losses, ctc_accs = train_ctc_model(ctc_model, loader)

In [None]:
# Train RNN-T model
rnnt_model = RNNTModel()
rnnt_losses, rnnt_accs, rnnt_wers, rnnt_cers, rnnt_edits = train_rnnt_model(rnnt_model, loader)

In [None]:
# Plot CTC training loss and accuracy
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(ctc_losses, marker='o', label="CTC Loss")
plt.title("CTC Model - Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid()
plt.legend()

plt.subplot(1,2,2)
plt.plot(ctc_accs, marker='o', label="CTC Accuracy", color='green')
plt.title("CTC Model - Accuracy over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Plot RNN-T training metrics
plt.figure(figsize=(15,8))

plt.subplot(2,2,1)
plt.plot(rnnt_losses, marker='o', label="RNN-T Loss")
plt.title("RNN-T Model - Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid()
plt.legend()

plt.subplot(2,2,2)
plt.plot(rnnt_accs, marker='o', label="RNN-T Accuracy", color='green')
plt.title("RNN-T Model - Accuracy over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid()
plt.legend()

plt.subplot(2,2,3)
plt.plot(rnnt_wers, marker='o', label="WER", color='red')
plt.plot(rnnt_cers, marker='o', label="CER", color='blue')
plt.title("RNN-T Model - WER and CER over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Error Rate")
plt.grid()
plt.legend()

plt.subplot(2,2,4)
plt.plot(rnnt_edits, marker='o', label="Edit Distance", color='purple')
plt.title("RNN-T Model - Edit Distance over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Edit Distance")
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()