# HW4 (b): Automatic Speech Recognition (ASR) by Fine-Tuning SSL models (30 points)

In this lab, you will learn how to retrieve speech SSL models through Hugging Face, and fine-tune the model for the ASR task.

In particular, we will apply Connectionist Temporal Classification (CTC; [paper](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)). We provide a hands-on experience in implementing and training an ASR model. The training takes several hours and may hit the limit of GPU capacity of free version of colab, so please START EARLY!!.

**Note - To overcome GPU limit, save checkpoints every epoch, so you can restart your run from that epoch.**

## **About submission**
Below, you will be asked to fill in some cells to implement the model. For submission, please make a zip containing this notebook with a complete implementation, and the model checkpoint, `best_model.pt`, which is the best performing model based on the validation WER. You should also use this model to generate the test-set transcriptions.

For this assignment, you will need to submit transcriptions for the test-set provided with the data as a text file in the exact format of the specified audio files. 

The submission file should have the name: 'asr_submission.txt'. 


## 1. Setting up environment and data

Please set a GPU session.


In [None]:
# Install required packages
!pip install transformers flashlight-text

Download the dataset from https://drive.google.com/file/d/1KEX_sLTRGOt82DjMsOt7ItqwI7SFN49G/view?usp=sharing

This dataset is a mini-subset of [LibriSpeech](https://https://ieeexplore.ieee.org/abstract/document/7178964), a most commonly used speech dataset that is composed of audiobook recordings.

The mini-version created for this assignment has 10 hours of multispeaker English audios for training and one hour for validation and test set.

Please place the zip file under your Google drive folder, YOUR_DIR. Then, follow the below instructions to set up the dataset.

The texts are already normalized.

In [None]:
# Mounting your Google drive to Colab environment.

from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_zip = '/content/drive/MyDrive/183dataset/student_dataset.zip'

import zipfile

with zipfile.ZipFile(data_zip, 'r') as zip_ref:
    zip_ref.extractall('/content/drive/MyDrive/183dataset/')

In [None]:
DATA_DIR = "/content/drive/MyDrive/YOUR_DIR/DATA-FOLDER-NAME"

## 2. Basic helper functions for dataloading and tokenizing.

We use PyTorch library for training the model. Here, we provide basic helper functions to set up the dataloader, and tokenize texts to character indices.


You don't have to implement anything here, but we highly recommend to go over each function carefully.



In [None]:
# Loading packages

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torchaudio
import tqdm
from datetime import datetime

In [None]:
def load_transcription(transcription_file):
    '''
    Load transcription
    '''
    tag2text = {}
    with open(transcription_file, "r") as f:
        for line in f.readlines():
            tag, text = line.rstrip().split("|")
            tag2text[tag] = text
    return tag2text

class CharacterTokenizer(nn.Module):
    '''
    Tokenize texts to indices of charactors, and decode indices back to texts.
    '''
    def __init__(self,):
        super().__init__()
        self.charactors = ['blank','pad', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
                           'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
                           'U', 'V', 'W', 'X', 'Y', 'Z', ' ', "'" ]
        self.ch2idx = {ch:i for i, ch in enumerate(self.charactors)}
        self.pad_id = 1
        self.blank =  0
        self.vocab = self.charactors


    def encode(self, text):
        token_idxs = np.array([self.ch2idx[t] for t in text.upper() if t in self.charactors])
        return token_idxs

    def enocode_torch_batch(self, texts):
        token_idxs_batch = [torch.from_numpy(self.encode(text)).long() for text in texts]
        token_lens = torch.Tensor([len(token_idxs) for token_idxs in token_idxs_batch]).long()
        token_idxs_batch = nn.utils.rnn.pad_sequence(token_idxs_batch,
                                                batch_first=True, padding_value=self.pad_id)
        return token_idxs_batch, token_lens

    def decode(self, token_idxs):
        return ''.join([self.charactors[i] for i in token_idxs if i not in [self.pad_id, self.blank]])

    def pad_id(self):
        return self.pad_id

    def blank_id(self):
        return self.blank

In [None]:
class SpeechTextDataset(Dataset):

    def __init__(self, DATA_DIR, split='train',):
        super().__init__()
        DATA_DIR = Path(DATA_DIR)
        # Use split-specific transcription files
        if split == 'train':
            self.tag2text = load_transcription(DATA_DIR/"train_transcriptions.txt")
        elif split == 'dev':
            self.tag2text = load_transcription(DATA_DIR/"dev_transcriptions.txt")
        else:  # test split
            self.tag2text = {}  # No transcriptions for test set
        self.wav_files = [f for f in (DATA_DIR/split).glob("*.flac")]
        self.wav_files.sort()

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self,i):
        wav_file = self.wav_files[i]
        # load wave form
        wav,sr = torchaudio.load(wav_file) # shape = (1, L) for mono-sound

        # z-score wave form
        wav = (wav-wav.mean())/wav.std()
        assert sr ==16000

        text = self.tag2text[wav_file.stem]

        output = {'wav':wav[0],
                  'text':text}

        return output

    @staticmethod
    def collate(batch):
        data = {}
        input_values =  nn.utils.rnn.pad_sequence([d['wav'] for d in batch],
                                                batch_first=True, padding_value=0.0)
        data['input_values'] = input_values
        data['attention_mask'] = nn.utils.rnn.pad_sequence([torch.ones(len(d['wav'])) for d in batch],
                                                batch_first=True, padding_value=0)
        data['text'] = [d['text'] for d in batch]
        return data

In [None]:
# Initialize dataset
# The default mode is 'train'
dataset = SpeechTextDataset(DATA_DIR, 'train')  # or 'dev' for validation

In [None]:
# Check how a data point looks like.

dataset.__getitem__(0)

{'wav': tensor([ 0.0164,  0.0142,  0.0120,  ...,  0.0133, -0.0030, -0.0109]),
 'text': 'PSYCHOTHERAPY AND THE COMMUNITY BOTH THE PHYSICIAN AND THE PATIENT FIND THEIR PLACE IN THE COMMUNITY THE LIFE INTERESTS OF WHICH ARE SUPERIOR TO THE INTERESTS OF THE INDIVIDUAL'}

## 3. Implementation of CTC ASR model with a pre-trained SSL upstream

Here, we will load a pre-trained Hubert-base model from huggingface. The ASR model will be constructed as an SSL model upstream and some layers of LSTMs.

The model will be trained using CTC (`nn.CTCLoss`).

Please fill below `YOUR CODE` sections to complete the model.

You need to decide which layer in the Transformer encoder in SSL model to use (refer to probing assignment for this), and LSTM configuration to run.
(Please don't try very large models. Colab would not allow it.)

In [None]:
from torchaudio.models.decoder import ctc_decoder

######################## YOUR CODE ######################
# you can find more from Hugging Face!
# https://huggingface.co/
from transformers import TODO
#########################################################

class ASR(nn.Module):
    def __init__(self,
                 ######################## YOUR CODE ######################
                 # Please add any arguments required to initiate the model
                 #########################################################
                 ):


        super().__init__()


        self.token_processor = CharacterTokenizer()

        ######################## YOUR CODE ######################
        self.upstream =  TODO ## Initiate and load an SSL model,
        # If you figure out the function call correctly,
        # you will be able to change some config of the SSL (e.g., layer num) when you load the model.

        self.lstm = TODO # LSTM module
        self.lstm_output_size = TODO
        self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab)) # Final classification layer

        # We put a downsampling convolution here to reduce temporal resoultion from 50Hz to 25Hz
        # to reduce the computation load.
        # You need to properly set ssl_output_size, and  lstm_input_size.
        #########################################################

        self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
                                    kernel_size=2, stride=2, padding=0, dilation=1,)
        self.loss = nn.CTCLoss(blank = self.token_processor.blank_id(),zero_infinity=True,reduction='mean')
        self.ctc_decoder = ctc_decoder(lexicon=None,
                                      tokens=self.token_processor.vocab,
                                      lm=None,
                                      lm_dict=None,
                                      nbest=1,
                                      beam_size=5,
                                      blank_token="blank",
                                      sil_token=" ",
                                     )

    def forward(self, input_values, text, attention_mask=None):

        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
        text_idxs = text_idxs.to(input_values.device).to(torch.int32)
        text_lens = text_lens.to(input_values.device).to(torch.int32)

        ######################## YOUR CODE ######################
        # fill in the argument below to run SSL upstream model.
        # the resulting "source_encodings" should be (Batch size, Length, Dimension).
        source_encodings = self.upstream(TODO).last_hidden_state
        #########################################################

        source_encodings = self.downsample(source_encodings.transpose(1,2)).transpose(1,2) # Don't have to change

        ######################## YOUR CODE ######################
        # You need to implement the code for the rest of the part (LSTM, logit, ...)
        #
        # And fill in the below loss code for calling CTCLoss
        # Tip: check argument types/shapes in https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html
        loss = TODO
        #########################################################

        return loss

    def predict(self, input_values, attention_mask=None):

        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        ######################## YOUR CODE ######################
        # fill in the inference code. This should be the same as "forward" function except loss calculation.
        #########################################################

        pred_texts = self.ctc_decoder(source_encodings.cpu(), source_lengths.cpu())
        pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]
        return pred_texts, source_encodings




## 4. Training

Training ASR model usually takes long time and consumes a lot of GPU space. Here, we won't ask to reach convergence to get very low WER (try if you have enough capacity in Colab!).

Below, we put very minimal setting for training, in which we expect to see below 50% WER when the model is trained for 20 epochs, if successfully implemented. As mentioned, the fine-tuning can get below 20% WER. Also, the loss may seem not improving for a few first epochs, which is natural for CTC. So be patient and please try early, so you can try many.

This procedure takes several hours. 

NOTE - The choise of layer used for training ASR will be crucial. Try training the model with different layers. 

The points distribution based on WER on the test set:

< 70% : 10 points

< 50% : 15 points

< 45% : 20 points

< 40% : 25 points

< 35% : 30 points (Full points) 

In [None]:
## Helper functions for WER/CER calculation and train/validation loop.

def get_wer(gt_text, pred_text):
    gt_text_tokens = [t for t in gt_text.upper().split(' ') if t != '']
    pred_text_tokens = [t for t in pred_text.upper().split(' ') if t != '']
    return torchaudio.functional.edit_distance(pred_text_tokens,gt_text_tokens) /len(gt_text_tokens)

def get_cer(gt_text, pred_text):
    return torchaudio.functional.edit_distance(pred_text.upper(),gt_text.upper()) /len(gt_text)


def train_loop(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    report_step = np.ceil(len(dataloader)/5).astype(int)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch_i, batch in enumerate(dataloader):
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        text = batch['text']
        # Compute loss
        loss = model(input_values, text, attention_mask)

        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        optimizer.zero_grad()

        if batch_i % report_step == 0:
            loss, current = loss.item(), batch_i * batch_size + len(input_values)
            print(f"    Train loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    wer, cer = 0, 0
    cnt = 0
    with torch.no_grad():
        for batch in dataloader:
            input_values = batch['input_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            gt_texts = batch['text']
            pred_texts, _ = model.predict(input_values, attention_mask)
            for gt_text, pred_text in zip(gt_texts, pred_texts):
                wer += get_wer(gt_text, pred_text)
                cer += get_cer(gt_text, pred_text)
                cnt += 1

    wer = wer/cnt
    cer = cer/cnt
    return wer, cer

In [None]:
# Check if GPU is available
# If False, please check the runtime type of the session.
print(torch.cuda.is_available())

False


In [None]:
## Training config

device = 'cuda'
batch_size = 4 # You may meet GPU limit if you put this number higher.
epoch = 20 # This number is not enough for convergence but sufficient to pass the assignment.
lr = 4e-5 # this should be sufficiently small

In [None]:
# Initialize the model
model = ASR()
model = model.to(device).train()

In [None]:
# Initialize dataloader

train_dataset = SpeechTextDataset(DATA_DIR, 'train')
val_dataset = SpeechTextDataset(DATA_DIR, 'dev')

train_dataloader = DataLoader(train_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=2,
                        drop_last=True,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)

val_dataloader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=2,
                        drop_last=False,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)



In [None]:
# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
best_weights = model.state_dict()
best_wer = 99999
for e in range(epoch):
    print(f"Epoch [{e+1}/{epoch}] - {datetime.now()}")
    train_loop(train_dataloader, model, optimizer)
    val_wer, val_cer = test_loop(val_dataloader, model)
    print(f"    Validation WER: {(100*val_wer):>0.2f}%, CER: {(100*val_cer):>0.2f}%")
    if val_wer < best_wer:
        best_weights = model.state_dict()



In [None]:
# save model checkpoint

torch.save({'state_dict':model.state_dict()}, "best_model.pt")

In [None]:
# Load your best model
model.load_state_dict(torch.load("best_model.pt", map_location="cpu")['state_dict'])

# Load test dataset
test_dataset = SpeechTextDataset(DATA_DIR, 'test')
submission_order_file = Path(DATA_DIR) / "test_submission_order.txt"

# Generate test transcriptions
model.eval()
test_transcriptions = {}

print("Generating test transcriptions...")
with torch.no_grad():
    for i in tqdm.tqdm(range(len(test_dataset)), desc="Transcribing test set"):
        sample = test_dataset[i]
        wav_file = test_dataset.wav_files[i]
        filename = wav_file.stem
        
        input_values = sample['wav'].unsqueeze(0).to(device)
        attention_mask = torch.ones_like(input_values).to(device)
        
        pred_texts, _ = model.predict(input_values, attention_mask)
        prediction = pred_texts[0] if len(pred_texts) > 0 else ""
        test_transcriptions[filename] = prediction

# Create submission file in correct order
with open(submission_order_file, 'r') as f:
    submission_order = [line.strip() for line in f.readlines()]

output_lines = []
for audio_id in submission_order:
    if audio_id in test_transcriptions:
        output_lines.append(f"{audio_id}|{test_transcriptions[audio_id]}")
    else:
        output_lines.append(f"{audio_id}|")

with open("asr_submission.txt", "w") as f:
    for line in output_lines:
        f.write(line + "\n")

print("Test transcriptions saved to asr_submission.txt")

### !!!! Please make sure that the output file format is what's expected.