In [1]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [2]:
!pip install torchcodec

Collecting torchcodec
  Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m0.9/2.0 MB[0m [31m26.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m35.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec
Successfully installed torchcodec-0.8.1


In [None]:
from torchaudio.datasets import LIBRISPEECH
from pathlib import Path
import torchaudio
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    Trainer,
    TrainingArguments,
    logging
)
from torch.utils.data import Dataset, ConcatDataset
import evaluate
import torch
import warnings
import glob
import os
import shutil

warnings.filterwarnings("ignore", category=UserWarning)
logging.set_verbosity_error()


In [4]:
# loading dataset
root = Path("data/raw/LIBRISPEECH")
root.mkdir(parents=True, exist_ok=True)

train_ds = LIBRISPEECH(root=root, url="train-clean-100", download=True)
eval_ds = LIBRISPEECH(root=root, url="dev-clean", download=True)


100%|██████████| 5.95G/5.95G [05:41<00:00, 18.7MB/s]
100%|██████████| 322M/322M [00:22<00:00, 15.3MB/s]


In [5]:
# tokenizing transcripts
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = processor.tokenizer
feature_extractor = processor.feature_extractor

class LibriSpeechDataset(Dataset):
    def __init__(self, torchaudio_dataset, tokenizer):
        self.dataset = torchaudio_dataset
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        data = self.dataset[idx]

        if len(data) == 2:
            waveform, sr = data
            transcript = ""
        else:
            waveform, sr, transcript, *_ = data

        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)

        input_values = waveform.squeeze(0).numpy()

        return {"input_values": input_values, "labels": transcript}

    def __len__(self):
        return len(self.dataset)


train_dataset = LibriSpeechDataset(train_ds, tokenizer)
eval_dataset = LibriSpeechDataset(eval_ds, tokenizer)


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [6]:
# initializing wav2vec2 model
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base",
                                       pad_token_id=processor.tokenizer.pad_token_id,
                                       vocab_size=len(processor.tokenizer),
                                       ctc_loss_reduction="mean"
                                       )
model.freeze_feature_encoder()

def data_collator(batch):
    audio = [b["input_values"] for b in batch]
    text = [b["labels"] for b in batch]

    inputs = feature_extractor(
        audio,
        sampling_rate=16000,
        padding=True,
        return_attention_mask=True,
        return_tensors="pt"
    )

    labels_batch = tokenizer(
        text,
        padding=True,
        return_tensors="pt",
        add_special_tokens=False
    )

    labels = labels_batch.input_ids
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "input_values": inputs["input_values"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }



config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

In [7]:
!pip install --no-cache-dir jiwer

model.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-4.0.0 rapidfuzz-3.14.3


In [None]:
# training, fine-tuning

wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # decode
    pred_str = tokenizer.batch_decode(pred_ids)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids)

    print("\n" + "="*30)
    print(f"SAMPLE 1 TARGET: {label_str[0]}")
    print(f"SAMPLE 1 PRED:   {pred_str[0]}")
    print("-" * 10)
    print(f"SAMPLE 2 TARGET: {label_str[1]}")
    print(f"SAMPLE 2 PRED:   {pred_str[1]}")
    print("="*30 + "\n")

    return {
        "wer": wer_metric.compute(predictions=pred_str, references=label_str),
        "cer": cer_metric.compute(predictions=pred_str, references=label_str)
    }

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    warmup_steps=1000,
    logging_steps=100,
    fp16=torch.cuda.is_available(),
    max_grad_norm=1.0,
    gradient_accumulation_steps=2,
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


In [12]:
# metrics
eval_results = trainer.evaluate()
print(f"Validation WER: {eval_results['eval_wer']:.4f}")
print(f"Validation CER: {eval_results['eval_cer']:.4f}")


SAMPLE 1 TARGET: MISTER QUILTER IS THE APOSTLE OF THE MIDLE CLASES AND WE ARE GLAD TO WELCOME HIS GOSPEL
SAMPLE 1 PRED:   MISTER QUILER IS THE OPPOSAL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
----------
SAMPLE 2 TARGET: NOR IS MISTER QUILTER'S MANER LES INTERESTING THAN HIS MATER
SAMPLE 2 PRED:   NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER

{'eval_loss': 0.10452383756637573, 'eval_wer': 0.16186904893202456, 'eval_cer': 0.039728733536182406, 'eval_runtime': 127.8705, 'eval_samples_per_second': 21.139, 'eval_steps_per_second': 1.322, 'epoch': 3.368834080717489}
Validation WER: 0.1619
Validation CER: 0.0397


In [13]:
test_dataset = LIBRISPEECH(root=root, url="test-clean", download=True)
test_dataset = LibriSpeechDataset(test_dataset, tokenizer)

100%|██████████| 331M/331M [00:19<00:00, 17.7MB/s]


In [14]:
test_results = trainer.evaluate(test_dataset)
print(f"Test WER: {test_results['eval_wer']:.4f}")
print(f"Test CER: {test_results['eval_cer']:.4f}")


SAMPLE 1 TARGET: HE HOPED THERE WOULD BE STEW FOR DINER TURNIPS AND CAROTS AND BRUISED POTATOES AND FAT MUTON PIECES TO BE LADLED OUT IN THICK PEPERED FLOUR FATENED SAUCE
SAMPLE 1 PRED:   HE HOPED THERE WOULD BE STE FOR DINNER TURNIPS AND CARRETS AND BRUISED POTATOES AND FAT MUTTEN PIECES TO BE LAIDLED OUT IN THICK PEPPERED FLOWER FATTINED SAUCE
----------
SAMPLE 2 TARGET: STUF IT INTO YOU HIS BELY COUNSELED HIM
SAMPLE 2 PRED:   STUFF IT INTO YOU HIS BELLY COUNCELED HIM

{'eval_loss': 0.11569350212812424, 'eval_wer': 0.16172778454047473, 'eval_cer': 0.03994150310945724, 'eval_runtime': 130.2868, 'eval_samples_per_second': 20.109, 'eval_steps_per_second': 1.259, 'epoch': 3.368834080717489}
Test WER: 0.1617
Test CER: 0.0399


In [None]:
trainer.save_model("./pretrained_model")
processor.save_pretrained("./pretrained_model")

In [None]:
shutil.make_archive('Wav2Vec2-base-LibriSpeech100h', 'zip', "./pretrained_model")

## Domain adaptation: training on custom datasets

In [None]:
class CommandDataset(Dataset):
    def __init__(self, folder_path, tokenizer, override_label=None):
        self.files = glob.glob(os.path.join(folder_path, "*.flac"))
        self.tokenizer = tokenizer
        self.override_label = override_label

    def __getitem__(self, idx):
        file_path = self.files[idx]

        if self.override_label:
            transcript = self.override_label
        else:
            filename = os.path.basename(file_path)
            clean_name = filename.replace(".flac", "")
            text_part = "_".join(clean_name.split("_")[:-1])
            text_part = text_part.replace("dynamic_", "")
            transcript = text_part.replace("_", " ").upper()

        waveform, sr = torchaudio.load(file_path)
        
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)

        input_values = waveform.squeeze(0).numpy()

        return {"input_values": input_values, "labels": transcript}

    def __len__(self):
        return len(self.files)

In [None]:
if not os.path.exists("data/raw/custom/commands_dataset"):
    shutil.unpack_archive("commands_dataset.zip", "data/raw/custom/commands_dataset")

if not os.path.exists("data/raw/custom/wakeup_dataset"):
    shutil.unpack_archive("wakeup_dataset.zip", "data/raw/custom/wakeup_dataset")

if not os.path.exists("results/Wav2Vec2-base-LibriSpeech100h"):
    shutil.unpack_archive("Wav2Vec2-base-LibriSpeech100h.zip", "results/Wav2Vec2-base-LibriSpeech100h")

In [None]:
trained_processor = Wav2Vec2Processor.from_pretrained("Wav2Vec2-base-LibriSpeech100h")
trained_tokenizer = trained_processor.tokenizer
trained_feature_extractor = trained_processor.feature_extractor

trained_model = Wav2Vec2ForCTC.from_pretrained("Wav2Vec2-base-LibriSpeech100h")
trained_model.freeze_feature_encoder()

In [None]:
command_dataset = CommandDataset("data/raw/custom/commands_dataset", trained_tokenizer)
wakeup_dataset = CommandDataset("data/raw/custom/wakeup_dataset", trained_tokenizer, override_label="WAKE UP TYPIST")

In [None]:
librispeech_subset = torch.utils.data.Subset(train_ds, torch.randperm(len(train_ds))[:1000])
subset_dataset = LibriSpeechDataset(librispeech_subset, trained_tokenizer)

train_domain_adapted = ConcatDataset([subset_dataset] + [command_dataset] * 2 + [wakeup_dataset] * 2)
print(f"Domain adapted train dataset length: {len(train_domain_adapted)}")

In [None]:
eval_size = int(0.1 * len(train_domain_adapted))
train_adapted_ds, eval_adapted_ds = torch.utils.data.random_split(train_domain_adapted, [len(train_domain_adapted)-eval_size, eval_size])

In [None]:
def data_collator_adapted(batch):
    audio = [b["input_values"] for b in batch]
    text = [b["labels"] for b in batch]

    inputs = trained_feature_extractor(
        audio,
        sampling_rate=16000,
        padding=True,
        return_attention_mask=True,
        return_tensors="pt"
    )

    labels_batch = trained_tokenizer(
        text,
        padding=True,
        return_tensors="pt",
        add_special_tokens=False
    )

    labels = labels_batch.input_ids
    labels[labels == trained_tokenizer.pad_token_id] = -100

    return {
        "input_values": inputs["input_values"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }

In [None]:
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics_adapted(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # decode
    pred_str = trained_tokenizer.batch_decode(pred_ids)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = trained_tokenizer.pad_token_id
    label_str = trained_tokenizer.batch_decode(label_ids, group_tokens=False)

    print(f"\nTARGET: {label_str[0]} \nPRED: {pred_str[0]}")

    return {
        "wer": wer_metric.compute(predictions=pred_str, references=label_str),
        "cer": cer_metric.compute(predictions=pred_str, references=label_str)
    }

In [None]:
os.environ["WANDB_DISABLED"] = "true"

In [None]:
training_args_adapted = TrainingArguments(
    output_dir="./finetuned_results",
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    warmup_steps=100,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    save_total_limit=2,
)

trainer_adapted = Trainer(
    model=trained_model,
    args=training_args_adapted,
    train_dataset=train_adapted_ds,
    eval_dataset=eval_adapted_ds,
    data_collator=data_collator_adapted,
    compute_metrics=compute_metrics_adapted
)

trainer_adapted.train()

In [None]:
trainer_adapted.save_model("./final_model")
trained_tokenizer.save_pretrained("./final_model")
trained_feature_extractor.save_pretrained("./final_model")

In [None]:
shutil.make_archive('Wav2Vec2-base-LibriSpeech100h-Custom', 'zip', "./final_model")