In [1]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [2]:
from torchaudio.datasets import LIBRISPEECH
from pathlib import Path
import torchaudio
from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2ForCTC,
    Trainer,
    TrainingArguments,
    logging
)
from torch.utils.data import Dataset
import evaluate
import torch
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
logging.set_verbosity_error()


In [3]:
# loading dataset
root = Path("data/raw/LIBRISPEECH")
root.mkdir(parents=True, exist_ok=True)

train_dataset = LIBRISPEECH(root=root, url="train-clean-100", download=True)
eval_dataset = LIBRISPEECH(root=root, url="dev-clean", download=True)


100%|██████████| 5.95G/5.95G [05:28<00:00, 19.4MB/s]
100%|██████████| 322M/322M [00:17<00:00, 19.4MB/s]


In [4]:
# tokenizing transcripts
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base")

class LibriSpeechDataset(Dataset):
    def __init__(self, torchaudio_dataset, tokenizer):
        self.dataset = torchaudio_dataset
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        data = self.dataset[idx]

        if len(data) == 2:
            waveform, sr = data
            transcript = ""
        else:
            waveform, sr, transcript, *_ = data

        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
        input_values = waveform.squeeze(0).numpy()
        labels = self.tokenizer(transcript).input_ids
        return {"input_values": input_values, "labels": labels}

    def __len__(self):
        return len(self.dataset)


train_dataset = LibriSpeechDataset(train_dataset, tokenizer)
eval_dataset = LibriSpeechDataset(eval_dataset, tokenizer)


tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

In [5]:
# initializing wav2vec2 model
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")
model.freeze_feature_encoder()

def data_collator(batch):
    input_values = [torch.tensor(b["input_values"]) for b in batch]
    labels = [torch.tensor(b["labels"]) for b in batch]

    input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

    return {"input_values": input_values, "labels": labels}


pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

In [6]:
!pip install --no-cache-dir jiwer

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-4.0.0 rapidfuzz-3.14.1


In [7]:
# training, fine-tuning

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # decode
    pred_str = tokenizer.batch_decode(pred_ids)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids)

    return {
        "wer": wer_metric.compute(predictions=pred_str, references=label_str),
        "cer": cer_metric.compute(predictions=pred_str, references=label_str)
    }

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    logging_strategy="steps",
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    max_grad_norm=1.0,
    gradient_accumulation_steps=2,
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

{'loss': 12621.5613, 'grad_norm': 13906.1279296875, 'learning_rate': 9.862668161434978e-06, 'epoch': 0.028026905829596414}
{'loss': 4816.9197, 'grad_norm': 1839.320068359375, 'learning_rate': 9.722533632286997e-06, 'epoch': 0.05605381165919283}
{'loss': 4453.6809, 'grad_norm': 2015.767578125, 'learning_rate': 9.582399103139015e-06, 'epoch': 0.08408071748878924}
{'loss': 4351.3872, 'grad_norm': 1736.87109375, 'learning_rate': 9.442264573991032e-06, 'epoch': 0.11210762331838565}
{'loss': 4373.7825, 'grad_norm': 2047.0941162109375, 'learning_rate': 9.30213004484305e-06, 'epoch': 0.14013452914798205}
{'loss': 4323.9197, 'grad_norm': 1569.660888671875, 'learning_rate': 9.161995515695067e-06, 'epoch': 0.1681614349775785}
{'loss': 4298.5866, 'grad_norm': 1651.0101318359375, 'learning_rate': 9.021860986547086e-06, 'epoch': 0.1961883408071749}
{'loss': 4271.3366, 'grad_norm': 1591.7454833984375, 'learning_rate': 8.881726457399104e-06, 'epoch': 0.2242152466367713}
{'loss': 4301.1, 'grad_norm': 7

TrainOutput(global_step=3568, training_loss=2559.048037131271, metrics={'train_runtime': 5410.6273, 'train_samples_per_second': 10.549, 'train_steps_per_second': 0.659, 'train_loss': 2559.048037131271, 'epoch': 2.0})

In [9]:
# metrics
eval_results = trainer.evaluate()
print(f"Validation WER: {eval_results['eval_wer']:.4f}")
print(f"Validation CER: {eval_results['eval_cer']:.4f}")

{'eval_loss': 433.98602294921875, 'eval_wer': 0.3444358663284438, 'eval_cer': 0.09733575055924854, 'eval_runtime': 91.6491, 'eval_samples_per_second': 29.493, 'eval_steps_per_second': 3.688, 'epoch': 2.0}
Validation WER: 0.3444
Validation CER: 0.0973


In [11]:
test_dataset = LIBRISPEECH(root=root, url="test-clean", download=True)
test_dataset = LibriSpeechDataset(test_dataset, tokenizer)

In [12]:
test_results = trainer.evaluate(test_dataset)
print(f"Test WER: {test_results['eval_wer']:.4f}")
print(f"Test CER: {test_results['eval_cer']:.4f}")

{'eval_loss': 444.38134765625, 'eval_wer': 0.3518905964698722, 'eval_cer': 0.09871712264799785, 'eval_runtime': 95.7412, 'eval_samples_per_second': 27.365, 'eval_steps_per_second': 3.426, 'epoch': 2.0}
Test WER: 0.3519
Test CER: 0.0987
