# Wav2Vec Conformer MoE Model Fine-Tuning and Optimization

In [41]:
!pip install torch torchaudio transformers datasets sentencepiece jiwer



## Pre-Processing

In [25]:
from datasets import load_dataset
from transformers import Wav2Vec2Processor
import random

# Load the Common Voice Urdu dataset
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ur")

# Limit the dataset to 10% per split
def limit_dataset_by_percentage(dataset_dict, percentage):
    limited_dataset = {}
    for split, data in dataset_dict.items():
        num_rows = int(len(data) * percentage / 100)
        selected_indices = random.sample(range(len(data)), num_rows)
        limited_dataset[split] = data.select(selected_indices)
    return limited_dataset

# Apply the limit
percentage = 1
limited_dataset = limit_dataset_by_percentage(dataset, percentage)

# Load pre-trained processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")

def preprocess_data(batch):
    # Extract audio and process it
    audio = batch["audio"]["array"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    # Tokenize the sentence and add the token IDs to labels
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    
    # Remove all unnecessary columns except 'input_values' and 'labels'
    columns_to_remove = ["audio", "client_id", "path", "up_votes", "down_votes", "age", "gender", "accent", "locale", "segment"]
    batch = {k: v for k, v in batch.items() if k not in columns_to_remove}
    
    return batch


# Apply preprocessing
for split, data in limited_dataset.items():
    print(f"Preprocessing {split} split with {data.num_rows} rows.")
    limited_dataset[split] = data.map(preprocess_data, remove_columns=["audio", "sentence"], batched=False)

# Inspect the processed dataset
for split, data in limited_dataset.items():
    print(f"{split} split processed: {data.num_rows} rows")




Preprocessing train split with 41 rows.


Map:   0%|          | 0/41 [00:00<?, ? examples/s]

Preprocessing validation split with 33 rows.


Map:   0%|          | 0/33 [00:00<?, ? examples/s]

Preprocessing test split with 33 rows.


Map:   0%|          | 0/33 [00:00<?, ? examples/s]

Preprocessing other split with 851 rows.


Map:   0%|          | 0/851 [00:00<?, ? examples/s]

Preprocessing invalidated split with 32 rows.


Map:   0%|          | 0/32 [00:00<?, ? examples/s]

train split processed: 41 rows
validation split processed: 33 rows
test split processed: 33 rows
other split processed: 851 rows
invalidated split processed: 32 rows


In [26]:
print(limited_dataset["train"].column_names)
print(limited_dataset["test"].column_names)

['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'input_values', 'labels']
['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'input_values', 'labels']


In [31]:
from transformers import Wav2Vec2ForCTC

# Load model with processor's vocab size
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-960h-lv60-self",
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

# Freeze feature extractor for memory efficiency
model.freeze_feature_extractor()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
import torch

def data_collator(batch):
    # Pad input and labels
    input_values = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(b["input_values"]) for b in batch], batch_first=True, padding_value=processor.tokenizer.pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        [torch.tensor(b["labels"]) for b in batch], batch_first=True, padding_value=processor.tokenizer.pad_token_id
    )
    return {"input_values": input_values, "labels": labels}

In [33]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,  # Reduce batch size
    per_device_eval_batch_size=4,   # Reduce batch size
    gradient_accumulation_steps=4,   # Simulate larger batch size
    dataloader_num_workers=0,       # Reduce number of workers
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="epoch",          # Save model every epoch
    fp16=True,                      # Mixed precision training
    save_steps=500,                 # Save every 500 steps
    remove_unused_columns=False,    # Keep all columns in dataset
)

In [34]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=limited_dataset["train"],
    eval_dataset=limited_dataset["test"],
    tokenizer=processor.feature_extractor
)

# Train the model
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111366951110742, max=1.0)…

  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Epoch,Training Loss,Validation Loss
0,No log,9.861096
1,No log,9.242449
2,No log,9.191039


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


TrainOutput(global_step=6, training_loss=9.145123799641928, metrics={'train_runtime': 4475.0139, 'train_samples_per_second': 0.027, 'train_steps_per_second': 0.001, 'total_flos': 5.35758468231168e+16, 'train_loss': 9.145123799641928, 'epoch': 2.1818181818181817})

In [35]:
model.save_pretrained("./wav2vec2-finetuned-urdu")
processor.save_pretrained("./wav2vec2-finetuned-urdu")

[]

# Conformer MoE Fine-Tuned Model

In [38]:
from jiwer import wer
import torch
import torchaudio
from transformers import Wav2Vec2Processor

# Load the pre-trained processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")

# Function to resample audio with correct data type
def resample_audio(audio, original_sampling_rate, target_sampling_rate=16000):
    resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=target_sampling_rate)
    # Convert audio to float32
    audio_tensor = torch.tensor(audio, dtype=torch.float32)
    return resampler(audio_tensor).numpy()

# Generate predictions
predictions, references = [], []

for batch in dataset["test"]:
    with torch.no_grad():
        # Extract raw audio array and resample
        audio = batch["audio"]["array"]
        resampled_audio = resample_audio(audio, original_sampling_rate=batch["audio"]["sampling_rate"])
        
        # Process audio dynamically
        input_values = processor(resampled_audio, sampling_rate=16000).input_values[0]
        input_values = torch.tensor(input_values).unsqueeze(0)  # Add batch dimension
        
        # Generate logits and predictions
        logits = model(input_values).logits
        pred_ids = torch.argmax(logits, dim=-1)
        
        # Decode predictions and add to list
        predictions.append(processor.batch_decode(pred_ids)[0])
        references.append(batch["sentence"])  # Use existing sentence as reference

# Calculate WER
print("WER:", wer(references, predictions))

WER: 52.628284749354


In [7]:
print(f"Audio data shape: {audio.shape}, dtype: {audio.dtype}")


Audio data shape: (506304,), dtype: float64


# Whisper-Turbo 

In [40]:
import requests
import numpy as np
from scipy.io.wavfile import write
import librosa
import os
from datasets import load_dataset

# Hugging Face Inference API URL for Whisper model
whisper_model_url = "https://api-inference.huggingface.co/models/openai/whisper-large-v3-turbo"

# Your Hugging Face API key
hf_api_key = "API-Key" #I've replaced mine after running

# Function to process and convert audio into the correct format
def get_whisper_transcription(audio):
    # Ensure the audio is in 16kHz format using librosa
    try:
        # Resample audio to 16kHz
        audio_resampled = librosa.resample(audio, orig_sr=16000, target_sr=16000)
        print(f"Audio resampled to 16kHz, length: {len(audio_resampled)}")
    except Exception as e:
        print(f"Error resampling audio: {e}")
        return None

    # Convert the resampled audio to WAV format and save it to disk
    try:
        # Save the audio to a file on disk
        file_name = "/tmp/audio.wav"  # Temporary location
        write(file_name, 16000, audio_resampled.astype(np.int16))  # Save as 16kHz, 16-bit PCM
        print(f"Audio written to disk at {file_name}")
    except Exception as e:
        print(f"Error converting audio to WAV: {e}")
        return None

    # Prepare the request headers
    headers = {
        "Authorization": f"Bearer {hf_api_key}",
    }

    # Send the audio file to the Hugging Face API
    try:
        files = {
            'file': ('audio.wav', open(file_name, 'rb'))  # Send as WAV file from disk
        }
        response = requests.post(whisper_model_url, headers=headers, files=files)

        if response.status_code == 200:
            response_data = response.json()
            transcription = response_data.get("text", "")
            return transcription
        else:
            print(f"Error: {response.status_code}, {response.text}")
            return None
    except Exception as e:
        print(f"Error sending request to Whisper API: {e}")
        return None

# Load the Urdu subset of the Common Voice dataset (small test subset)
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ur", split="test[:1]")  # Using a very small test slice for simplicity

# Function to process the dataset and get transcriptions
def transcribe_dataset(dataset):
    transcriptions = []
    for example in dataset:
        audio = example["audio"]["array"]  # Get the audio array from the dataset
        print(f"Processing audio with shape: {audio.shape}, dtype: {audio.dtype}")
        transcription = get_whisper_transcription(audio)

        if transcription:
            transcriptions.append(transcription)
        else:
            transcriptions.append("")
    
    return transcriptions

# Get transcriptions for the dataset
transcriptions = transcribe_dataset(dataset)

# Add the transcriptions to the dataset
dataset = dataset.add_column("transcription", transcriptions)

# Calculate Word Error Rate (WER)
ground_truths = dataset["sentence"]
predictions = dataset["transcription"]

# Calculate the WER (Word Error Rate)
wer_score = wer(ground_truths, predictions)

# Print the WER
print(f"WER: {wer_score:.2f}")

WER: 0.089
