# Wav2Vec2BERT - 40 epochs

In [17]:
import os
import re
import json
import random
import string
from dataclasses import dataclass
from typing import Dict, List, Union, Optional

import torch
import torchaudio
import librosa
import evaluate
from datasets import load_dataset, Audio, DatasetDict
from transformers import (
    Wav2Vec2BertForCTC,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2BertProcessor,
    TrainingArguments,
    Trainer,
    set_seed,
)

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(42)


Torch: 2.8.0+cu126
CUDA available: True


In [18]:
print(device)

cuda


In [19]:
import random

def display10(dataset):
    for i in range(10):
        r = random.randint(0,len(dataset))
        print(i+1 , dataset[r]['sentence'])

In [20]:
# Dataset and language
CV_VERSION = "mozilla-foundation/common_voice_16_0"
LANG_ID = "hi"  # Hindi

# Base SSL model (wav2vec2-bert encoder)
BASE_MODEL = "facebook/w2v-bert-2.0"

# Audio parameters
TARGET_SAMPLING_RATE = 16000

# Training output dir
OUTPUT_DIR = "Ed-168/w2vbert-hi-ctc-cv16"


BATCH_SIZE          = 1      
GRAD_ACCUM          = 16     
LEARNING_RATE       = 3e-5  
NUM_TRAIN_EPOCHS    = 40
EVAL_STRATEGY       = "steps"
EVAL_STEPS          = 1000   
SAVE_STEPS          = 1000   
LOGGING_STEPS       = 50
WARMUP_RATIO        = 0.05
FP16                = torch.cuda.is_available()      


PUSH_TO_HUB = False
HF_REPO_ID = "Ed-168/w2vbert-hi-ctc-cv16"  # e.g. "username/w2vbert-hi-ctc-cv17"


In [21]:
# This will download and prepare the dataset (first run may take a while)
from datasets import load_dataset

common_voice_train = load_dataset(
    "mozilla-foundation/common_voice_16_0",
    "hi",
    split="train+validation",
    trust_remote_code=True
)
common_voice_test = load_dataset(
    "mozilla-foundation/common_voice_16_0",
    "hi",
    split="test",
    trust_remote_code=True
)



Using the latest cached version of the module from C:\Users\EDWIN\.cache\huggingface\modules\datasets_modules\datasets\mozilla-foundation--common_voice_16_0\3076bf9caad479bbd4fa71669eac459841567c9efac7e647db5ae1ef78abe82a (last modified on Sun Aug 10 17:11:43 2025) since it couldn't be found locally at mozilla-foundation/common_voice_16_0, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from C:\Users\EDWIN\.cache\huggingface\modules\datasets_modules\datasets\mozilla-foundation--common_voice_16_0\3076bf9caad479bbd4fa71669eac459841567c9efac7e647db5ae1ef78abe82a (last modified on Sun Aug 10 17:11:43 2025) since it couldn't be found locally at mozilla-foundation/common_voice_16_0, or remotely on the Hugging Face Hub.


In [22]:
NUM_TRAIN_SAMPLES = 1500
NUM_TEST_SAMPLES = 750

common_voice_train = common_voice_train.select(range(NUM_TRAIN_SAMPLES))
common_voice_test = common_voice_test.select(range(NUM_TEST_SAMPLES))


In [23]:
print(len(common_voice_train))
print(len(common_voice_test))

1500
750


In [24]:

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes" , "variant"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes" , "variant"])

In [25]:
from datasets import Audio

common_voice_train = common_voice_train.cast_column("audio" , Audio(sampling_rate = 16000))
common_voice_test = common_voice_test.cast_column("audio" , Audio(sampling_rate = 16000))

In [26]:
display10(common_voice_train)

1 फैशन की दुनिया में कामयाब होने के लिए अपनाए तरीके
2 आप क्या पहनने वाले हैं?
3 टॉम ने नीले कपड़े पहने थे।
4 उसने मुझसे वह बात छिपाई।
5 यूपी में मूर्ति, माफिया और मुल्जिमों की सरकार: नकवी
6 उसने विस्तृत रूप से अपनी योजना समझाई।
7 नोएडाः वेब वर्क कंपनी के दफ्तर पर पुलिस का छापा, अहम दस्तावेज बरामद
8 टॉम बिलकुल हमारी तरह है।
9 कोड़ा को न दी जाए जमानत: सीबीआई
10 हॉकी मैच के दौरान स्टेडियम में मौजूद रहेंगे प्रधानमंत्री मनमोहन सिंह


In [27]:
# Define the regex at the top level so that subprocesses can access it
chars_to_ignore_regex = r"[\"\'\(\)\[\]\{\}\<\>\—\–\-\—\—\–\—\.\,\?\!\:\;\।\d\@\#\$\%\^\&\*\+\=\_\\\/\|~`]+"

def normalize_text(batch):
    text = batch["sentence"]
    text = text.lower()
    text = re.sub(chars_to_ignore_regex, " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    batch["sentence"] = text
    return batch

common_voice_train = common_voice_train.map(normalize_text)
common_voice_test = common_voice_test.map(normalize_text)
display10(common_voice_train)


1 इस दुर्घटना के लिए कौन ज़िम्मेदार है
2 दिल्ली नॉर्थ एमसीडी में आर्थिक तंगी कमिश्नर को नहीं मिली तीन महीने से सैलरी
3 छत्तीसगढ़ नक्सलगढ़ पर प्रहार का प्लान केंद्रीय गृह सचिव ने बुलाई उच्च स्तरीय बैठक
4 वे विद्यार्थी कोरियाई हैं
5 मैंने उसको पैसे देने की कोशिश करी पर उसने इनकार कर दिया
6 वह अपने प्रयोगों में कबूतरों का उपयोग करता था
7 ओम पुरी का निधन सलमान ने शेयर की ये खास तस्वीर
8 एशिया में तीसरा सबसे बड़ा देश भारत है
9 नई नीतियां अपनाएं बराक ओबामा ईरान
10 मुझे नौवे महीने का नाम बताओ


In [28]:

def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    return {"all_text": [all_text]}

vocabs = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, remove_columns=common_voice_train.column_names)
all_text = " ".join(vocabs["all_text"])
vocab_list = sorted(list(set(list(all_text))))

if " " in vocab_list:
    vocab_list.remove(" ")


vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict["|"] = len(vocab_dict) 
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

print("Vocab size:", len(vocab_dict))
print("Sample of vocab keys:", list(vocab_dict.keys())[:60])


os.makedirs(OUTPUT_DIR, exist_ok=True)
vocab_path = os.path.join(OUTPUT_DIR, "vocab.json")
with open(vocab_path, "w", encoding="utf-8") as f:
    json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
print("Saved vocab to:", vocab_path)


Vocab size: 97
Sample of vocab keys: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'ँ', 'ं', 'ः', 'अ', 'आ', 'इ', 'ई', 'उ', 'ऊ', 'ऋ', 'ए', 'ऐ', 'ऑ', 'ओ', 'औ', 'क', 'ख', 'ग', 'घ', 'च', 'छ', 'ज', 'झ', 'ञ', 'ट', 'ठ', 'ड', 'ढ', 'ण', 'त', 'थ', 'द', 'ध', 'न', 'प', 'फ']
Saved vocab to: Ed-168/w2vbert-hi-ctc-cv16\vocab.json


In [29]:
# Tokenizer for CTC

from transformers import SeamlessM4TFeatureExtractor

tokenizer = Wav2Vec2CTCTokenizer(
    vocab_path,
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="|",
)


feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(BASE_MODEL)


from transformers import Wav2Vec2BertProcessor

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor , tokenizer=tokenizer)

# Save processor for later use/inference
processor.save_pretrained(OUTPUT_DIR)
print("Processor saved to:", OUTPUT_DIR)


Processor saved to: Ed-168/w2vbert-hi-ctc-cv16


In [30]:
rand_clip = random.randint(0 , len(common_voice_train) -1 )
print("Target text:", common_voice_train[rand_clip]["sentence"])
print("Input array shape:", common_voice_train[rand_clip]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_clip]["audio"]["sampling_rate"])

Target text: बिजली बिल मुद्दे पर अरविंद केजरीवाल करेंगे अनशन
Input array shape: (77184,)
Sampling rate: 16000


In [32]:
def prepare_dataset(batch):
    audio = batch['audio']
    batch['input_features'] = processor(audio['array'] , sampling_rate=audio['sampling_rate'])
    batch["input_length"] = len(batch['input_features'])

    batch['labels'] = processor(text = batch['sentence']).input_ids

    return batch

common_voice_train = common_voice_train.map(prepare_dataset , remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset , remove_columns = common_voice_test.column_names)

print("Example lengths:", len(common_voice_train[0]["input_features"]), len(common_voice_train[0]["labels"]))


Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Map:   0%|          | 0/750 [00:00<?, ? examples/s]

Example lengths: 2 23


In [34]:
# import torch

# from dataclasses import dataclass, field
# from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


In [35]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [36]:
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    # pred.predictions is float logits of shape (batch, time, vocab_size)
    pred_logits = pred.predictions
    pred_ids = torch.from_numpy(pred_logits).argmax(-1)

    # Decode predictions and references
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    # Replace -100 with pad_token_id for decoding refs
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}


Using the latest cached version of the module from C:\Users\EDWIN\.cache\huggingface\modules\evaluate_modules\metrics\evaluate-metric--wer\e41eaa77ca7152430cd94704de20946c1b004b5b488ab5d20b26fb81c6c15506 (last modified on Sun Aug 10 22:24:10 2025) since it couldn't be found locally at evaluate-metric--wer, or remotely on the Hugging Face Hub.


In [37]:
from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    BASE_MODEL, 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

Some weights of Wav2Vec2BertForCTC were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
model.freeze_feature_extractor()

AttributeError: 'Wav2Vec2BertForCTC' object has no attribute 'freeze_feature_extractor'

In [None]:
# # Initialize the CTC head on top of wav2vec2-bert encoder
# model = Wav2Vec2BertForCTC.from_pretrained(
#     BASE_MODEL,
#     vocab_size=len(processor.tokenizer),
#     pad_token_id=processor.tokenizer.pad_token_id,
#     ctc_loss_reduction="mean",
#     # You can set this to True for long-form training stability with LayerDrop models
#     # but w2v-bert-2.0 doesn't use LayerDrop by default.
# )

# # Make sure the model knows the correct special tokens
# model.config.pad_token_id = processor.tokenizer.pad_token_id
# model.config.vocab_size = len(processor.tokenizer)
# model.to(device)

# # Optionally freeze the feature encoder for a few epochs if you have small compute
# # (uncomment to try). Often helps stabilize early training.
# # if hasattr(model, "freeze_feature_encoder"):
# #     model.freeze_feature_encoder()

# # Print parameter count
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")


Some weights of Wav2Vec2BertForCTC were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total params: 580,589,470 | Trainable: 580,589,470


In [None]:
# TrainingArguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    group_by_length=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    save_steps=SAVE_STEPS,
    eval_steps=EVAL_STEPS,
    logging_steps=LOGGING_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    fp16=FP16,
    save_total_limit=2,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    report_to=["none"],  
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,  # ensures padding works
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Trainer is ready.")


  trainer = Trainer(


Trainer is ready.


In [23]:
# common_voice_train = common_voice_train.rename_column("input_features", "input_values")
# common_voice_train

In [24]:
common_voice_train

Dataset({
    features: ['input_features', 'input_length', 'labels'],
    num_rows: 1000
})

In [None]:
train_result = trainer.train()
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)


print("Training complete. Model and processor saved to:", OUTPUT_DIR)

Step,Training Loss
50,24.5543
100,18.3017
150,14.4906
200,13.8688
250,12.6651
300,12.5078
350,11.7391
400,11.6225
450,11.0105
500,10.7895


In [None]:
# After training
metrics = train_result.metrics
metrics["train_samples"] = len(common_voice_train)
metrics["num_epochs"] = training_args.num_train_epochs  

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

print(f"Training complete! Model and processor saved to: {OUTPUT_DIR}")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Metrics: {metrics}")


In [5]:
import torch
import librosa
from transformers import Wav2Vec2BertForCTC, Wav2Vec2BertProcessor

from transformers import Wav2Vec2BertProcessor
processor = Wav2Vec2BertProcessor.from_pretrained("w2vbert-hi-ctc-cv16")


# # Load your fine-tuned model and processor
model_dir = "w2vbert-hi-ctc-cv16"
model = Wav2Vec2BertForCTC.from_pretrained(model_dir)
# processor = Wav2Vec2BertProcessor.from_pretrained(model_dir)

In [1]:
from datasets import load_dataset
import soundfile as sf

# Load only a small sample from the Common Voice Hindi test set
common_voice_test = load_dataset("mozilla-foundation/common_voice_16_0", "hi", split="test")

# Select one example (e.g., first sample)
sample = common_voice_test[10]

# Access the audio array and sample rate for that example
audio_array = sample["audio"]["array"]
sampling_rate = sample["audio"]["sampling_rate"]

# Save the audio to a WAV file locally
output_wav_path = "test_hindi_clip.wav"
sf.write(output_wav_path, audio_array, samplerate=sampling_rate)

print(f"Saved test audio clip to: {output_wav_path}")
print("Transcript:", sample["sentence"])


Saved test audio clip to: test_hindi_clip.wav
Transcript: जानें उत्तर प्रदेश में किस-किस की है मुस्लिम वोटों पर नजर


In [6]:
audio_path = "test_hindi_clip.wav"

# Load and preprocess audio
speech, sr = librosa.load(audio_path, sr=processor.feature_extractor.sampling_rate)
sampling_rate = processor.feature_extractor.sampling_rate
print(sampling_rate)
# Preprocess input for the model
inputs = processor(audio=speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(**inputs).logits

# Get predicted token IDs

predicted_ids = torch.argmax(logits, dim=-1)
print(predicted_ids)
# Decode to text (Hindi)
transcription = processor.batch_decode(predicted_ids)[0]
print("Transcription:", transcription)

16000
tensor([[91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91,
         91, 91, 91, 9

In [6]:
model.eval()

Wav2Vec2BertForCTC(
  (wav2vec2_bert): Wav2Vec2BertModel(
    (feature_projection): Wav2Vec2BertFeatureProjection(
      (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=160, out_features=1024, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Wav2Vec2BertEncoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0-23): 24 x Wav2Vec2BertEncoderLayer(
          (ffn1_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (ffn1): Wav2Vec2BertFeedForward(
            (intermediate_dropout): Dropout(p=0.0, inplace=False)
            (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): SiLU()
            (output_dense): Linear(in_features=4096, out_features=1024, bias=True)
            (output_dropout): Dropout(p=0.0, inplace=False)
          )
          (self_attn_layer_norm): LayerNorm((10

In [7]:
import torch
import evaluate
from datasets import load_dataset

# Load the Hindi Common Voice test split
common_voice_test = load_dataset(
    "mozilla-foundation/common_voice_16_0",
    "hi",
    split="test",
    trust_remote_code=True
)

# Put model in eval mode
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Target sampling rate from processor
TARGET_SAMPLING_RATE = processor.feature_extractor.sampling_rate

# Initialize WER metric
wer_metric = evaluate.load("wer")

# Take only 5 samples for demo
sampled_test = common_voice_test.shuffle(seed=42).select(range(5))

preds = []
refs = []

print("Evaluating on 5 samples...")
for sample in sampled_test:
    audio = sample["audio"]
    ref = sample["sentence"]

    # Preprocess audio
    inputs = processor(audio["array"], sampling_rate=TARGET_SAMPLING_RATE, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = model(inputs.input_features.to(device)).logits

    # Decode prediction
    pred_ids = torch.argmax(logits, dim=-1)
    pred = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

    preds.append(pred)
    refs.append(ref)

    print(f"REF: {ref}\nHYP: {pred}\n{'-'*80}")

# Compute WER
demo_wer = wer_metric.compute(predictions=preds, references=refs)
print(f"\n✅ Demo WER on 5 samples: {demo_wer:.4f}")


Evaluating on 5 samples...
REF: इस जीत के बाद भी नहीं मुस्कुराता तो लोग एबनॉर्मल कहतेः विराट कोहली
HYP: [PAD]
--------------------------------------------------------------------------------
REF: इसे ठीक करना नामुमकिन है|
HYP: [PAD]
--------------------------------------------------------------------------------
REF: मैं अपनी वर्तमान आमदनी से संतुष्ट हूँ।
HYP: [PAD]
--------------------------------------------------------------------------------
REF: चमत्कारी हैं मां महालक्ष्मी
HYP: [PAD]
--------------------------------------------------------------------------------
REF: मैं अपने बच्चों को टीवी नहीं देखने देता।
HYP: [PAD]
--------------------------------------------------------------------------------

✅ Demo WER on 5 samples: 1.0000
