In [None]:
%pip install transformers datasets torchaudio

In [None]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset


In [None]:
dataset = load_dataset("common_voice", "en", split="train[:1%]")

In [None]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
def preprocess(batch):
    audio = batch["audio"]["array"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch

dataset = dataset.map(preprocess)


In [None]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.eval()


In [None]:
import torch

for sample in dataset.select(range(5)):
    input_values = torch.tensor([sample["input_values"]])
    with torch.no_grad():
        logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    print(transcription)


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./wav2vec2-ft",
    per_device_train_batch_size=4,
    evaluation_strategy="steps",
    num_train_epochs=3,
    save_steps=500,
    logging_steps=100
)

def compute_metrics(pred):
    pred_ids = pred.predictions.argmax(-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    # You can use WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()