# Fine-Tuning Wav2Vec2

This notebook follows tutorial from [this huggingface blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) demonstrating how Wav2Vec2's [base](https://huggingface.co/facebook/wav2vec2-base), [large](https://huggingface.co/facebook/wav2vec2-large), and [large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) checkpoints can be fine-tuned on any English dataset. We also look at [this kaggle notebook](https://www.kaggle.com/code/vitouphy/phoneme-recognition-with-wav2vec2/notebook#Training) that uses similar method on fine-tuning [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) for phone recognition tasks. Here is the [model card](https://huggingface.co/vitouphy/wav2vec2-xls-r-300m-timit-phoneme) of the fine-tuned Wav2Vec2. 

Some other fine-tuned Wav2Vec2 models:
- [speech31/wav2vec2-large-english-TIMIT-phoneme_v3](https://huggingface.co/speech31/wav2vec2-large-english-TIMIT-phoneme_v3)

# 1. Dataset

## 1-1. Load Dataset

Load data. Audio data is resampled (to 16 kHz). The rescaling and MFCC transform can be done later using Wav2Vec2 processer. Annotations are pre-processed by downsizing the labels from 72 to 41. 

In [1]:
import os
import glob
import math
import numpy as np
import pandas as pd
import textgrid
import librosa as lbr
from tqdm import tqdm

win_length = 512
n_fft = 1024      
hop_length = 512  
fmin = 200
fmax = 8000
lifter = 40
n_mfcc = 13

def downsized_phones(old_phone):
    if old_phone == "sp" or old_phone == "sil" or old_phone == "":
        return "[SIL]"
    elif old_phone == "spn":
        return "[UNK]"
    else:
        return ''.join(ph for ph in old_phone if not ph.isdigit())

def rescaling_data(X):
    flatten_data = [item for sent in X for item in sent]
    mean = np.mean(flatten_data)
    std = np.std(flatten_data)

    rescaled_data = []
    for sent in tqdm(X):
        flat = sent.flatten()
        norm = (flat-flat.min())/(flat.max() - flat.min())
        restore = norm.reshape(sent.shape)
        rescaled_data.append(restore)

    return rescaled_data

def load_processed_data(directory="./librespeech_360/train-clean-360", max_songs=600, sampling_rate=44100):
    audios = sorted(glob.glob(directory + "/**/*.flac", recursive=True))
    annotations = sorted(glob.glob(directory + "/**/*.TextGrid", recursive=True))

    X = []
    Y_p = []
    vocab = set()
    phoneme = set()

    max_songs = min(len(audios), max_songs)

    # STEP 1. load and process audio data and annotations and seperate by frames
    for audio, annotation, _ in tqdm(zip(audios, annotations, range(max_songs)), total=max_songs):
        tg = textgrid.TextGrid.fromFile(annotation)
        if audio.split("/")[-1].split(".")[0] != annotation.split("/")[-1].split(".")[0]:
            print(f"Files mismatch! {audio} and {annotation}")

        try:
            # process audio (resampling, get MFCCs)
            wav, rate = lbr.load(audio, sr=sampling_rate)
            #x = lbr.feature.mfcc(y=wav, sr=rate,
            #                      win_length=win_length, hop_length=hop_length,
            #                      n_fft=n_fft, fmin=fmin, fmax=fmax, lifter=lifter,
            #                      n_mfcc=n_mfcc)
            #delta = lbr.feature.delta(x, mode="wrap")
            #delta2 = lbr.feature.delta(x, order=2, mode="wrap")
    
            #X.append(np.vstack([x, delta, delta2]).T)
            X.append(wav)
    
            #yp = [["[SIL]"]] * x.shape[1]
            yp = ""
    
            # process phonemes (71 phones -> 39+2 phones)
            for i in range(len(tg[1])):
            #    start = max(0, round(tg[1][i].minTime * rate / hop_length))
            #    end = min(x.shape[1], round(tg[1][i].maxTime * rate / hop_length))
                new_phone = downsized_phones(tg[1][i].mark)
                yp += new_phone
                yp += "|"
            #    yp[start:end] = [[new_phone]] * (end - start)
                phoneme.add(new_phone)
            
            Y_p.append(yp[:-1])
        except:
            print(f"Error when processing file {audio}")
    
    # STEP 2. normalise audio data
    #X_norm = rescaling_data(X)

    # STEP 3. get phone label dict
    #label_to_ix = {ph:i for i, ph in enumerate(phoneme)}

    return X, Y_p, list(phoneme)

# 1. load and preprocess audio and label data
X, Y_p, phoneme = load_processed_data(max_songs=5000, sampling_rate=16000)

100%|████████████████████████████████████████████████████████| 5000/5000 [00:19<00:00, 260.02it/s]


In [2]:
phoneme

['[SIL]',
 'L',
 'F',
 'Y',
 'SH',
 'R',
 'AE',
 'W',
 'IY',
 'AA',
 'EY',
 'AH',
 'ER',
 'V',
 'NG',
 'DH',
 'D',
 'HH',
 'OW',
 'T',
 'ZH',
 'N',
 'K',
 'CH',
 '[UNK]',
 'AY',
 'M',
 'G',
 'P',
 'OY',
 'AW',
 'UW',
 'S',
 'EH',
 'TH',
 'UH',
 'B',
 'JH',
 'AO',
 'IH',
 'Z']

## 1-2. Split Dataset

Split data by train/test in 8:1:1

In [3]:
from helper_functions import *

In [4]:
X_train, y_train, X_test, y_test = get_xy_split_data(X, Y_p, split=0.2)
X_test, y_test, X_val, y_val = get_xy_split_data(X_test, y_test, split=0.5)

# 2. Preprocessing

## 2-1. Processing into Huggingface Dataset

First, we need to transform raw data into Dataset type. Then, we need to connect each frame (each is a 39 dimensional vector) into a sequence (sequence length = audio file length = 15 seconds), and make sure the sequence is a 1-dimensional array (this is similar to what was done in the `feature_extractor` - when setting `feature_size` = 1). 

In [5]:
from datasets import Dataset
import pandas as pd

def get_data_dict(X, Y):
    '''
    Helper function of get_huggingface_dataset
    
    Turns raw dataform List[List[List[]]]
    into a dictionary {"audio": X, "label": Y}
    '''
    adict = {"audio":[], "text":[]}
    for audios, labels in zip(X, Y):
        #aarr_x = audios.reshape(-1)
        #astr_y = ""
        #for label in labels:
        #    astr_y += label[0] + " "
        #adict["audio"].append(aarr_x)
        #adict["phones"].append(astr_y)
        adict["audio"].append(audios)
        adict["text"].append(labels)
        
    return adict

#def convert_to_numpy(example):
#    example["audio"] = np.array(example["audio"])
#    return example

def get_huggingface_dataset(X_train, y_train, X_val, y_val, X_test, y_test):
    '''
    This function takes X, Y in raw dataform (list of list of lists)
    and return them in Huggingface Dataset form 
    '''
    # get dictionary
    train_dict = get_data_dict(X_train, y_train)
    val_dict = get_data_dict(X_val, y_val)
    test_dict = get_data_dict(X_test, y_test)

    # transfer dictionary to dataset
    train_dataset = Dataset.from_dict(train_dict)
    val_dataset = Dataset.from_dict(val_dict)
    test_dataset = Dataset.from_dict(test_dict)

    # transfoer dataset array from list to np.array
    #train_dataset = train_dataset.map(convert_to_numpy)
    #test_dataset = test_dataset.map(convert_to_numpy)
    return train_dataset, val_dataset, test_dataset


In [6]:
train_dataset, val_dataset, test_dataset = get_huggingface_dataset(X_train, y_train, X_val, y_val, X_test, y_test)

In [7]:
# check val_dataset
test_dataset

Dataset({
    features: ['audio', 'text'],
    num_rows: 500
})

In [8]:
# check train_dataset
train_dataset

Dataset({
    features: ['audio', 'text'],
    num_rows: 4000
})

Have a peak at the first sample of `train_dataset`: 

In [9]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(train_dataset)-1)

print("Phonetics:", train_dataset[rand_int]["text"])
print("Input array shape:", np.array(train_dataset[rand_int]["audio"]).shape)
#print("Sampling rate:", train_dataset[rand_int]["audio"]["sampling_rate"])
ipd.Audio(data=np.array(train_dataset[rand_int]["audio"]), autoplay=False, rate=16000)


Phonetics: [SIL]|B|AO|L|AE|N|D|[SIL]|IY|V|N|IH|NG|R|IH|S|EH|P|SH|AH|N|[SIL]|AH|M|AO|R|IH|L|AE|B|R|AH|T|F|AO|R|M|AH|V|K|AH|M|IH|NG|AW|T|P|AA|R|T|IY|K|AH|N|S|IH|S|T|S|AH|V|AH|B|AO|L|[SIL]|AO|R|AH|V|AH|N|IY|V|N|IH|NG|R|IH|S|EH|P|SH|AH|N|F|AA|L|OW|D|B|AY|D|AE|N|S|IH|NG|[SIL]|[SIL]
Input array shape: (179280,)


## 2-2. Build Character Set

This is for the `processor` usage later. We need to add a padding special token and a "|" to replace the space " ".  

In [10]:
label_to_ix = {ph:i for i, ph in enumerate(phoneme)}

# add PAD token
label_to_ix["[PAD]"] = len(label_to_ix)

# add "|" that represents space token " " (easier to read)
label_to_ix["|"] = len(label_to_ix)

# add UNK token for OOV
#label_to_ix["[UNK]"] = len(label_to_ix)

len(label_to_ix)

43

In [11]:
# save vocab.json
import json 
with open('./tokenizer/vocab.json', 'w') as vocab_file:
    json.dump(label_to_ix, vocab_file)

## 2-3. Tokenizer, Feature Extractor, and Processor

ASR models transcribe speech to text, which means that we both need a `feature extractor` that processes the speech signal to the model's input format, e.g. a feature vector, and a `tokenizer` that processes the model's output format to text.

In 🤗 Transformers, the `Wav2Vec2` model is thus accompanied by both a tokenizer, called `Wav2Vec2CTCTokenizer`, and a feature extractor, called `Wav2Vec2FeatureExtractor`.

The `feature extractor` transform audio waveform into model's input format. It takes the following parameters to be instantiated:
- `feature_size`: Speech models take a sequence of feature vectors as an input, and this defines the dimension of the extracted features. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal 22.
- `sampling_rate`: The sampling rate at which the model is trained on.
- padding_value: For batched inference, shorter inputs need to be padded with a specific value
- `do_normalize`: Whether the input should be zero-mean-unit-variance normalized or not. Usually, speech models perform better when normalizing the input
- `return_attention_mask`: Whether the model should make use of an attention_mask for batched inference. In general, models should always make use of the attention_mask to mask padded tokens. However, due to a very specific design choice of Wav2Vec2's "base" checkpoint, better results are achieved when using no attention_mask. This is not recommended for other speech models. If you want to fine-tune large-lv60, this parameter should be set to True.

In [12]:
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from tokenizers.processors import TemplateProcessing

In [13]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./tokenizer/", unk_token="[UNK]", 
                                                 pad_token="[PAD]", word_delimiter_token="|", )  # './tokenizer' to load vocab.json 
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, 
                                             return_attention_mask=True)                          
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

## 2-4. Preprocessing

Usually this step involves resampling data and complex feature extraction such as log-mel feature extraction. We have done these two steps already so all we need to do is encode the phonetic labels into label ids. 

In [14]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(text=batch["text"]).input_ids

    return batch

def preprocess_wav2vec2(batch):

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(batch['audio'], sampling_rate=16000).input_values[0]
    
    batch["labels"] = processor(text=batch["text"]).input_ids

    return batch

In [15]:
train_dataset = train_dataset.map(preprocess_wav2vec2)
val_dataset = val_dataset.map(preprocess_wav2vec2)
test_dataset = test_dataset.map(preprocess_wav2vec2)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [16]:
# check if the mapping is successful
train_dataset.features

{'audio': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
 'text': Value(dtype='string', id=None),
 'input_values': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}

In [17]:
len(train_dataset[0]['input_values']) == len(train_dataset[0]['audio'])

True

In [18]:
len(train_dataset[0]['labels'])

31

# 3. Training & Evaluation

After the data is processed we are ready to set up the training pipeline. We use huggingface `Trainer` to train the model, and before doing so, we need to do the following as a preparation:
- Define a `data collator`
- Define Evaluation metric
- Load a pretrained checkpoint
- Define training configuration



## 3-1. Data Collator

In contrast to most NLP models, `Wav2Vec2` has a much larger input length than output length. Therefore, we need a sepcial padding data collator when fine-tuning a `Wav2Vec2` model. 

The data collator below treat the `input_values` and `labels` differently and thus applyies seperate padding functions on them. This is because speech input and output are of different modalities. 

Similar to the common data collators, we replace the padding token label with -100 to avoid them being taken into account the loss computing. 

In [19]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels

        return batch

## 3-2. Evaluation Metrics

We use WER to measure the performance. 

In [20]:
from evaluate import load

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
wer_metric = load("wer")

In [21]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids)
    label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {
        "wer": wer,
    }

## 3-3. Load Pretrained Checkpoint

Now, we load the pretrained Wav2Vec2 checkpoint. We are using checkpoint `facebook/wav2vec2-xls-r-300m` ([link to model description](https://huggingface.co/facebook/wav2vec2-xls-r-300m)). 

The tokenizer's `pad_token_id` must be defined to the model's `pad_token_id`, or in the case of `Wav2Vec2ForCTC`, to CTC's blank token 22.

We set the loss reduction to "mean".



In [22]:
from transformers import Wav2Vec2ForCTC

model_id = "facebook/wav2vec2-xls-r-300m"

model = Wav2Vec2ForCTC.from_pretrained(
    model_id, 
    attention_dropout=0.1,
    mask_time_prob=0.75, 
    mask_time_length=5,
    mask_feature_prob=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The first componenet of `Wav2Vect2` consists of a stack of CNN layers, and has already been sufficiently trained during the pretraining. Therefore, we can freeze the paramters in the feature extraction part:

In [23]:
model.freeze_feature_encoder()

## 3-4. Training Configurations

To save GPU memory, we enable PyTorch's gradient checkpointing ([link](https://docs.pytorch.org/docs/stable/checkpoint.html)) 

In [24]:
from transformers import TrainingArguments

model_name = model_id.split("/")[-1]

training_args = TrainingArguments(
    f"{model_name}-finetuned-gtzan",
    group_by_length=True,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=10,
    gradient_checkpointing=True,
    fp16=True,
    save_steps=500,  #100,
    eval_steps=500,
    logging_steps=1,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    save_total_limit=3,
    metric_for_best_model="wer",
    greater_is_better=False,
    load_best_model_at_end=True
)

## 3-5. Train with Trainer

Now we can use Trainer using all those above:

In [25]:
PAD_ID = tokenizer.encode("[PAD]")[0]
EMPTY_ID = tokenizer.encode("|")[0]

def collapse_tokens(tokens: List[Union[str, int]]) -> List[Union[str, int]]:
    prev_token = None
    out = []
    for token in tokens:
        if token != prev_token and prev_token is not None:
            out.append(prev_token)
            out.append(EMPTY_ID)
        prev_token = token
    return out[:-1]

def clean_token_ids(token_ids: List[int]) -> List[int]:
    """
    Remove [PAD] and collapse duplicated token_ids
    """
    token_ids = [x for x in token_ids if x not in [PAD_ID, EMPTY_ID]]
    token_ids = collapse_tokens(token_ids)
    return token_ids


In [26]:
# Training arguments
simple_training_args = TrainingArguments(
    output_dir="./test_output",
    eval_strategy="epoch",        # Try "epoch" instead of "steps"
    logging_steps=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    report_to=[],                       # Disable external logging
) 

ix_to_label = {ix: lab for lab, ix in label_to_ix.items()}

def simple_compute_metrics(pred):
    print("✓ compute_metrics is being called!")
    pred_logits = pred.predictions
    predicted_ids = np.argmax(pred_logits, axis=-1)
    label_ids = pred.label_ids
    
    #print(f"Prediction {len(pred_ids[0])}: {pred_ids}")
    #print(f"Goldstandards {len(pred.label_ids[0])}: {pred.label_ids}")

    # replace -100 with the pad_token_id
    pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

    predicted_ids = clean_token_ids(predicted_ids[0].tolist())
    predicted_str = tokenizer.decode(predicted_ids, group_tokens=False)
        
    label_ids = clean_token_ids(label_ids[0].tolist())
    label_str = tokenizer.decode(label_ids, group_tokens=False)
    
    #print(f"Prediction: {predicted_str}")
    #print(f"Goldstandards: {label_str}")
    
    wer = wer_metric.compute(predictions=[predicted_str], references=[label_str])

    return {
        "wer": wer,
    }

In [27]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=simple_compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=processor.feature_extractor
)

In [None]:
import logging
import os
import warnings

# Enable transformers logging
logging.basicConfig(level=logging.INFO)
os.environ["TRANSFORMERS_VERBOSITY"] = "info"

# Disable external loggers that might interfere
os.environ["WANDB_DISABLED"] = "true"

# This warning is shown when device is using CPU - the following code supress the warning
warnings.filterwarnings("ignore", message=".*pin_memory.*")

trainer.train()

Epoch,Training Loss,Validation Loss


In [None]:
def diagnose_training_setup(trainer):
    """Diagnose why no logs are showing"""
    print("=== Training Setup Diagnosis ===")
    
    args = trainer.args
    dataset_size = len(trainer.train_dataset)
    batch_size = args.per_device_train_batch_size
    steps_per_epoch = dataset_size // batch_size
    
    print(f"Dataset size: {dataset_size}")
    print(f"Batch size: {batch_size}")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Logging every: {args.logging_steps} steps")
    print(f"Total epochs: {args.num_train_epochs}")
    print(f"Total steps: {steps_per_epoch * args.num_train_epochs}")
    
    # Check if logging will happen
    if args.logging_steps > steps_per_epoch:
        print("❌ PROBLEM: logging_steps > steps_per_epoch")
        print(f"   You need logging_steps <= {steps_per_epoch}")
    else:
        print("✅ Logging configuration looks OK")
    
    return steps_per_epoch

# Diagnose your setup
steps_per_epoch = diagnose_training_setup(trainer)

# 4. Evaluate

In [None]:
trainer.evaluate(test_dataset)

In [None]:
result = []
for x in test_dataset:
    input_values = torch.Tensor(x['input_values']).unsqueeze(0)
    label_ids = torch.Tensor(x['labels']).unsqueeze(0)
    
    with torch.no_grad():
        logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        
        # Convert ID to Char because Levenstein library operates on char level
        predicted_ids = clean_token_ids(predicted_ids[0].tolist())
        predicted_str = tokenizer.decode(predicted_ids, group_tokens=False)
        
        label_ids = clean_token_ids(label_ids[0].int().tolist())
        label_str = tokenizer.decode(label_ids, group_tokens=False)

        print(f"Prediction: {predicted_str}")
        print(f"Goldstandards: {label_str}")
        