# Debug Training
In this notebook, we 
- ensure that training works (in principle) with `ContextWhisperForCausalLM`, checking that no errors are thrown during the training loop.
- Overfit on a single sample, validating that training works as expected
- Freeze all parameters except that of the `text_encoder`, validating that the `text_encoder` signal is (at least somewhat) useful.


The notebook was developed in Google Colab, granting access to GPU resources for small experiments.

## Install dependencies

In [None]:
!git clone https://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:x-oauth-basic@github.com/fnestaas/context-whisper.git

In [None]:
!pip install pdm uv
!pdm config use_uv true
!cp -r context-whisper/src/context_whisper .

In [None]:
!pip install --upgrade pip
!uv pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio


In [None]:
!uv pip install datasets

## Load a small dataset for debugging

In [None]:
from datasets import load_dataset

ds = load_dataset("rodoggx/ATCO2-ASR-1h")

## Define the model

In [None]:
from context_whisper.modules import ContextWhisperModel, ContextWhisperConfig, ContextWhisperForCausalLM
from context_whisper.processing import ContextWhisperProcessor
import torch
from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor
from transformers.models.whisper.tokenization_whisper import WhisperTokenizer
from transformers.models.bert.tokenization_bert import BertTokenizer

whisper_str = 'openai/whisper-small'
bert_str = 'google-bert/bert-base-uncased'

tokenizer = WhisperTokenizer.from_pretrained(whisper_str)
prompt_tokenizer = BertTokenizer.from_pretrained(bert_str)
feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_str)

processor = ContextWhisperProcessor(
    tokenizer=tokenizer,
    prompt_tokenizer=prompt_tokenizer,
    feature_extractor=feature_extractor
)

config = ContextWhisperConfig(
    d_model=768,
    whisper_pretrained_str=whisper_str,
    text_encoder_pretrained_str=bert_str
)

model = ContextWhisperForCausalLM(config).to('cuda')


## Process the dataset

In [None]:
import torch
from datasets import load_dataset
from tqdm import tqdm

AdamW = torch.optim.AdamW

train_dataset = ds['train']
val_dataset = ds['test']

# Data preprocessing function
def preprocess_function(examples):
    # Get the audio features and tokenized input text
    audio_features = processor(audio=examples['audio']['array'], sampling_rate=16000)
    text_tokens = processor(text=examples['text_Str'], padding=True, truncation=True)  # Text data
    prompt_tokens = processor(prompt='This is a recording about a fabulous view', padding=True, truncation=True)  # Example prompt

    # Convert data into tensors
    input_features = torch.tensor(audio_features['input_features']).to('cuda')
    input_ids = torch.tensor(text_tokens['input_ids']).to('cuda')
    prompt_ids = torch.tensor(prompt_tokens['input_ids']).to('cuda')

    return {'input_features': input_features.squeeze(), 'input_ids': input_ids.squeeze(), 'prompt_ids': prompt_ids.squeeze()}

# Apply the preprocessing function
val_dataset = val_dataset.map(preprocess_function)
train_dataset = train_dataset.map(preprocess_function)

## Debugging: overfit on one sample
Here we 
- take a single sample from the train data, and try to overfit on it, monitoring the loss on that sample by setting it to also be the validation data
- freeze the `decoder` and `spectrogram_encoder`, leaving only the `text_encoder` changable during training. 
If the loss improves, it is due to the `text_encoder` updating its parameters.
Note that we could make similar experiments with freezing other modules too.

In [None]:
# overfit on one sample
from datasets import Dataset
val_dataset = Dataset.from_dict(train_dataset[[0]])
train_dataset = Dataset.from_dict(train_dataset[[0]])

# does the text_encoder matter?
model.freeze_module("output_embeddings")
model.freeze_module("decoder")
model.freeze_module("spectrogram_encoder")


def model2params(m: torch.nn.Module):
    return torch.concat([p.flatten() for p in m.parameters()]).detach().cpu().numpy()
pre_training_params = model2params(model)
pre_training_encoder = model2params(model.get_text_encoder()) # this, and nothing else, should change in this experiment

## Training setup and training loop

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    Custom collate function to pad text and audio data for batching.

    Args:
        batch (list): List of samples from the dataset.

    Returns:
        dict: Batched data with padded text sequences and audio features.
    """
    # Initialize lists for the batched data
    input_features = []
    input_ids = []
    prompt_ids = []

    # Iterate over each sample in the batch
    for sample in batch:
        input_features.append(torch.tensor(sample['input_features']))
        input_ids.append(torch.tensor(sample['input_ids']).squeeze(0))  # remove unnecessary extra dimension
        prompt_ids.append(torch.tensor(sample['prompt_ids']).squeeze(0))  # remove unnecessary extra dimension

    # Pad the text sequences (input_ids and prompt_ids) to the max length in the batch
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)  # Pad text tokens
    prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=0)  # Pad prompt tokens

    # Stack the audio features (input_features) into a tensor and pad them if necessary
    input_features = torch.stack(input_features, dim=0)  # Stack audio features along the batch dimension

    # Return the batch data in a dictionary
    return {
        'input_features': input_features,
        'input_ids': input_ids,
        'prompt_ids': prompt_ids
    }
# Setup DataLoader for training and validation
from torch.utils.data import DataLoader

batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    loop = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}')
    total_loss = 0
    for batch in loop:
        # Move inputs to GPU
        input_features = batch['input_features'].to('cuda')
        input_ids = batch['input_ids'].to('cuda')
        prompt_ids = batch['prompt_ids'].to('cuda')

        # Forward pass
        encoder_out = model.get_encoder().forward(
            spectrogram_input_features=input_features,
            output_hidden_states=True,
            text_encoder_input_ids=prompt_ids,
        )
        outputs = model(
            decoder_input_ids=input_ids,
            encoder_outputs=encoder_out,
            output_hidden_states=True,
            labels=input_ids
          )

        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=total_loss / (loop.n + 1))

    # Evaluate after each epoch
    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            input_features = batch['input_features'].to('cuda')
            input_ids = batch['input_ids'].to('cuda')
            prompt_ids = batch['prompt_ids'].to('cuda')

            enc_out = model.get_encoder().forward(
                spectrogram_input_features=input_features,
                output_hidden_states=True,
                text_encoder_input_ids=prompt_ids,
            )
            # Forward pass (without calculating gradients)
            outputs = model(
                decoder_input_ids=input_ids,
                encoder_outputs=enc_out,
                output_hidden_states=True,
                labels=input_ids
            )

            eval_loss += outputs.loss.item()

    print(f'Epoch {epoch + 1} - Eval Loss: {eval_loss / len(val_dataloader)}')

# Optionally save the model after training
model.save_pretrained('./context_whisper_model', safe_serialization=False)

## Validate that training only changed the `text_encoder`


In [None]:
post_training_params = model2params(model)
train_param_diff = post_training_params - pre_training_params
nz_diff = train_param_diff.nonzero()[0]

In [None]:
assert len(nz_diff) <= len(pre_training_encoder), "Too many parameters changed"
assert nz_diff.max() - nz_diff.min() < len(pre_training_encoder), "Too many parameters changed"
assert len(post_training_params) > len(pre_training_encoder), "All parameters can be changed"
