# This is the script in which we train Wav2Vec2 on the Creolese Audio

In [1]:
#!pip install datasets transformers torchaudio jiwer librosa soundfile
!pip install transformers[torch] hf_xet



In [2]:
from transformers import TrainingArguments, Trainer
from datasets import Dataset, Audio
import torch
import json
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


In [3]:
audio_path = "../creolese-audio-dataset/finetune_eligible"
transcription_path = "../creolese-audio-dataset/finetune_eligible/transcripts.json"

# Load transcripts JSON
with open(transcription_path, 'r') as f:
    transcripts = json.load(f)

# Create a list of dicts pairing audio files and transcripts
data = []
for item in transcripts:
    audio_file = os.path.join(audio_path, item['audio'])
    if os.path.exists(audio_file):
        print(f"Found file: {audio_file}")
        data.append({'audio': audio_file, 'text': item['text']})
    else:
        print(f"Missing file: {audio_file}")

Found file: ../creolese-audio-dataset/finetune_eligible/1a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1c.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1d.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1e.wav
Found file: ../creolese-audio-dataset/finetune_eligible/3a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/3b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/2.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4c.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4d.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4e.wav
Found file: ../creolese-audio-dataset/finetune_eligible/5a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/5b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/

In [4]:
dataset = Dataset.from_list(data)

# Cast the audio column to automatically load audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
print(dataset)


Dataset({
    features: ['audio', 'text'],
    num_rows: 239
})


In [5]:
# Split dataset into 80% training, 20% evaluation
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]



In [6]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor

from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=False
)

tokenizer = Wav2Vec2CTCTokenizer("../tokenizer_files/vocab (Copy).json", unk_token="<unk>", pad_token="<pad>", word_delimiter_token="|")
print(f"Pad token ID: {tokenizer.pad_token_id}")
print(f"Vocab size: {len(tokenizer)}")
processor = Wav2Vec2Processor(feature_extractor, tokenizer)


Pad token ID: 40
Vocab size: 53


In [7]:
def prepare_dataset(batch, processor):
    audio = batch["audio"]

    # Get input values from audio
    input_values = processor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    ).input_values[0]

    # Get labels from text
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    

    # Return proper format for CTC
    return {
        "input_values": input_values,
        "labels": batch["labels"]
    }


# Apply preprocessing to both splits
train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names, num_proc=4, fn_kwargs={"processor": processor})
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names, num_proc=4, fn_kwargs={"processor": processor})


Map (num_proc=4):   0%|          | 0/191 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/48 [00:00<?, ? examples/s]

In [8]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    # "facebook/wav2vec2-large-960h",
    "facebook/wav2vec2-base-960h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(tokenizer),
    ignore_mismatched_sizes=True
)
model.wav2vec2.feature_extractor._freeze_parameters()
torch.nn.init.normal_(model.wav2vec2.masked_spec_embed, mean=0.0, std=0.02)
torch.nn.init.constant_(model.lm_head.bias, -1.0)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized because the shapes did not match:
- lm_head.bias: found shape torch.Size([32]) in the checkpoint and torch.Size([53]) in the model instantiated
- lm_head.weight: found shape torch.Size([32, 768]) in the checkpoint and torch.Size([53, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameter containing:
tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
       requires_grad=True)

# This is a custom attempt of a CTC Data Collator
Maybe you can include a description of what this is, what the issue is with the other one and why this is necessary

In [9]:
import torch
from dataclasses import dataclass
from typing import Dict, List, Union, Any

@dataclass
class DataCollatorCTCWithPadding:
    """
    Proper data collator for Wav2Vec2 CTC training
    This fixes the zero loss issue
    """
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Union[int, None] = None
    max_length_labels: Union[int, None] = None
    pad_to_multiple_of: Union[int, None] = None
    pad_to_multiple_of_labels: Union[int, None] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Separate inputs and labels
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            max_length=self.max_length_labels,
            pad_to_multiple_of=self.pad_to_multiple_of_labels,
            return_tensors="pt",
        )

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

In [10]:
data_collator = DataCollatorCTCWithPadding(processor=processor)

# Let's start training

In [11]:
pip install accelerate>=0.26.0 transformers[torch]

Note: you may need to restart the kernel to use updated packages.


In [12]:
# import jiwer
# import torch

# def compute_metrics(pred):
#     pred_logits = pred.predictions
#     pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)
#     pred_str = processor.batch_decode(pred_ids)
#     label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

#     wer = jiwer.wer(label_str, pred_str)
#     mer = jiwer.mer(label_str, pred_str)
#     cer = jiwer.cer(label_str, pred_str)
#     return {"wer": wer, "mer": mer, "cer": cer}


def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # Clean predictions
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)

    # Fix labels - replace -100 with pad token
    label_ids = pred.label_ids.copy()
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # Debug prints
    print(f"Sample pred: '{pred_str[0]}'")
    print(f"Sample label: '{label_str[0]}'")

    try:
        wer = jiwer.wer(label_str, pred_str)
        mer = jiwer.mer(label_str, pred_str)
        cer = jiwer.cer(label_str, pred_str)
        return {"wer": wer, "mer": mer, "cer": cer}
    except:
        return {"fallback wer": 1.0, "fallback mer": 1.0, "fallback cer": 1.0, }  # Fallback

In [13]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="../training_outputs/wav2vec2_training_output",
    per_device_train_batch_size=1,
    num_train_epochs=5,
    dataloader_num_workers=0,
    logging_steps=5,
    save_steps=2,
    learning_rate=1e-5, #lower learning rate
    save_total_limit=2,
    gradient_accumulation_steps=4,
    weight_decay=0.01,
    
    eval_steps=5,            # Every 50 steps
    eval_strategy="steps",
)


In [14]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=processor.feature_extractor
)


In [15]:
# Test one batch first
print("Testing one training step...")
batch = [train_dataset[0]]
collated = data_collator(batch)
outputs = model(**collated)
print(f"Test loss: {outputs.loss.item()}")

Testing one training step...
Test loss: 15.605981826782227


In [16]:
# trainer.train(resume_from_checkpoint=True)
trainer.train()
trainer.evaluate()

Step,Training Loss,Validation Loss,Fallback wer,Fallback mer,Fallback cer
5,16.9035,23.668852,1.0,1.0,1.0
10,14.8361,22.626188,1.0,1.0,1.0


Sample pred: '(-hn)oo:hn:oung:)ooehooouam(p:(:oomhnoua-pdoohnrkwouoo)s:dkakwd:ruhnuahnur:ouoo):soodkakwou:ee,:nounghn):(-uo:r-:v:(-mpou:oo):ohnmaouehoo-kwnkw-poo:ou-rooovdkw:r:rngm:ochpoomoukwoo:ooou:ngp:uuo-m(ksya-m(pypmoo(:a-moo)puou)soodkahnvyee)ng:ya:(poo:ouak)(hn:hn:kw:oogeeeoukwouoo:ooouu(gmsh:oung:poo)oodkwooeoo:s:dahnou:-oo:oueeedp(oueeou(moouungoo-ng):m(p:yo)pooo)oo-:a:ohn):rmoovd:pouaoo:ao-oomp-n?-mp?:ngmouuup)-rooouung)apy-)avd:kwmehp!ouhndroo:oupookwjryooeh:-ouoouou)pmeeoo:ookwjeeao)ooroaooodo-m(rd)oouuthnnghn:)(:p:mpchooksyang:?-:(:oueeey,)kw'
Sample label: 'riiyeerz intu yuujiidatairiliidooniivin riliilaik koodinglaik da rait. i wono om,aiwono fookos moor on art koz das somtingdatairiliilaik, koz, yu noo,aim a kriiyeetiv porson. nd om,aidoon riliihav a lot ov art stof, om, biikaazai,aikaainda dischroyd dem a wail bak, om. ozaidid vizuwol arts and maitiichor sokt. ndaiendid op heetin art for laik a gud'
Sample pred: '(oohnoo:ou)oopam(po:oomhnapoo)s:dkkwd:ruhnuahnur:oo):sde

KeyboardInterrupt: 

In [None]:
# Debug what the model is actually predicting
test_sample = eval_dataset[0]
inputs = processor(test_sample["input_values"], return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
    predicted_ids = torch.argmax(logits, dim=-1)

print(f"Raw predicted IDs: {predicted_ids}")
print(f"Predicted tokens: {[processor.tokenizer.decode([id]) for id in predicted_ids[0]]}")
print(f"Decoded text: '{processor.batch_decode(predicted_ids)[0]}'")
print(f"Expected: '{processor.tokenizer.decode(test_sample['labels'])}'")

In [None]:
model.save_pretrained("./wav2vec2-creolese-finetuned")
processor.save_pretrained("./wav2vec2-creolese-finetuned")
