In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ctc-model-2/checkpoint_epoch_60.pt


In [2]:
pip install jiwer

Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.1.0 rapidfuzz-3.13.0
Note: you may need to restart the kernel to use updated packages.


In [3]:
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xvjf LJSpeech-1.1.tar.bz2 > /dev/null 2>&1

--2025-05-12 15:13:07--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 185.93.1.242, 2400:52e0:1a00::1068:1
Connecting to data.keithito.com (data.keithito.com)|185.93.1.242|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [text/plain]
Saving to: ‘LJSpeech-1.1.tar.bz2’


2025-05-12 15:13:19 (234 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]



In [4]:
import pandas as pd
df = pd.read_csv("/kaggle/working/LJSpeech-1.1/metadata.csv",  delimiter="|", 
                 header=None, names=["wav", "drop", "text"]).drop(columns=["drop"]).dropna()
df.head(3)

Unnamed: 0,wav,text
0,LJ001-0001,"Printing, in the only sense with which we are ..."
1,LJ001-0002,in being comparatively modern.
2,LJ001-0003,For although the Chinese took impressions from...


In [5]:
# Apply lambda function to create the new file path
df['file_path'] = df['wav'].apply(lambda x: f"/kaggle/working/LJSpeech-1.1/wavs/{x}.wav")
df.head(1)

Unnamed: 0,wav,text,file_path
0,LJ001-0001,"Printing, in the only sense with which we are ...",/kaggle/working/LJSpeech-1.1/wavs/LJ001-0001.wav


In [6]:
# Reorder columns to make 'file_path' the first column
df = df[['file_path'] + [col for col in df.columns if col != 'file_path']]

In [7]:
df.drop(columns=["wav"], inplace=True)
df.head(3)

Unnamed: 0,file_path,text
0,/kaggle/working/LJSpeech-1.1/wavs/LJ001-0001.wav,"Printing, in the only sense with which we are ..."
1,/kaggle/working/LJSpeech-1.1/wavs/LJ001-0002.wav,in being comparatively modern.
2,/kaggle/working/LJSpeech-1.1/wavs/LJ001-0003.wav,For although the Chinese took impressions from...


In [8]:
from IPython.display import Audio

print(df["text"][3])

# Access the first audio file path
audio_path = df["file_path"][3]

# Use the 'filename' parameter to pass the file path
Audio(filename=audio_path, rate=10000)

produced the block books, which were the immediate predecessors of the true printed book,


In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import pandas as pd
import string

In [10]:
# --- Tokenizer Class ---
class TextTokenizerCTC:
    def __init__(self):
        chars = list(string.ascii_lowercase + " ',.!?-:;")
        self.blank_token = "_"
        self.char_to_id = {c: i + 1 for i, c in enumerate(chars)}  # 0 is reserved for CTC blank
        self.char_to_id[self.blank_token] = 0
        self.id_to_char = {v: k for k, v in self.char_to_id.items()}

    def normalize(self, text):
        return text.lower().strip()

    def encode(self, text):
        text = self.normalize(text)
        return [self.char_to_id[c] for c in text if c in self.char_to_id]

    def decode(self, ids):
        return ''.join([self.id_to_char[i] for i in ids])

In [11]:
class LJSpeechCTCDataset(Dataset):
    def __init__(self, dataframe, sample_rate=16000, n_mels=80):
        self.df = dataframe.reset_index(drop=True)  # <- ensures clean integer indexing
        self.sample_rate = sample_rate
        self.tokenizer = TextTokenizerCTC()

        self.audio_transform = torch.nn.Sequential(
            T.Resample(orig_freq=22050, new_freq=sample_rate),
            T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels),
            T.AmplitudeToDB()
        )

    def __len__(self):
        return len(self.df)

    def preprocess_audio(self, file_path):
        waveform, sample_rate = torchaudio.load(file_path)
        if sample_rate != self.sample_rate:
            waveform = T.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(waveform)
        mel_spec = self.audio_transform(waveform)
        return mel_spec.squeeze(0).transpose(0, 1)  # [time, n_mels]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mel = self.preprocess_audio(row["file_path"])
        token_ids = self.tokenizer.encode(row["text"])
        return mel, torch.tensor(token_ids, dtype=torch.long)

In [12]:
# --- Collate Function for Padding ---
def collate_fn(batch):
    # Unpack mel and token_ids
    mels, labels = zip(*batch)

    # Pad mel features
    mel_lengths = [m.shape[0] for m in mels]
    mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True)  # [B, T, n_mels]

    # Pad token ids
    label_lengths = [len(l) for l in labels]
    label_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)

    return mel_padded, torch.tensor(mel_lengths), label_padded, torch.tensor(label_lengths)

In [13]:
dataset = LJSpeechCTCDataset(df)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

# --- Example Usage ---
for batch in dataloader:
    mel_batch, mel_lengths, label_batch, label_lengths = batch
    print("Mel batch shape:", mel_batch.shape)         # [B, T, n_mels]
    print("Label batch shape:", label_batch.shape)     # [B, L]
    break

Mel batch shape: torch.Size([64, 586, 80])
Label batch shape: torch.Size([64, 165])


In [14]:
# Assuming you have a full dataset object `dataset`
small_dataset = torch.utils.data.Subset(dataset, list(range(50)))
small_loader = DataLoader(small_dataset, batch_size=2, collate_fn=collate_fn)

In [15]:
# --- Updated Tokenizer Class (keep this the same) ---
class TextTokenizerCTC:
    def __init__(self):
        chars = list(string.ascii_lowercase + " ',.!?-:;")
        self.blank_token = "_"
        self.char_to_id = {c: i + 1 for i, c in enumerate(chars)}  # 0 is reserved for CTC blank
        self.char_to_id[self.blank_token] = 0
        self.id_to_char = {v: k for k, v in self.char_to_id.items()}

    def normalize(self, text):
        return text.lower().strip()

    def encode(self, text):
        text = self.normalize(text)
        return [self.char_to_id[c] for c in text if c in self.char_to_id]

    def decode(self, ids):
        return ''.join([self.id_to_char[i] for i in ids if i != 0])

In [16]:
import torch
import torch.nn as nn
import math
# --- Updated Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [17]:
# --- Updated Transformer Model with Aggressive Decoder ---
class TransformerDualHead(nn.Module):
    def __init__(self, input_dim, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.token_pad_idx = 0
        self.d_model = d_model

        # --- Encoder ---
        self.input_fc = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False  # Changed to (T, B, D) format for consistency
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # --- CTC Head (unchanged) ---
        self.ctc_fc = nn.Linear(d_model, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=-1)

        # --- Aggressive Decoder from Second Notebook ---
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=self.token_pad_idx)
        self.pos_decoder = PositionalEncoding(d_model, dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, x, tgt_seq=None, src_key_padding_mask=None, tgt_key_padding_mask=None):
        # --- Encoder ---
        x = x.permute(1, 0, 2)  # Change to (T, B, D) format
        x = self.input_fc(x)
        x = self.pos_encoder(x)
        memory = self.encoder(x, src_key_padding_mask=src_key_padding_mask)  # [T, B, D]

        # --- CTC Output ---
        ctc_logits = self.log_softmax(self.ctc_fc(memory))  # [T, B, vocab]

        decoder_logits = None
        if tgt_seq is not None:
            # Shift tgt for teacher forcing: input: [B, T] => [B, T-1]
            tgt_in = tgt_seq[:, :-1]
            tgt_in = tgt_in.permute(1, 0)  # [L, B]
            
            tgt_embed = self.token_embedding(tgt_in) * math.sqrt(self.d_model)
            tgt_embed = self.pos_decoder(tgt_embed)

            # Generate look-ahead mask for causal decoding
            tgt_mask = self.generate_square_subsequent_mask(tgt_in.size(0)).to(x.device)

            # Decode
            decoder_out = self.decoder(
                tgt=tgt_embed,
                memory=memory,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask
            )
            decoder_logits = self.lm_head(decoder_out)  # [L, B, vocab]
            decoder_logits = decoder_logits.permute(1, 0, 2)  # Back to [B, L, vocab]

        return ctc_logits.permute(1, 0, 2), decoder_logits  # Return CTC as [B, T, vocab]

    @staticmethod
    def generate_square_subsequent_mask(sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

In [18]:
# --- Updated Training Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = TextTokenizerCTC()
vocab_size = len(tokenizer.char_to_id)

# Initialize with more aggressive parameters
model = TransformerDualHead(
    input_dim=80, 
    vocab_size=vocab_size,
    d_model=512,
    nhead=8,
    num_layers=6,
    dim_feedforward=2048,
    dropout=0.1
).to(device)

# Differential learning rates as in second notebook
optimizer = torch.optim.Adam([
    {'params': model.encoder.parameters(), 'lr': 1e-4},
    {'params': model.ctc_fc.parameters(), 'lr': 1e-4},
    {'params': model.decoder.parameters(), 'lr': 1e-4},
    {'params': model.lm_head.parameters(), 'lr': 1e-4}
])

ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
ce_loss_fn = nn.CrossEntropyLoss(ignore_index=0)

checkpoint_path = "/kaggle/input/ctc-model-2/checkpoint_epoch_60.pt"  

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_loss = checkpoint['loss']
print(f"Loaded checkpoint from epoch {start_epoch} with loss {best_loss:.4f}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

  checkpoint = torch.load(checkpoint_path)


Loaded checkpoint from epoch 60 with loss 1.0141


In [19]:
import time
from tqdm import tqdm  # For progress bars
import os

# --- Training Loop with Time Estimation ---
total_epochs = 90
start_time = time.time()

# Create checkpoint directory
checkpoint_dir = "/kaggle/working/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(epoch, model, optimizer, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)



for epoch in range(start_epoch, total_epochs):
    epoch_start_time = time.time()
    model.train()
    total_ctc_loss = 0
    total_ce_loss = 0
    
    # Use tqdm for batch progress
    batch_iterator = tqdm(dataloader, desc=f'Epoch {epoch+1}/{total_epochs}', leave=False)
    
    for mel_batch, mel_lengths, label_batch, label_lengths in batch_iterator:
        mel_batch = mel_batch.to(device)
        label_batch = label_batch.to(device)

        optimizer.zero_grad()

        # --- Forward ---
        ctc_logits, decoder_logits = model(mel_batch, tgt_seq=label_batch)

        # --- CTC Loss ---
        ctc_input = ctc_logits.transpose(0, 1)
        ctc_loss = ctc_loss_fn(ctc_input, label_batch, mel_lengths, label_lengths)

        # --- Decoder (LM) Loss ---
        decoder_targets = label_batch[:, 1:]
        decoder_logits = decoder_logits.reshape(-1, vocab_size)
        decoder_targets = decoder_targets.reshape(-1)
        ce_loss = ce_loss_fn(decoder_logits, decoder_targets)

        # Combine losses
        loss = ctc_loss + ce_loss
        loss.backward()
        optimizer.step()

        total_ctc_loss += ctc_loss.item()
        total_ce_loss += ce_loss.item()
        
        # Update progress bar
        batch_iterator.set_postfix({
            'CTC Loss': f'{ctc_loss.item():.4f}',
            'CE Loss': f'{ce_loss.item():.4f}'
        })

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pt"
        save_checkpoint(epoch+1, model, optimizer, loss, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")

    
    # Calculate time statistics
    epoch_time = time.time() - epoch_start_time
    elapsed_time = time.time() - start_time
    remaining_time = (total_epochs - epoch - 1) * epoch_time
    
    
    # Convert to readable format
    def format_time(seconds):
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        seconds = int(seconds % 60)
        return f"{hours}h {minutes}m {seconds}s"
    
    print( f"Epoch {epoch+1}/{total_epochs} | "
           f"CTC Loss: {total_ctc_loss/len(dataloader):.4f} | "
           f"CE Loss: {total_ce_loss/len(dataloader):.4f} | "
           f"Epoch Time: {format_time(epoch_time)} | "
           f"Elapsed: {format_time(elapsed_time)} | "
           f"Remaining: {format_time(remaining_time)}")

                                                                                               

Epoch 61/90 | CTC Loss: 0.7696 | CE Loss: 0.1649 | Epoch Time: 0h 14m 32s | Elapsed: 0h 14m 32s | Remaining: 7h 1m 39s


                                                                                               

Epoch 62/90 | CTC Loss: 0.7578 | CE Loss: 0.1594 | Epoch Time: 0h 14m 50s | Elapsed: 0h 29m 23s | Remaining: 6h 55m 45s


                                                                                               

Epoch 63/90 | CTC Loss: 0.7455 | CE Loss: 0.1532 | Epoch Time: 0h 14m 46s | Elapsed: 0h 44m 9s | Remaining: 6h 38m 50s


                                                                                               

Epoch 64/90 | CTC Loss: 0.7356 | CE Loss: 0.1485 | Epoch Time: 0h 14m 49s | Elapsed: 0h 58m 58s | Remaining: 6h 25m 20s


                                                                                               

Epoch 65/90 | CTC Loss: 0.7230 | CE Loss: 0.1420 | Epoch Time: 0h 14m 46s | Elapsed: 1h 13m 45s | Remaining: 6h 9m 14s


                                                                                               

Epoch 66/90 | CTC Loss: 0.7134 | CE Loss: 0.1388 | Epoch Time: 0h 14m 46s | Elapsed: 1h 28m 31s | Remaining: 5h 54m 37s


                                                                                               

Epoch 67/90 | CTC Loss: 0.7011 | CE Loss: 0.1326 | Epoch Time: 0h 14m 46s | Elapsed: 1h 43m 18s | Remaining: 5h 40m 0s


                                                                                               

Epoch 68/90 | CTC Loss: 0.6907 | CE Loss: 0.1288 | Epoch Time: 0h 14m 49s | Elapsed: 1h 58m 8s | Remaining: 5h 26m 12s


                                                                                               

Epoch 69/90 | CTC Loss: 0.6784 | CE Loss: 0.1227 | Epoch Time: 0h 14m 51s | Elapsed: 2h 12m 59s | Remaining: 5h 11m 57s


                                                                                               

Saved checkpoint to /kaggle/working/checkpoints/checkpoint_epoch_70.pt
Epoch 70/90 | CTC Loss: 0.6697 | CE Loss: 0.1200 | Epoch Time: 0h 14m 52s | Elapsed: 2h 27m 52s | Remaining: 4h 57m 36s


                                                                                               

Epoch 71/90 | CTC Loss: 0.6588 | CE Loss: 0.1147 | Epoch Time: 0h 14m 47s | Elapsed: 2h 42m 39s | Remaining: 4h 40m 54s


                                                                                               

Epoch 72/90 | CTC Loss: 0.6487 | CE Loss: 0.1112 | Epoch Time: 0h 14m 48s | Elapsed: 2h 57m 28s | Remaining: 4h 26m 40s


                                                                                               

Epoch 73/90 | CTC Loss: 0.6389 | CE Loss: 0.1074 | Epoch Time: 0h 14m 53s | Elapsed: 3h 12m 21s | Remaining: 4h 13m 10s


                                                                                               

Epoch 74/90 | CTC Loss: 0.6295 | CE Loss: 0.1051 | Epoch Time: 0h 14m 50s | Elapsed: 3h 27m 12s | Remaining: 3h 57m 35s


                                                                                               

Epoch 75/90 | CTC Loss: 0.6216 | CE Loss: 0.1017 | Epoch Time: 0h 14m 47s | Elapsed: 3h 42m 0s | Remaining: 3h 41m 56s


                                                                                               

Epoch 76/90 | CTC Loss: 0.6125 | CE Loss: 0.0985 | Epoch Time: 0h 14m 50s | Elapsed: 3h 56m 50s | Remaining: 3h 27m 42s


                                                                                               

Epoch 77/90 | CTC Loss: 0.6026 | CE Loss: 0.0956 | Epoch Time: 0h 14m 51s | Elapsed: 4h 11m 41s | Remaining: 3h 13m 4s


                                                                                               

Epoch 78/90 | CTC Loss: 0.5941 | CE Loss: 0.0920 | Epoch Time: 0h 14m 51s | Elapsed: 4h 26m 33s | Remaining: 2h 58m 14s


                                                                                               

Epoch 79/90 | CTC Loss: 0.5851 | CE Loss: 0.0889 | Epoch Time: 0h 14m 49s | Elapsed: 4h 41m 22s | Remaining: 2h 43m 8s


                                                                                               

Saved checkpoint to /kaggle/working/checkpoints/checkpoint_epoch_80.pt
Epoch 80/90 | CTC Loss: 0.5766 | CE Loss: 0.0863 | Epoch Time: 0h 14m 52s | Elapsed: 4h 56m 15s | Remaining: 2h 28m 40s


                                                                                               

Epoch 81/90 | CTC Loss: 0.5692 | CE Loss: 0.0833 | Epoch Time: 0h 14m 48s | Elapsed: 5h 11m 3s | Remaining: 2h 13m 12s


                                                                                               

Epoch 82/90 | CTC Loss: 0.5600 | CE Loss: 0.0811 | Epoch Time: 0h 14m 52s | Elapsed: 5h 25m 55s | Remaining: 1h 58m 57s


                                                                                               

Epoch 83/90 | CTC Loss: 0.5533 | CE Loss: 0.0798 | Epoch Time: 0h 14m 48s | Elapsed: 5h 40m 43s | Remaining: 1h 43m 40s


                                                                                               

Epoch 84/90 | CTC Loss: 0.5442 | CE Loss: 0.0762 | Epoch Time: 0h 14m 49s | Elapsed: 5h 55m 33s | Remaining: 1h 28m 55s


                                                                                               

Epoch 85/90 | CTC Loss: 0.5376 | CE Loss: 0.0737 | Epoch Time: 0h 14m 49s | Elapsed: 6h 10m 22s | Remaining: 1h 14m 6s


                                                                                               

Epoch 86/90 | CTC Loss: 0.5300 | CE Loss: 0.0718 | Epoch Time: 0h 14m 51s | Elapsed: 6h 25m 13s | Remaining: 0h 59m 24s


                                                                                               

Epoch 87/90 | CTC Loss: 0.5225 | CE Loss: 0.0702 | Epoch Time: 0h 14m 49s | Elapsed: 6h 40m 2s | Remaining: 0h 44m 27s


                                                                                               

Epoch 88/90 | CTC Loss: 0.5146 | CE Loss: 0.0672 | Epoch Time: 0h 14m 52s | Elapsed: 6h 54m 55s | Remaining: 0h 29m 44s


                                                                                               

Epoch 89/90 | CTC Loss: 0.5079 | CE Loss: 0.0662 | Epoch Time: 0h 14m 52s | Elapsed: 7h 9m 47s | Remaining: 0h 14m 52s


                                                                                               

Saved checkpoint to /kaggle/working/checkpoints/checkpoint_epoch_90.pt
Epoch 90/90 | CTC Loss: 0.5010 | CE Loss: 0.0642 | Epoch Time: 0h 14m 53s | Elapsed: 7h 24m 40s | Remaining: 0h 0m 0s


In [20]:
import numpy as np

def greedy_decode_ctc(log_probs, mel_lengths, tokenizer, blank_id=0):
    """Greedy decode from CTC log_probs (T, B, V)"""
    _, predicted_ids = log_probs.max(dim=-1)  # [T, B]
    decoded_preds = []
    for i in range(predicted_ids.size(1)):  # For each sample
        pred_ids = predicted_ids[:, i].cpu().numpy()
        pred_ids = pred_ids[:mel_lengths[i]]
        decoded_tokens = tokenizer.decode(pred_ids)
        decoded_preds.append(decoded_tokens)
    return decoded_preds


def greedy_decode_decoder(decoder_logits, tokenizer):
    """Greedy decode from decoder logits [B, T, V]"""
    pred_ids = decoder_logits.argmax(dim=-1)  # [B, T]
    decoded_preds = []
    for ids in pred_ids:
        ids = ids.cpu().numpy()
        # Stop at first padding (optional)
        if 0 in ids:
            ids = ids[:np.where(ids == 0)[0][0]]
        decoded_preds.append(tokenizer.decode(ids))
    return decoded_preds


# --- Sample batch ---
mel_batch, mel_lengths, label_batch, label_lengths = next(iter(small_loader))
mel_batch = mel_batch.to(device)
label_batch = label_batch.to(device)

# --- Forward pass ---
ctc_log_probs, decoder_logits = model(mel_batch, tgt_seq=label_batch)

# Transpose CTC for decoding: [T, B, vocab]
ctc_log_probs = ctc_log_probs.transpose(0, 1)

# Decode both heads
ctc_output = greedy_decode_ctc(ctc_log_probs, mel_lengths, tokenizer)
decoder_output = greedy_decode_decoder(decoder_logits, tokenizer)

# --- Print ---
print("Actual vs CTC vs Decoder Output:\n")
for i in range(len(ctc_output)):
    actual = tokenizer.decode(label_batch[i].cpu().numpy())
    print(f"Example {i+1}:")
    print(f"Actual:   {actual}")
    print(f"CTC:      {ctc_output[i]}")
    print(f"Decoder:  {decoder_output[i]}")
    print()

Actual vs CTC vs Decoder Output:

Example 1:
Actual:   printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition
CTC:      pprriiitttiinnnngng,, iiinttheee oonnnnllyyyt eeeennse  witthh   wwiiiiccchh  www arrrr aatttfeeeseeent cconnnseerrrrrrrrrrnnndt diiifoerrrrrss   fffrrommmmmmmmoooosssstttiinnnoofrrommm  aaallllll   thhee aaarrtttss nnn crraaaaaf  rrreeariisssennted  iin  ttheee aesiibbiiiioon, 
Decoder:  rinting, in the only sense,with weich we are at wresent concerned, differs from most in aot from all the arts and crifts representad in the axhibition

Example 2:
Actual:   in being comparatively modern.
CTC:      innn bbeeeiinngg ccoommppaarrrrrrraatiivvellyy   mmooooodderrrrrnn.
Decoder:  n being comparatively modern. ..'.. .........e......u.et.r...e...o::.ii...ll  l:r.ru:   el  :.omo.:::o:a.ooeaaea:o oo eeeerer:eaete eoaa aoaeeoeaeeeeo



In [21]:
def greedy_decode_ctc(log_probs, mel_lengths, tokenizer, blank_id=0):
    """Greedy decode from CTC log_probs (T, B, V) with proper blank removal and collapsing"""
    _, predicted_ids = log_probs.max(dim=-1)  # [T, B]
    decoded_preds = []
    
    for i in range(predicted_ids.size(1)):  # For each sample
        pred_ids = predicted_ids[:mel_lengths[i], i].cpu().numpy()
        
        # CTC post-processing: remove blanks and collapse repeats
        processed_ids = []
        prev_char = blank_id
        for char_id in pred_ids:
            if char_id != blank_id:
                if char_id != prev_char:
                    processed_ids.append(char_id)
            prev_char = char_id
        
        decoded_tokens = tokenizer.decode(processed_ids)
        decoded_preds.append(decoded_tokens)
    
    return decoded_preds

In [22]:
def greedy_decode_decoder(decoder_logits, tokenizer, max_length=200):
    """Greedy decode from decoder logits [B, T, V] with proper length control"""
    pred_ids = decoder_logits.argmax(dim=-1)  # [B, T]
    decoded_preds = []
    
    for ids in pred_ids:
        ids = ids.cpu().numpy()
        
        # Find the first padding token or stop at max_length
        if 0 in ids:
            end_pos = np.where(ids == 0)[0][0]
            ids = ids[:end_pos]
        else:
            ids = ids[:max_length]
        
        decoded_tokens = tokenizer.decode(ids)
        decoded_preds.append(decoded_tokens)
    
    return decoded_preds

In [23]:
# --- Sample batch ---
mel_batch, mel_lengths, label_batch, label_lengths = next(iter(small_loader))
mel_batch = mel_batch.to(device)
label_batch = label_batch.to(device)

# --- Forward pass ---
with torch.no_grad():
    ctc_log_probs, decoder_logits = model(mel_batch, tgt_seq=label_batch)

# Transpose CTC for decoding: [T, B, vocab]
ctc_log_probs = ctc_log_probs.transpose(0, 1)

# Decode both heads with improved functions
ctc_output = greedy_decode_ctc(ctc_log_probs, mel_lengths, tokenizer)
decoder_output = greedy_decode_decoder(decoder_logits, tokenizer)

# --- Print ---
print("Actual vs Cleaned CTC vs Cleaned Decoder Output:\n")
for i in range(len(ctc_output)):
    actual = tokenizer.decode(label_batch[i].cpu().numpy())
    print(f"Example {i+1}:")
    print(f"Actual:   {actual}")
    print(f"CTC:      {ctc_output[i]}")
    print(f"Decoder:  {decoder_output[i]}")
    print("-" * 80)

Actual vs Cleaned CTC vs Cleaned Decoder Output:

Example 1:
Actual:   printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition
CTC:      prinnting, i  tee onlysense with which wereattpfresent conernd  tdiffoers fromosst if notfrom  all thyeartes dinn crafh  reprisentein the esubiion
Decoder:  rinting, tn the only sensw with which we are at bresent concerned  differs from most fn not from all the arts fnd crifts represented in the exhibition
--------------------------------------------------------------------------------
Example 2:
Actual:   in being comparatively modern.
CTC:      in being comparativly moodern.
Decoder:  n being comparatively modern. ....r.u.t..ururteeo..ue.eeeerrrl::.e,tratritleuuerraererer eeeaeaaa:a urreeieaoaaalaaeeeaeeeee aerrerre erteaarrarreteea
--------------------------------------------------------------------------------


In [24]:
# --- Sample batch ---
mel_batch, mel_lengths, label_batch, label_lengths = next(iter(small_loader))
mel_batch = mel_batch.to(device)
label_batch = label_batch.to(device)

# --- Forward pass ---
with torch.no_grad():
    ctc_log_probs, decoder_logits = model(mel_batch, tgt_seq=label_batch)

# Transpose CTC for decoding: [T, B, vocab]
ctc_log_probs = ctc_log_probs.transpose(0, 1)

# Decode both heads with improved functions
ctc_output = greedy_decode_ctc(ctc_log_probs, mel_lengths, tokenizer)
decoder_output = greedy_decode_decoder(decoder_logits, tokenizer)

# --- Print ---
print("Actual vs Cleaned CTC vs Cleaned Decoder Output:\n")
for i in range(len(ctc_output)):
    actual = tokenizer.decode(label_batch[i].cpu().numpy())
    print(f"Example {i+1}:")
    print(f"Actual:   {actual}")
    print(f"CTC:      {ctc_output[i]}")
    print(f"Decoder:  {decoder_output[i]}")
    print("-" * 80)

Actual vs Cleaned CTC vs Cleaned Decoder Output:

Example 1:
Actual:   printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the exhibition
CTC:      pritn   itheonly sens with wihich wre atpre n concernd differs fromosttif  not from alll hy aars enratreresentie in the exibition
Decoder:  rinting, in the only sense,with which we are at present concerned, differs from most df not from mll the arts and crafts represent d in the axhibition
--------------------------------------------------------------------------------
Example 2:
Actual:   in being comparatively modern.
CTC:      in beieing comparatively modern.
Decoder:  n being comparatively modern. ....u.........e...e..u..au.u...:... .iieeel..aletaaaemileeile r au ea .eaaaaaeaeeeeeeeeeleeeeeoeeeeeeiteeleaereeeaereeee
--------------------------------------------------------------------------------
