In [1]:
# This notebook is for ASR fine-tuning.

In [2]:
# Import packages

import os
import torch
import torchaudio
import random
import numpy as np
import glob
import csv
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import jiwer
import tqdm


In [3]:
def get_uttname_from_path(path):
    # Get the utterance name from the audio file path
    return(os.path.join(os.path.basename(os.path.dirname(path)), os.path.basename(path)))

In [4]:
def read_ref_csv(path):
    # Read a CSV file containing reference word sequences
    ref = dict()
    with open(path, "r", encoding="utf-8") as file:
        for line in file:
            line = line.rstrip().split(",")
            if ".mp3" in line[0]:
                utt_name = line[0]
                assert(utt_name not in ref)
                ref[utt_name] = line[1].upper() # wav2vec2-large-960h only supports upper case
    return ref

In [5]:
# Define dataset

class common_voice_dataset(torch.utils.data.Dataset):
    def __init__(self, audio_filenames, reference_filename, processor, max_audio_len):
        super(common_voice_dataset, self).__init__()
        self.audio_filenames = audio_filenames
        self.ref_words = dict()
        self.ref_token_ids = dict()
        self.processor = processor
        self.max_audio_len = max_audio_len

        # Read reference outputs
        self.ref_words = read_ref_csv(reference_filename)

        # Convert reference word sequence to grapheme ID sequence
        for utt_name in sorted(self.ref_words.keys()):
            self.ref_token_ids[utt_name] = torch.LongTensor(self.processor.tokenizer.encode(self.ref_words[utt_name]))


    def __len__(self):
        return len(self.audio_filenames)


    def __getitem__(self, idx):

        # Get utterance name
        utt_name = get_uttname_from_path(self.audio_filenames[idx])
        
        # Read audio file
        expected_sampling_rate = 16000
        audio, sr = torchaudio.load(self.audio_filenames[idx])
        if sr != expected_sampling_rate:
            audio = torchaudio.transforms.Resample(sr, expected_sampling_rate)(audio)
        audio = audio[0]

        # Truncate audio to max_audio_len to prevent CUDA out of memory
        if len(audio) > self.max_audio_len:
            audio = audio[:self.max_audio_len]

        # Extract features
        features = self.processor(audio, return_tensors="pt", sampling_rate=expected_sampling_rate).input_values
        
        return {
            "audio": features,
            "words": self.ref_words[utt_name],
            "token_ids": self.ref_token_ids[utt_name],
            "utt_name": utt_name
        }

In [6]:
# Define dataloader collate function

def common_voice_collate_fn(samples):

    batch_size = len(samples)

    # Get audio lengths
    audio_lens = torch.LongTensor([samples[i]["audio"].size(1) for i in range(batch_size)])

    # Find maximum audio length
    audio_len_max = audio_lens.max().item()

    # Collate audio and zero-pad shorter sequences
    audio = torch.zeros([batch_size, audio_len_max])
    for i in range(batch_size):
        audio[i, :audio_lens[i]] = samples[i]["audio"].squeeze()

    # Get output lengths
    token_ids_lens = torch.LongTensor([len(samples[i]["token_ids"]) for i in range(batch_size)])

    # Find maximum output length
    token_id_len_max = token_ids_lens.max().item()

    # Collate output and -100-pad shorter sequences
    token_ids = torch.ones([batch_size, token_id_len_max]).to(torch.long) * (-100)
    for i in range(batch_size):
        token_ids[i, :token_ids_lens[i]] = samples[i]["token_ids"]

    # Collate reference words
    words = [samples[i]["words"] for i in range(batch_size)]

    # Collate utterance names
    utt_name = [samples[i]["utt_name"] for i in range(batch_size)]

    batch = {
        "audio": audio,
        "token_ids": token_ids,
        "words": words,
        "audio_lens": audio_lens,
        "token_ids_lens": token_ids_lens,
        "utt_name": utt_name
    }
    return batch

In [7]:
# Hyper-parameters

seed = 0
device = "cuda"
data_dir = "/home/jeremy/datasets/common_voice/cv-valid-train/cv-valid-train"
ref_path = "/home/jeremy/datasets/common_voice/cv-valid-train.csv"
exp_dir = "/home/jeremy/htx_test1234/asr-train"
val_fraction = 0.3
batch_size_tr = 8
batch_size_val = 32
lr = 0.00001
max_epochs = 2 # Should ideally train for longer, but I do not have time
val_check_interval = 1000
accumulate_grad_batches = 8
gradient_clip_val = 10
num_workers = 8
max_audio_len = 256000 # Truncate audio length to 16 seconds to avoid CUDA out of memory
limit_val_batches = 1.0
checkpoint_interval = val_check_interval # measured in global_steps


In [8]:
# Initialisation

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
pl.seed_everything(seed, workers=True)


Seed set to 0


0

In [9]:
# Split data

# Find audio files
audio_filenames = glob.glob(os.path.join(data_dir, "*.mp3"))

# Split train and validation
audio_filenames = np.random.permutation(audio_filenames)
audio_filenames_tr = audio_filenames[:int(np.round(len(audio_filenames)*(1-val_fraction)))]
audio_filenames_val = audio_filenames[int(np.round(len(audio_filenames)*(1-val_fraction))):]

In [10]:
# Read reference transcriptions
ref = read_ref_csv(ref_path)

In [11]:
# Load model
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model.to(device)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=1024, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder

In [12]:
# Initialise dataset and dataloader

dataset_tr = common_voice_dataset(audio_filenames_tr, ref_path, processor, max_audio_len)
dataset_val = common_voice_dataset(audio_filenames_val, ref_path, processor, max_audio_len)

dataloader_tr = torch.utils.data.DataLoader(
    dataset_tr,
    batch_size=batch_size_tr,
    shuffle=True,
    collate_fn=common_voice_collate_fn,
    pin_memory=False,
    num_workers=num_workers
)
dataloader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=batch_size_val,
    shuffle=False,
    collate_fn=common_voice_collate_fn,
    pin_memory=False,
    num_workers=num_workers
)

In [13]:
# Define PyTorch Lightning model wrapper

class model_pl_wrapper(pl.LightningModule):

    def __init__(self, init_model, lr):
        super(model_pl_wrapper, self).__init__()
        self.model = init_model
        self.lr = lr

    
    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.model.parameters(),
            lr=self.lr
        )


    def training_step(self, batch, batch_idx):
        device = self.device()
        batch_size = len(batch["audio"])
        
        audio = batch["audio"].to(device)
        audio_lens = batch["audio_lens"].to(device)
        ref_token_ids = batch["token_ids"].to(device)
        ref_token_ids_lens = batch["token_ids_lens"].to(device)

        # The model type is Wav2Vec2ForCTC, which is configured for return_attention_mask=False.
        # This indicates that the model is pretrained to attend over zero-padding.
        # Therefore, the model should also be allowed to attend over zero-padding during fine-tuning, to prevent mismatch with pre-training.
        # Therefore, no attention_mask is supplied to the model and audio_lens is not used.

        # Forward through model and compute CTC loss
        output = self.model(audio, labels=ref_token_ids)

        loss = output.loss

        # Log training loss into Tensorboard
        self.log("train_loss", loss, batch_size=batch_size, reduce_fx="mean", prog_bar=True)

        return loss


    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            device = self.device()
            batch_size = len(batch["audio"])
    
            audio = batch["audio"].to(device)
            ref_token_ids = batch["token_ids"].to(device)
            ref_words = batch["words"]

            # Forward through model and compute CTC loss
            model_was_training = self.model.training
            self.model.eval()
            output = self.model(audio, labels=ref_token_ids)
            if model_was_training:
                self.model.train()

            loss = output.loss
            logits = output.logits

            # Log validation loss into Tensorboard
            self.log("val_loss", loss, batch_size=batch_size, reduce_fx="mean", prog_bar=True)

            # Decode words
            predict_ids = torch.argmax(logits, dim=-1)
            hyp_words = processor.batch_decode(predict_ids)

            # Compute WER
            wer = torch.zeros([batch_size])
            total_num_ref_words = 0
            for i in range(batch_size):
                wer[i] = jiwer.wer(ref_words[i], hyp_words[i]) * len(ref_words[i].split(" "))
                total_num_ref_words += len(ref_words[i])
            avg_wer = wer.sum() / total_num_ref_words # Weighted average within batch, simple average between batches because weighted average is too difficult to implement

            # Log validation WER into Tensorboard, to measure how the model performance generalises to the use-case WER metric
            self.log("val_wer", avg_wer, batch_size=batch_size, reduce_fx="mean")


    def forward(self, audio):
        device = self.device()

        output = self.model(audio.to(device))
        return output.logits


    def device(self):
        for param in self.model.parameters():
            return param.device


    def on_after_backward(self):
        with torch.no_grad():
            grad_norm = self.compute_grad_norm()
            param_norm = self.compute_param_norm()

        # Log gradient norm to monitor training stability
        self.log("grad_norm", grad_norm, reduce_fx="mean")

        # Log parameter norm to monitor training stability
        self.log("param_norm", param_norm, reduce_fx="mean")


    def compute_grad_norm(self):
        parameters = [p for p in self.model.parameters() if p.grad is not None]

        total_norm = torch.norm(
            torch.stack([
                torch.norm(p.grad.detach(), 2)
                for p in parameters
            ]),
            2
        )
        return total_norm


    def compute_param_norm(self):
        parameters = [p for p in self.model.parameters() if p is not None]

        total_norm = torch.norm(
            torch.stack([
                torch.norm(p.detach(), 2)
                for p in parameters
            ]),
            2
        )
        return total_norm

In [14]:
# Initialise Tensorboard logger to visualise training progress

logger = TensorBoardLogger("lightning_logs", name="htx_asr_finetune")

In [15]:
# Setup model checkpointing

# Save model at regular steps
step_checkpoint_callback = pl.callbacks.ModelCheckpoint(
    every_n_train_steps=checkpoint_interval,
    save_top_k=-1
)

# Save model that has the best validation WER
val_wer_checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_wer",
    mode="min",
    save_top_k=1,
    filename="best-wer-checkpoint",
    verbose=True
)

In [None]:
# Run training

wrapped_model = model_pl_wrapper(model, lr)
wrapped_model.train()

trainer = pl.Trainer(
    max_epochs=max_epochs,
    val_check_interval=val_check_interval,
    accumulate_grad_batches=accumulate_grad_batches,
    gradient_clip_val=gradient_clip_val,
    num_nodes=1,
    use_distributed_sampler=False,
    logger=logger,
    accelerator="gpu" if device=="cuda" else None,
    devices=1 if device=="cuda" else 0,
    default_root_dir=exp_dir,
    log_every_n_steps=10,
    limit_val_batches=limit_val_batches,
    callbacks=[step_checkpoint_callback, val_wer_checkpoint_callback]
)

trainer.fit(wrapped_model, dataloader_tr, dataloader_val)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type           | Params | Mode 
-------------------------------------------------
0 | model | Wav2Vec2ForCTC | 315 M  | train
-------------------------------------------------
315 M     Trainable params
0         Non-trainable params
315 M     Total params
1,261.847 Total estimated model params size (MB)
403       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                    | 0/? [0…

Training: |                                                                                           | 0/? [0…

Validation: |                                                                                         | 0/? [0…

Epoch 0, global step 125: 'val_wer' reached 0.02566 (best 0.02566), saving model to 'lightning_logs/htx_asr_finetune/version_0/checkpoints/best-wer-checkpoint.ckpt' as top 1


Validation: |                                                                                         | 0/? [0…

Epoch 0, global step 250: 'val_wer' reached 0.02543 (best 0.02543), saving model to 'lightning_logs/htx_asr_finetune/version_0/checkpoints/best-wer-checkpoint.ckpt' as top 1


Validation: |                                                                                         | 0/? [0…

In [21]:
# The following steps are to load the model checkpoint that is going to be evaluated

# Find latest training run index
runs = glob.glob(os.path.join(exp_dir, "lightning_logs", "htx_asr_finetune", "version_*"))
latest_run_idx = 0
for r in runs:
    idx = int(os.path.basename(r).replace("version_", ""))
    if idx > latest_run_idx:
        latest_run_idx = idx

# Load checkpoint with best validation WER
checkpoint_path = os.path.join(exp_dir, "lightning_logs", "htx_asr_finetune", "version_{}".format(latest_run_idx), "checkpoints", "best-wer-checkpoint.ckpt")
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]

# Replace state_dict keys from model wrapping
new_state_dict = dict()
for key in state_dict:
    new_state_dict[key.replace("model.", "")] = state_dict[key]

# Load parameters into model
wrapped_model.eval()
model.eval()
with torch.no_grad():
    model.to("cpu")
    model.load_state_dict(new_state_dict)
    model.to(device)

In [22]:
# Use model checkpoint to decode cv-valid-test set

test_data_dir = "/home/jeremy/datasets/common_voice/cv-valid-test/cv-valid-test"
test_output_filename = "/home/jeremy/htx_test1234/asr-train/cv-valid-test.csv"

# Find all audio files
test_audio_filenames = glob.glob(os.path.join(test_data_dir, "*.mp3"))

test_transcriptions = dict()

with torch.no_grad():
    for audio_filename in tqdm.tqdm(sorted(test_audio_filenames)):
        utt_name = os.path.join(os.path.basename(os.path.dirname(audio_filename)), os.path.basename(audio_filename))
        assert(utt_name not in test_transcriptions)

        # Load audio
        expected_sr = 16000
        audio, sr = torchaudio.load(audio_filename)
        if sr != expected_sr:
            audio = torchaudio.transforms.Resample(sr, expected_sr)(audio)
        audio = audio[0]

        # Extract features from audio
        features = processor(audio, return_tensors="pt", sampling_rate=expected_sr).input_values

        # Parse audio through ASR model
        logits = model(features.to(device)).logits

        # Decode output distribution by choosing most likely token at each frame
        predict_ids = torch.argmax(logits, dim=-1)

        # Convert token sequence into word sequence
        test_transcriptions[utt_name] = processor.batch_decode(predict_ids)[0]

# Write transcriptions to file
with open(test_output_filename, "w", encoding="utf-8") as file:
    print("utternace_name,generated_text", file=file)
    for utt_name in sorted(test_transcriptions.keys()):
        print("{},{}".format(utt_name, test_transcriptions[utt_name]), file=file)

100%|███████████████████████████████████████████████████████████████████████████████████| 3995/3995 [01:35<00:00, 41.78it/s]


In [25]:
# Measure cv-valid-test WER

test_ref_path = "/home/jeremy/datasets/common_voice/cv-valid-test.csv"

# Read reference from file
test_ref = read_ref_csv(test_ref_path)

# Compute WER
wer = 0
total_num_ref_words = 0
for utt_name in sorted(test_ref.keys()):
    num_ref_words = len(test_ref[utt_name].split(" "))
    total_num_ref_words += num_ref_words
    if utt_name not in test_transcriptions:
        wer += num_ref_words # All deletions
    else:
        wer += jiwer.wer(test_ref[utt_name], test_transcriptions[utt_name]) * num_ref_words
wer /= total_num_ref_words

print("WER = {} %".format(wer * 100))
with open(os.path.join(exp_dir, "cv-valid-test-wer.txt"), "w") as file:
    print("WER = {} %".format(wer * 100), file=file)

WER = 7.066004675177 %


In [24]:
# Copy checkpoint to wav2vec2-large-960h-cv

import shutil

shutil.copyfile(checkpoint_path, os.path.join(exp_dir, "wav2vec2-large-960h-cv.ckpt"))

'/home/jeremy/htx_test1234/asr-train/wav2vec2-large-960h-cv.ckpt'