In [None]:
from datasets import load_dataset

# Load Common Voice English subset
dataset = load_dataset("mozilla-foundation/common_voice_13_0", "en", split="train+validation+test")

# Filter out samples without transcription
dataset = dataset.filter(lambda x: x["sentence"] is not None and x["audio"] is not None)

In [None]:
import torchaudio
import librosa
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def preprocess(example):
    audio = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    
    # Resample to 16kHz
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        
    example["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    example["labels"] = processor.tokenizer(example["sentence"]).input_ids
    return example

dataset = dataset.map(preprocess, remove_columns=["audio", "client_id", "sentence", "gender", "age"])


In [None]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./wav2vec2-stt",
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    fp16=True,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
)

trainer.train()


In [None]:
from jiwer import wer

def compute_metrics(pred):
    pred_ids = pred.predictions.argmax(-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    return {"wer": wer(label_str, pred_str)}


In [None]:
def transcribe(audio_path):
    import soundfile as sf
    speech, _ = sf.read(audio_path)
    input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
    with torch.no_grad():
        logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    return transcription
