<a href="https://colab.research.google.com/github/beinghorizontal/wav2vec2/blob/main/finetune_crossdelenna_medium_cross_en.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install required packages


In [2]:
import os
os.system('pip install datasets transformers jiwer evaluate')


0

In [3]:
import torch
import evaluate
import numpy as np
import random
import librosa
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from transformers import (
    WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor,
    WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
)
from google.colab import drive, output


# Enable custom widget manager


In [4]:
output.enable_custom_widget_manager()


# Check GPU availability


In [5]:
gpu_info = os.popen('nvidia-smi').read()
if 'failed' in gpu_info:
    print('Not connected to a GPU')
else:
    print(gpu_info)


Tue Feb 11 02:28:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

# Load dataset

In [6]:
import datasets
timit = datasets.load_dataset("crossdelenna/whisper_data_merge2")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Split dataset


In [7]:
num_rows = len(timit['train'])
num_test_rows = num_rows // 7
num_train_rows = num_rows - num_test_rows
timit_train = timit["train"].select(range(num_train_rows))
timit_test = timit["train"].select(range(num_test_rows))


# Load Whisper components from Hugging Face Hub


In [8]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("crossdelenna/whisper_med_alex.en")
tokenizer = WhisperTokenizer.from_pretrained("crossdelenna/whisper_med_alex.en", language="English", task="transcribe")
processor = WhisperProcessor.from_pretrained("crossdelenna/whisper_med_alex.en", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("crossdelenna/whisper_med_alex.en")


# Data collator


In [9]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)


# Evaluation metric


In [10]:
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}


# Freeze layers


In [11]:
def freeze_whisper_layers(model):
    for param in model.parameters():
        param.requires_grad = False

    try:
        encoder_layers = model.model.encoder.layers
        for layer in encoder_layers[-2:]:
            for param in layer.parameters():
                param.requires_grad = True
    except AttributeError:
        print("Could not access encoder layers")

    try:
        decoder_layers = model.model.decoder.layers
        for layer in decoder_layers[-2:]:
            for param in layer.parameters():
                param.requires_grad = True
    except AttributeError:
        print("Could not access decoder layers")

    try:
        model.model.encoder.layer_norm.requires_grad = True
    except AttributeError:
        print("Could not access encoder layer norm")

    try:
        model.model.decoder.layer_norm.requires_grad = True
    except AttributeError:
        print("Could not access decoder layer norm")

    for name, module in model.named_children():
        if 'proj' in name or 'head' in name or 'classifier' in name:
            for param in module.parameters():
                param.requires_grad = True

    return model

model = freeze_whisper_layers(model)

# Verify trainable parameters


In [12]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print(f"Percentage of trainable parameters: {trainable_params/total_params*100:.2f}%")

Total parameters: 763856896
Trainable parameters: 111888384
Percentage of trainable parameters: 14.65%


# Training arguments


In [13]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-medium.en",
    per_device_train_batch_size=24,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=10,
    max_steps=901,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=300,
    eval_steps=300,
    logging_steps=300,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
    hub_strategy="checkpoint",
    hub_model_id="crossdelenna/whisper_med_alex.en",
    hub_token='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',
    resume_from_checkpoint=True  # This will resume training from the last checkpoint
)

# trainer = Seq2SeqTrainer(
#     model=model,
#     data_collator=data_collator,
#     args=training_args,
#     compute_metrics=compute_metrics,
#     train_dataset=timit_train,
#     eval_dataset=timit_test,
#     tokenizer=processor.feature_extractor,
# )


  trainer = Seq2SeqTrainer(


# Custom Seq2SeqTrainer to use sampled validation subset. Default random sample size is 300 from test data for faster evaluation at each eval_steps.


In [None]:
# Function to sample a subset of the validation data
def sample_validation_data(dataset, sample_size=300, seed=42):
    return dataset.shuffle(seed=seed).select(range(sample_size))

class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        # Sample a smaller validation subset if it's an evaluation step
        if self.state.global_step % self.args.eval_steps == 0:
            eval_dataset = sample_validation_data(self.eval_dataset, sample_size=300)
        else:
            eval_dataset = self.eval_dataset
        # Call the parent class's evaluate method with the modified eval_dataset
        return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

trainer = CustomSeq2SeqTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit_train,
    eval_dataset=timit_test,
    tokenizer=processor.feature_extractor,
)


# Save processor and tokenizer locally


In [14]:
processor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)


('./whisper-medium.en/tokenizer_config.json',
 './whisper-medium.en/special_tokens_map.json',
 './whisper-medium.en/vocab.json',
 './whisper-medium.en/merges.txt',
 './whisper-medium.en/normalizer.json',
 './whisper-medium.en/added_tokens.json')

# Train model


In [15]:
checkpoint_path = "./whisper_med_alex.en/last-checkpoint"  # Specify the path to the checkpoint

trainer.train(resume_from_checkpoint=checkpoint_path)


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
300,0.2094,0.157849,10.239179
600,0.15,0.129385,7.854614
900,0.1244,0.120753,7.276537


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=901, training_loss=0.16121626560153496, metrics={'train_runtime': 8291.1482, 'train_samples_per_second': 2.608, 'train_steps_per_second': 0.109, 'total_flos': 2.20144478552064e+19, 'train_loss': 0.16121626560153496, 'epoch': 3.4787644787644787})

# Push to hub


In [16]:
trainer.push_to_hub()


events.out.tfevents.1739240955.4cee2e63b939.6105.0:   0%|          | 0.00/7.85k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/crossdelenna/whisper_med_alex.en/commit/0891c81f4f4bd8599dc2ea3ea4d84146666136b1', commit_message='End of training', commit_description='', oid='0891c81f4f4bd8599dc2ea3ea4d84146666136b1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/crossdelenna/whisper_med_alex.en', endpoint='https://huggingface.co', repo_type='model', repo_id='crossdelenna/whisper_med_alex.en'), pr_revision=None, pr_num=None)

# Save model, processor, and tokenizer locally


In [17]:
processor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)


['./whisper-medium.en/preprocessor_config.json']

# Push processor and tokenizer to the Hugging Face Hub


In [19]:
processor.push_to_hub("crossdelenna/whisper_med_alex.en",token="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")
tokenizer.push_to_hub("crossdelenna/whisper_med_alex.en", token="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")
feature_extractor.push_to_hub("crossdelenna/medium_cross.en", token="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")

README.md:   0%|          | 0.00/1.69k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


README.md:   0%|          | 0.00/1.69k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/crossdelenna/medium_cross.en/commit/f4ca35fcba58dd44c960387b1f732188f7380c8a', commit_message='Upload feature extractor', commit_description='', oid='f4ca35fcba58dd44c960387b1f732188f7380c8a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/crossdelenna/medium_cross.en', endpoint='https://huggingface.co', repo_type='model', repo_id='crossdelenna/medium_cross.en'), pr_revision=None, pr_num=None)