# 03 — Second Model: Transfrmers + CTC
Status: **Underfit / failed to generalize**. Kept to document the path and justify later choices.

**Summary:** Despite extensive reconfiguration and debugging, this attempt with Wav2Vec2 (pretrained from facebook) failed to converge, producing NaN-masked losses and underfitted outputs.  
After abandoning the unstable pretrained Wav2Vec2 weights, I trained the model architecture from scratch, which ran without NaN issues but suffered from underfitting due to limited data. As a result, the model failed to produce reliable transcriptions and was ultimately set aside.


# Prepare Environment

## 1. Mount Drive
Mount Google Drive (only needed in Colab).



In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Install Required Libraries


In [None]:
try:
    import google.colab
    from IPython import get_ipython
    ip = get_ipython()
    ip.system("pip install fsspec==2023.6.0 --quiet")
    ip.system("pip install transformers datasets librosa jiwer torchaudio --quiet")
    ip.system("pip install wandb --quiet")
except Exception:
    print("Skipping pip installs (not in Colab).")


## 3. Import packages
Load all Python libraries used later.




In [None]:
# Imports
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd
import numpy as np
import librosa
import random
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Union

## 4. Set paths
Define project root and subfolders (change `ROOT` to your clone path).

In [None]:
# Cell — Config paths (Notebook 2: model training)
import os, sys
from pathlib import Path

# detect if running in Colab
IN_COLAB = "google.colab" in sys.modules

# CONFIG: project root folder
# Change this path to the folder where you cloned/downloaded the repo
ROOT = Path("/content/drive/MyDrive/GitHub/musdb18-asr-dl") if IN_COLAB else Path.cwd()

# canonical subfolders
DATA_RAW        = ROOT / "data" / "raw"
DATA_PROCESSED  = ROOT / "data" / "processed"
OUT_DIR         = ROOT / "outputs"
CHECKPOINTS_DIR = ROOT / "checkpoints"
LOGS_DIR        = ROOT / "logs"
RESULTS         = ROOT / 'results'
MODELS_DIR      = RESULTS / 'models' / 'Transformer_model'
RESULTS.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)


# Hugging Face cache (persistent if on Colab + Drive)
HF_CACHE = Path(
    os.environ.get(
        "HF_CACHE",
        "/content/drive/MyDrive/hf_cache" if IN_COLAB else (Path.home() / ".cache" / "huggingface")
    )
)
os.environ["HF_HOME"] = str(HF_CACHE)
os.environ["HF_DATASETS_CACHE"] = str(HF_CACHE)
os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE)

# ensure dirs exist
for d in [DATA_RAW, DATA_PROCESSED, OUT_DIR, CHECKPOINTS_DIR, LOGS_DIR, HF_CACHE]:
    d.mkdir(parents=True, exist_ok=True)

# quick printout
print("ROOT         :", ROOT)
print("DATA_RAW     :", DATA_RAW)
print("DATA_PROCESSED:", DATA_PROCESSED)
print("OUT_DIR      :", OUT_DIR)
print("CHECKPOINTS  :", CHECKPOINTS_DIR)
print("LOGS_DIR     :", LOGS_DIR)
print("HF_CACHE     :", HF_CACHE)


# Data Process

## Load and Prepare Chunked Audio + Text Data
load the vocal-only chunk metadata (with aligned lyrics), filter out non-lyric segments, and prepare a HuggingFace Dataset for training.


In [None]:
csv_path = str(DATA_PROCESSED / 'train_segments_vocal_combined.csv')

df = pd.read_csv(csv_path)
df = df[df['Lyric'].notnull() & (df['Lyric'].str.strip() != '')]  # remove empty targets

# Shuffle and split
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
split_idx = int(len(df) * 0.9)
train_df, val_df = df[:split_idx], df[split_idx:]

# Convert to HuggingFace Datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
dataset = DatasetDict({"train": train_dataset, "validation": val_dataset})



### Load Wav2Vec2 model + processor

Using facebook/wav2vec2-base-960h as base model.  
Freeze feature encoder so only the CTC head trains.


In [None]:
# Clean up cached models/datasets (mainly useful on Colab)
try:
    import google.colab  # Only run in Colab
    from IPython import get_ipython
    ip = get_ipython()

    # HuggingFace model cache
    ip.system("rm -rf ~/.cache/huggingface/hub/models--facebook--wav2vec2-base-960h")

    # Temporary Colab dataset artifacts
    ip.system("rm -rf /content/tokenized_bpe_dataset")
    ip.system("rm -rf /content/vocab.json")
except Exception:
    print("Skipping Colab cleanup outside Colab.")


In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor

# Rebuild processor
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# ✅ Add <ctc_blank> token again
processor.tokenizer.add_tokens(["<ctc_blank>"])
blank_token_id = len(processor.tokenizer) - 1
print("✅ Processor loaded. Vocab size (with blank):", len(processor.tokenizer))


In [None]:
print("✅ Processor Vocab:", processor.tokenizer.get_vocab())


✅ Processor Vocab: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '|': 4, 'E': 5, 'T': 6, 'A': 7, 'O': 8, 'N': 9, 'I': 10, 'H': 11, 'S': 12, 'R': 13, 'D': 14, 'L': 15, 'U': 16, 'M': 17, 'W': 18, 'C': 19, 'F': 20, 'G': 21, 'Y': 22, 'P': 23, 'B': 24, 'V': 25, 'K': 26, "'": 27, 'X': 28, 'J': 29, 'Q': 30, 'Z': 31, '<ctc_blank>': 32}


### Preprocess text and audio

- Remove special characters from lyrics  
- Load and resample each .wav file to 16kHz mono  
- Add "speech" and cleaned "target_text" fields


In [None]:
### Text Cleaning Function

# normalizes the lyric text to match the format expected by the pretrained Wav2Vec2 tokenizer:

# Removes all punctuation except apostrophes (`'`)
# Converts all characters to uppercase (since the tokenizer uses uppercase letters)
# Replaces spaces with the `|` character (used as the word delimiter token in the tokenizer vocabulary)

def clean_text(text):
    text = re.sub(r"[^\w\s']", '', text)  # Remove punctuation except apostrophes
    text = text.upper().strip()
    return text.replace(" ", "|")


### Convert audio chunks to arrays

This function loads each audio chunk with **torchaudio**, converts stereo signals to mono, and ensures the waveform is stored as a clean 1-D NumPy array.  
For each chunk it adds:
- `speech`: the waveform (float32 array)  
- `sampling_rate`: the audio sampling rate  
- `target_text`: the cleaned lyric text  

In [None]:

def speech_file_to_array_fn(batch):
    try:
        speech_tensor, sr = torchaudio.load(batch["chunk_path"])

        # If stereo, average to mono
        if speech_tensor.shape[0] > 1:
            speech_tensor = torch.mean(speech_tensor, dim=0, keepdim=True)

        waveform = speech_tensor.squeeze().numpy()

        # Ensure waveform is a 1D array, not a scalar
        if isinstance(waveform, float) or np.isscalar(waveform) or waveform.ndim == 0:
            raise ValueError("Waveform is scalar")

        batch["speech"] = waveform.astype(np.float32)  # Ensure consistent dtype
        batch["sampling_rate"] = sr
        batch["target_text"] = clean_text(batch["Lyric"])

        return batch

    except Exception as e:
        print(f"⚠️ Skipping file: {batch['chunk_path']} due to error: {e}")
        return {
            "speech": np.zeros(1, dtype=np.float32),
            "sampling_rate": 16000,
            "target_text": ""
        }


dataset = dataset.map(speech_file_to_array_fn, num_proc=8)


In [None]:
from datasets import DatasetDict

# Save to directory on Google Drive
dataset.save_to_disk(str(MODELS_DIR / 'untokenized_vocal_dataset'))

### Tokenize inputs and labels

This step converts raw audio and text into model-ready features:

- **Waveform handling**: ensures each `speech` sample is a mono tensor at 16 kHz.  
- **Feature extraction**: applies the Wav2Vec2 processor to produce `input_values` and an `attention_mask`.  
- **Label encoding**: tokenizes the cleaned lyric text into `labels`.  

The function is mapped across the dataset, creating a fully tokenized version that is then saved to disk for reuse.  


In [None]:
# --- Load previously saved untokenized dataset ---

from datasets import load_from_disk
import shutil

try:
    import google.colab  # Colab-only copy step
    # Copy from Drive to Colab local scratch space
    shutil.copytree(
        MODELS_DIR / "untokenized_vocal_dataset",
        Path("/content/untokenized_vocal_dataset"),
        dirs_exist_ok=True
    )
    dataset = load_from_disk("/content/untokenized_vocal_dataset")
    print("Loaded dataset from Colab local copy.")
except Exception:
    # Fallback: load directly from MODELS_DIR (works locally)
    dataset = load_from_disk(str(MODELS_DIR / "untokenized_vocal_dataset"))
    print("Loaded dataset directly from MODELS_DIR.")


In [None]:
from datasets import load_from_disk
from transformers import Wav2Vec2Processor
import torch
import torchaudio
import numpy as np
import re


# Preprocessing function
def preprocess(example):
    # Convert waveform to tensor if needed
    waveform = example["speech"]
    if isinstance(waveform, list):
        waveform = np.array(waveform, dtype=np.float32)
    if isinstance(waveform, np.ndarray):
        waveform = torch.tensor(waveform)
    if waveform.ndim > 1 and waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0)
    if example["sampling_rate"] != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=example["sampling_rate"], new_freq=16000)(waveform)

    # Convert waveform to input values (features)
    input_values = processor.feature_extractor(
        waveform.numpy(),
        sampling_rate=16000,
        return_attention_mask=True,
        return_tensors="pt",
        padding=True
    )

    # Encode labels using the tokenizer
    with processor.as_target_processor():
        labels = processor.tokenizer(clean_text(example["target_text"])).input_ids

    return {
        "input_values": input_values["input_values"][0].numpy().tolist(),
        "attention_mask": input_values["attention_mask"][0].numpy().tolist(),
        "labels": labels
    }

# Apply to full dataset
tokenized_dataset = dataset.map(preprocess, remove_columns=dataset["train"].column_names)

# Save to disk
tokenized_dataset.save_to_disk("/content/tokenized_bpe_dataset")


## Order the data set by input lengt

for efficient batches and more stable gradients
without setting group_by_length=True in the training_args, which slows down the initiatin of the model (by approx 13 minutes)...

In [None]:
def compute_input_length(example):
    return {"input_length": len(example["input_values"])}  # return dict, not int

# Map over dataset to get input lengths
tokenized_dataset = tokenized_dataset.map(
    compute_input_length,
    desc="🔍 Computing input lengths",
    num_proc=8,
)

# Sort each split by input length
tokenized_dataset["train"] = tokenized_dataset["train"].sort("input_length")
tokenized_dataset["validation"] = tokenized_dataset["validation"].sort("input_length")

# Save to disk
tokenized_dataset.save_to_disk("/content/sorted_tokenized_vocal_dataset")


### Save the dataset to drive

In [None]:
# --- Save tokenized dataset ---
import shutil

save_path = MODELS_DIR / "tokenized_vocal_dataset"

# Delete the folder if it already exists
if save_path.exists():
    shutil.rmtree(save_path)

# Save newest version
tokenized_dataset.save_to_disk(str(save_path))
print("Tokenized dataset saved to:", save_path)


### Sanity check the tokenized dataset

Inspect a few random samples to verify:
- `input_values` are waveform feature arrays (float32, ~length 1000–5000)
- `attention_mask` exists and matches input length
- `labels` are lists of integers (character token IDs)
- Label IDs can be decoded back into readable text


In [None]:
# --- Load previously saved tokenized dataset ---

from datasets import load_from_disk
import shutil

try:
    import google.colab  # Colab-only copy step
    # Copy from Drive (MODELS_DIR) to Colab local scratch space
    shutil.copytree(
        MODELS_DIR / "tokenized_vocal_dataset",
        Path("/content/tokenized_vocal_dataset"),
        dirs_exist_ok=True
    )
    tokenized_dataset = load_from_disk("/content/tokenized_vocal_dataset")
    print("Loaded tokenized dataset from Colab local copy.")
except Exception:
    # Fallback: load directly from MODELS_DIR (works locally)
    tokenized_dataset = load_from_disk(str(MODELS_DIR / "tokenized_vocal_dataset"))
    print("Loaded tokenized dataset directly from MODELS_DIR.")


In [None]:
import random
from IPython.display import Audio
from datasets import load_from_disk

# Sample 3 examples from the training set
for i in range(3):
    sample = tokenized_dataset["train"][random.randint(0, len(tokenized_dataset["train"]) - 1)]

    print(f"🗂 Sample {i+1}")
    print(f" - input_values: shape = {len(sample['input_values'])}, type = {type(sample['input_values'][0])}")
    print(f" - attention_mask: length = {len(sample['attention_mask'])}")
    print(f" - labels (IDs): {sample['labels']}")

    decoded = processor.decode(sample["labels"], skip_special_tokens=True)
    print(f" - Decoded text: {decoded}")

    print("-" * 60)

# Model Training

## Trainer Utilities

### Standalone tokenizer Fix Cell

In [None]:
# === Add CTC blank token ===
processor.tokenizer.add_tokens(["<ctc_blank>"])
blank_token_id = len(processor.tokenizer) - 1
print("✅ Blank token added at ID:", blank_token_id)

# === Update vocab size on processor ===
print("✅ Updated vocab size:", len(processor.tokenizer))


✅ Blank token added at ID: 32
✅ Updated vocab size: 33


### Define compute_metrics using jiwer

In [None]:
import jiwer

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace -100 with pad token ID
    label_ids = pred.label_ids
    label_ids = np.where(label_ids != -100, processor.tokenizer.pad_token_id, label_ids)

    # Decode using group_tokens=False to avoid over-collapsing repeated characters
    pred_str = processor.batch_decode(pred_ids, group_tokens=False)
    label_str = processor.batch_decode(label_ids, group_tokens=False)

    # Normalize
    pred_str = [s.lower().strip() for s in pred_str]
    label_str = [s.lower().strip() for s in label_str]

    # Filter out empty references (to avoid WER > 1 due to divide-by-zero)
    pred_str_filtered = []
    label_str_filtered = []

    for pred, ref in zip(pred_str, label_str):
        if len(ref.strip()) > 0:
            pred_str_filtered.append(pred)
            label_str_filtered.append(ref)

    if not label_str_filtered:
        return {"wer": 1.0, "cer": 1.0}

    return {
        "wer": jiwer.wer(label_str_filtered, pred_str_filtered),
        "cer": jiwer.cer(label_str_filtered, pred_str_filtered)
    }


### Define Data Collator

In [None]:
from dataclasses import dataclass
from typing import List, Dict, Union
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[float], List[int]]]]) -> Dict[str, torch.Tensor]:
        # Pad input_values
        input_values = [torch.tensor(f["input_values"], dtype=torch.float32).squeeze() for f in features]
        input_values_padded = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True, padding_value=0.0)

        # Create attention mask
        attention_mask = torch.zeros_like(input_values_padded, dtype=torch.long)
        for i, iv in enumerate(input_values):
            attention_mask[i, :iv.shape[0]] = 1

        # Pad labels using tokenizer
        labels = [f["labels"] for f in features]
        label_batch = self.processor.tokenizer.pad(
            [{"input_ids": l} for l in labels],
            padding=self.padding,
            return_tensors="pt"
        )
        label_ids = label_batch["input_ids"]
        label_ids[label_ids == self.processor.tokenizer.pad_token_id] = -100

        # # 🔍 DEBUGGING: compare label lengths to CTC input lengths
        # with torch.no_grad():
        #     dummy_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").eval()
        #     dummy_input = input_values_padded[:1]  # just one sample
        #     logits = dummy_model(dummy_input).logits
        #     input_len = logits.shape[1]  # time dimension of logits

        # print("🔍 CTC Input Length (timesteps):", input_len)
        # for i, lbl in enumerate(label_ids):
        #     real_len = (lbl != -100).sum().item()
        #     print(f"🔍 Label[{i}] length: {real_len} vs input_len: {input_len}")
        #     print(f"    Label IDs: {[id for id in lbl.tolist() if id != -100]}")

        return {
            "input_values": input_values_padded,       # shape: (B, T)
            "attention_mask": attention_mask,          # shape: (B, T), 1s and 0s
            "labels": label_ids                        # shape: (B, L), with -100 for padding
        }


### Save To Drive Callback

In [None]:
from transformers import TrainerCallback
import os

class SaveToDriveCallback(TrainerCallback):
    def __init__(self, base_drive_path, processor):
        self.base_drive_path = base_drive_path
        self.processor = processor

    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        output_dir = os.path.join(self.base_drive_path, f"wav2vec2-ctc-vocal-step-{step}")
        model.save_pretrained(output_dir)
        self.processor.save_pretrained(output_dir)
        print(f"✅ Saved model and processor to: {output_dir}")


### Training configuration - Arguments and Trainer

Set batch size, learning rate, save strategy, gradient accumulation  
Enable best model saving and FP16 training

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    seed=42,
    output_dir="./wav2vec2-ctc-vocal",
    group_by_length=False,  # false since i manually sorted the inputs by length, to avoid the 13 min sorting at model initiation
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=112,
    save_steps=112,
    logging_strategy="steps",
    logging_steps=10,
    logging_first_step=True,
    num_train_epochs=10,
    gradient_checkpointing=True,
    fp16=True,
    learning_rate=5e-4,  # high to escape collapse
    weight_decay=0.005,  # lasso regularization
    warmup_ratio=0.1,
    max_grad_norm=1.0, # gradiet clipping
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    dataloader_pin_memory=True,
    report_to=[],
)


### Load dataset

In [None]:
# --- Load tokenized dataset before training ---

from datasets import load_from_disk
import shutil

try:
    import google.colab  # Colab-only copy step
    # Copy from Drive (MODELS_DIR) to Colab local scratch space
    shutil.copytree(
        MODELS_DIR / "tokenized_vocal_dataset",
        Path("/content/tokenized_vocal_dataset"),
        dirs_exist_ok=True
    )
    tokenized_dataset = load_from_disk("/content/tokenized_vocal_dataset")
    print("Loaded tokenized dataset from Colab local copy.")
except Exception:
    # Fallback: load directly from MODELS_DIR (works locally)
    tokenized_dataset = load_from_disk(str(MODELS_DIR / "tokenized_vocal_dataset"))
    print("Loaded tokenized dataset directly from MODELS_DIR.")


### quick sanity


In [None]:
### Sanity Check Cell

from transformers import Wav2Vec2ForCTC
import torch

# Load tokenized dataset
tokenized = tokenized_dataset

# Inspect one sample
sample = tokenized["train"][0]
print("Sample keys:", sample.keys())
print("Input shape:", len(sample["input_values"]))
print("Label IDs:", sample["labels"])

# Initialize processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.lm_head = torch.nn.Linear(model.config.hidden_size, len(processor.tokenizer), bias=True).to("cuda")
model = model.to("cuda").eval()

# Convert to tensor
input_tensor = torch.tensor(sample["input_values"], dtype=torch.float32).unsqueeze(0).to("cuda")

# Forward pass
with torch.no_grad():
    output = model(input_tensor)
    logits = output.logits
    pred_ids = torch.argmax(logits, dim=-1)

# Decode
decoded = processor.batch_decode(pred_ids, group_tokens=False)
print("Decoded prediction:", decoded[0])


In [None]:
from datasets import load_from_disk

print(tokenized_dataset)
print(tokenized_dataset["train"].column_names)


## Train

Use HuggingFace Trainer with model, collator, processor  
This will log loss and save best checkpoint automatically

initializes the Wav2Vec2-CTC model with the custom vocabulary,  
sets up the `Trainer` for training on the tokenized dataset, and runs the training loop.  

Key steps:
- Configure the model with correct vocab size and `blank_token_id`.
- Replace and reinitialize the output layer (`lm_head`) for the new vocabulary.
- Remove unused dataset columns to avoid errors.
- Train the model with evaluation, metrics, and a callback that saves progress to Drive.
- Save the final model and tokenizer to `MODELS_DIR` for persistence.

In [None]:
import numpy as np
from transformers import Trainer, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Config
from transformers.utils import logging
import logging as py_logging
import os
import torch.nn as nn
import gc, torch
from IPython.core.interactiveshell import InteractiveShell

# Disable external integrations
os.environ["WANDB_DISABLED"] = "true"  # avoid wandb init
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"  # reduce CUDA fragmentation

# Logging setup
logging.set_verbosity_info()
logger = logging.get_logger()
logger.setLevel(py_logging.INFO)

# Notebook display config
InteractiveShell.ast_node_interactivity = "all"

# Use already prepared dataset
dataset = tokenized_dataset

# Set model save paths
MODEL_SAVE_DIR = MODELS_DIR / "wav2vec2-ctc-vocal"
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Initialize model config with tokenizer vocab size
config = Wav2Vec2Config.from_pretrained(
    "facebook/wav2vec2-base-960h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    forced_decoder_ids=None,
    blank_token_id=blank_token_id,
)

# Build model and resize output layer
model = Wav2Vec2ForCTC(config)
model.lm_head = nn.Linear(model.config.hidden_size, model.config.vocab_size, bias=True)
nn.init.xavier_uniform_(model.lm_head.weight)
nn.init.zeros_(model.lm_head.bias)
model.ctc_loss = nn.CTCLoss(blank=blank_token_id, zero_infinity=True, reduction="mean")

# Remove unused columns
dataset["train"] = dataset["train"].remove_columns(["input_length"])
dataset["validation"] = dataset["validation"].remove_columns(["input_length"])

# Free memory
gc.collect()
torch.cuda.empty_cache()

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=processor,
    data_collator=DataCollatorCTCWithPadding(processor=processor),
    compute_metrics=compute_metrics,
    callbacks=[SaveToDriveCallback(base_drive_path=MODEL_SAVE_DIR, processor=processor)],
)

# Train
trainer.train()

# --- Save final model and tokenizer ---
model.save_pretrained(str(MODEL_SAVE_DIR))
processor.save_pretrained(str(MODEL_SAVE_DIR))

print("✅ Model and tokenizer saved to:", MODEL_SAVE_DIR)


### Sample Evaluation on a Few Training Examples

In [None]:
import torch
import random
from jiwer import wer, cer

model.eval()

# Sample a few tokenized training examples
sample_batch = random.sample(list(dataset["train"]), 5)

for i, sample in enumerate(sample_batch):
    input_values = torch.tensor(sample["input_values"]).unsqueeze(0)  # [1, T]
    label_ids = sample["labels"]

    with torch.no_grad():
        logits = model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

    # Decode label ids
    label_ids_tensor = torch.tensor(label_ids)
    label_ids_tensor[label_ids_tensor == -100] = processor.tokenizer.pad_token_id
    target_text = processor.decode(label_ids_tensor, skip_special_tokens=True)

    # Compute metrics
    sample_wer = wer(target_text, pred_text)
    sample_cer = cer(target_text, pred_text)

    # Print results
    print(f"\n--- Sample {i+1} ---")
    print(f"GT   : {target_text}")
    print(f"PRED : {pred_text}")
    print(f"WER  : {sample_wer:.3f}")
    print(f"CER  : {sample_cer:.3f}")


# Test


Load and filter test DataFrame

In [None]:
# Load the preprocessed test CSV
test_df = pd.read_csv(DATA_PROCESSED / "test_segments_chunked.csv")

# Keep only rows with real lyrics
test_df = test_df[test_df["has_lyrics"] == True].reset_index(drop=True)

# Add ID column (if needed)
if "id" not in test_df.columns:
    test_df["id"] = test_df.index


Convert test DataFrame to HuggingFace Dataset

In [None]:
# Extract relevant columns
def extract_subset(df):
    return {
        "id": df["id"],
        "chunk_path": df["chunk_path"],
        "text": df["Lyric"]
    }

raw_test = test_df.apply(extract_subset, axis=1, result_type="expand")
raw_dataset = Dataset.from_pandas(raw_test)


Attach audio loader to Dataset

In [None]:
# Cast chunk_path column to Audio type
raw_dataset = raw_dataset.cast_column("chunk_path", Audio(sampling_rate=16000))
raw_dataset = raw_dataset.rename_column("chunk_path", "audio")

Load trained processor

In [None]:
processor = Wav2Vec2Processor.from_pretrained(
   MODELS_DIR / "wav2vec2-ctc-vocal"
)


Tokenize the test dataset

In [None]:
# Tokenize inputs and labels
def prepare_example(batch):
    audio_array = batch["audio"]["array"]
    inputs = processor(audio_array, sampling_rate=16000)
    with processor.as_target_processor():
        labels = processor(batch["text"]).input_ids
    batch["input_values"] = inputs["input_values"]
    batch["attention_mask"] = inputs["attention_mask"]
    batch["labels"] = labels
    return batch

processed_test = raw_dataset.map(
    prepare_example,
    remove_columns=["audio", "text", "id"],
    desc="Tokenizing"
)


Save tokenized test set to disk

In [None]:
processed_test.save_to_disk(
    DATA_PROCESSED / "tokenized_test_dataset"
)
