# HW4 (b): Automatic Speech Recognition (ASR) by Fine-Tuning SSL models (30 points)

In this lab, you will learn how to retrieve speech SSL models through Hugging Face, and fine-tune the model for the ASR task.

In particular, we will apply Connectionist Temporal Classification (CTC; [paper](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)). We provide a hands-on experience in implementing and training an ASR model. The training takes several hours and may hit the limit of GPU capacity of free version of colab, so please START EARLY!!.

**Note - To overcome GPU limit, save checkpoints every epoch, so you can restart your run from that epoch.**

## **About submission**
Below, you will be asked to fill in some cells to implement the model. For submission, please make a zip containing this notebook with a complete implementation, which is the best performing model based on the validation WER. You should also use this model to generate the test-set transcriptions.

For this assignment, you will need to submit transcriptions for the test-set provided with the data as a text file in the exact format of the specified audio files.

The submission file should have the name: 'asr_submission.txt'.


## 1. Setting up environment and data

Please set a GPU session.


In [7]:
import sys
import subprocess

# Install to the Python environment that Jupyter is using
subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers", "flashlight-text", "google"])

0

In [3]:
# Install required packages
!pip install transformers flashlight-text
!conda install transformers flashlight-text

Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... failed with initial frozen solve. Retrying with flexible solve.



PackagesNotFoundError: The following packages are not available from current channels:

  - flashlight-text

Current channels:

  - https://repo.anaconda.com/pkgs/main/win-64
  - https://repo.anaconda.com/pkgs/main/noarch
  - https://repo.anaconda.com/pkgs/r/win-64
  - https://repo.anaconda.com/pkgs/r/noarch
  - https://repo.anaconda.com/pkgs/msys2/win-64
  - https://repo.anaconda.com/pkgs/msys2/noarch

To search for alternate channels that may provide the conda package you're
looking for, navigate to

    https://anaconda.org

and use the search bar at the top of the page.




Download the dataset from https://drive.google.com/file/d/1KEX_sLTRGOt82DjMsOt7ItqwI7SFN49G/view?usp=sharing

This dataset is a mini-subset of [LibriSpeech](https://https://ieeexplore.ieee.org/abstract/document/7178964), a most commonly used speech dataset that is composed of audiobook recordings.

The mini-version created for this assignment has 10 hours of multispeaker English audios for training and one hour for validation and test set.

Please place the zip file under your Google drive folder, YOUR_DIR. Then, follow the below instructions to set up the dataset.

The texts are already normalized.

In [8]:
# Mounting your Google drive to Colab environment.

from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'google'

In [None]:
data_zip = '/content/drive/MyDrive/183/student_dataset.zip'

import zipfile

with zipfile.ZipFile(data_zip, 'r') as zip_ref:
    zip_ref.extractall('/content/drive/MyDrive/183dataset/')

In [1]:
DATA_DIR = "C:/Users/cjmc7/Documents/student_dataset"

## 2. Basic helper functions for dataloading and tokenizing.

We use PyTorch library for training the model. Here, we provide basic helper functions to set up the dataloader, and tokenize texts to character indices.


You don't have to implement anything here, but we highly recommend to go over each function carefully.



In [2]:
# Loading packages

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torchaudio
import tqdm
from datetime import datetime

In [3]:
def load_transcription(transcription_file):
    '''
    Load transcription
    '''
    tag2text = {}
    with open(transcription_file, "r") as f:
        for line in f.readlines():
            tag, text = line.rstrip().split("|")
            tag2text[tag] = text
    return tag2text

class CharacterTokenizer(nn.Module):
    '''
    Tokenize texts to indices of charactors, and decode indices back to texts.
    '''
    def __init__(self,):
        super().__init__()
        self.charactors = ['blank','pad', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
                           'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
                           'U', 'V', 'W', 'X', 'Y', 'Z', ' ', "'" ]
        self.ch2idx = {ch:i for i, ch in enumerate(self.charactors)}
        self.pad_id = 1
        self.blank =  0
        self.vocab = self.charactors


    def encode(self, text):
        token_idxs = np.array([self.ch2idx[t] for t in text.upper() if t in self.charactors])
        return token_idxs

    def enocode_torch_batch(self, texts):
        token_idxs_batch = [torch.from_numpy(self.encode(text)).long() for text in texts]
        token_lens = torch.Tensor([len(token_idxs) for token_idxs in token_idxs_batch]).long()
        token_idxs_batch = nn.utils.rnn.pad_sequence(token_idxs_batch,
                                                batch_first=True, padding_value=self.pad_id)
        return token_idxs_batch, token_lens

    def decode(self, token_idxs):
        return ''.join([self.charactors[i] for i in token_idxs if i not in [self.pad_id, self.blank]])

    def pad_id(self):
        return self.pad_id

    def blank_id(self):
        return self.blank

In [4]:
class SpeechTextDataset(Dataset):

    def __init__(self, DATA_DIR, split='train',):
        super().__init__()
        DATA_DIR = Path(DATA_DIR)
        # Use split-specific transcription files
        if split == 'train':
            self.tag2text = load_transcription(DATA_DIR/"train_transcriptions.txt")
        elif split == 'dev':
            self.tag2text = load_transcription(DATA_DIR/"dev_transcriptions.txt")
        else:  # test split
            self.tag2text = {}  # No transcriptions for test set
        self.wav_files = [f for f in (DATA_DIR/split).glob("*.flac")]
        self.wav_files.sort()

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self,i):
        wav_file = self.wav_files[i]
        # load wave form
        wav,sr = torchaudio.load(wav_file) # shape = (1, L) for mono-sound

        # z-score wave form
        wav = (wav-wav.mean())/wav.std()
        assert sr ==16000

        text = self.tag2text[wav_file.stem]

        output = {'wav':wav[0],
                  'text':text}

        return output

    @staticmethod
    def collate(batch):
        data = {}
        input_values =  nn.utils.rnn.pad_sequence([d['wav'] for d in batch],
                                                batch_first=True, padding_value=0.0)
        data['input_values'] = input_values
        data['attention_mask'] = nn.utils.rnn.pad_sequence([torch.ones(len(d['wav'])) for d in batch],
                                                batch_first=True, padding_value=0)
        data['text'] = [d['text'] for d in batch]
        return data

In [5]:
# Initialize dataset
# The default mode is 'train'
dataset = SpeechTextDataset(DATA_DIR, 'train')  # or 'dev' for validation

In [6]:
# Check how a data point looks like.

dataset.__getitem__(0)

{'wav': tensor([ 0.0164,  0.0142,  0.0120,  ...,  0.0133, -0.0030, -0.0109]),
 'text': 'PSYCHOTHERAPY AND THE COMMUNITY BOTH THE PHYSICIAN AND THE PATIENT FIND THEIR PLACE IN THE COMMUNITY THE LIFE INTERESTS OF WHICH ARE SUPERIOR TO THE INTERESTS OF THE INDIVIDUAL'}

## 3. Implementation of CTC ASR model with a pre-trained SSL upstream

Here, we will load a pre-trained Hubert-base model from huggingface. The ASR model will be constructed as an SSL model upstream and some layers of LSTMs.

The model will be trained using CTC (`nn.CTCLoss`).

Please fill below `YOUR CODE` sections to complete the model.

You need to decide which layer in the Transformer encoder in SSL model to use (refer to probing assignment for this), and LSTM configuration to run.
(Please don't try very large models. Colab would not allow it.)

In [None]:
# from torchaudio.models.decoder import ctc_decoder

# ######################## YOUR CODE ######################
# # you can find more from Hugging Face!
# # https://huggingface.co/
# from transformers import TODO
# #########################################################

# class ASR(nn.Module):
#     def __init__(self,
#                  ######################## YOUR CODE ######################
#                  # Please add any arguments required to initiate the model
#                  #########################################################
#                  ):


#         super().__init__()


#         self.token_processor = CharacterTokenizer()

#         ######################## YOUR CODE ######################
#         self.upstream =  TODO ## Initiate and load an SSL model,
#         # If you figure out the function call correctly,
#         # you will be able to change some config of the SSL (e.g., layer num) when you load the model.

#         self.lstm = TODO # LSTM module
#         self.lstm_output_size = TODO
#         self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab)) # Final classification layer

#         # We put a downsampling convolution here to reduce temporal resoultion from 50Hz to 25Hz
#         # to reduce the computation load.
#         # You need to properly set ssl_output_size, and  lstm_input_size.
#         #########################################################

#         self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
#                                     kernel_size=2, stride=2, padding=0, dilation=1,)
#         self.loss = nn.CTCLoss(blank = self.token_processor.blank_id(),zero_infinity=True,reduction='mean')
#         self.ctc_decoder = ctc_decoder(lexicon=None,
#                                       tokens=self.token_processor.vocab,
#                                       lm=None,
#                                       lm_dict=None,
#                                       nbest=1,
#                                       beam_size=5,
#                                       blank_token="blank",
#                                       sil_token=" ",
#                                      )

#     def forward(self, input_values, text, attention_mask=None):

#         if attention_mask is None:
#             attention_mask = torch.ones_like(input_values)

#         text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
#         text_idxs = text_idxs.to(input_values.device).to(torch.int32)
#         text_lens = text_lens.to(input_values.device).to(torch.int32)

#         ######################## YOUR CODE ######################
#         # fill in the argument below to run SSL upstream model.
#         # the resulting "source_encodings" should be (Batch size, Length, Dimension).
#         source_encodings = self.upstream(TODO).last_hidden_state
#         #########################################################

#         source_encodings = self.downsample(source_encodings.transpose(1,2)).transpose(1,2) # Don't have to change

#         ######################## YOUR CODE ######################
#         # You need to implement the code for the rest of the part (LSTM, logit, ...)
#         #
#         # And fill in the below loss code for calling CTCLoss
#         # Tip: check argument types/shapes in https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html
#         loss = TODO
#         #########################################################

#         return loss

#     def predict(self, input_values, attention_mask=None):

#         if attention_mask is None:
#             attention_mask = torch.ones_like(input_values)

#         ######################## YOUR CODE ######################
#         # fill in the inference code. This should be the same as "forward" function except loss calculation.
#         #########################################################

#         pred_texts = self.ctc_decoder(source_encodings.cpu(), source_lengths.cpu())
#         pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]
#         return pred_texts, source_encodings




In [None]:
from torchaudio.models.decoder import ctc_decoder

######################## YOUR CODE ######################
from transformers import HubertModel
#########################################################

class ASR(nn.Module):
    def __init__(self,
                 ######################## YOUR CODE ######################
                 ssl_layer=10,  # Based on probing results, layer 10 was best for phonemes
                 lstm_hidden_size=256,
                 lstm_num_layers=2,
                 lstm_dropout=0.1,
                 #########################################################
                 ):

        super().__init__()

        self.token_processor = CharacterTokenizer()

        ######################## YOUR CODE ######################
        # Load HuBERT model and configure to output specific layer
        self.upstream = HubertModel.from_pretrained(
            "facebook/hubert-base-ls960",
            num_hidden_layers=ssl_layer + 1  # +1 because layer 0 is embedding
        )
        # Freeze upstream model to save memory and computation -- disabling this
        # for param in self.upstream.parameters():
        #     param.requires_grad = False

        # SSL output size for HuBERT base
        ssl_output_size = 768

        # LSTM input size after downsampling
        lstm_input_size = 768

        # LSTM configuration
        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=lstm_dropout if lstm_num_layers > 1 else 0
        )

        # LSTM output size (bidirectional doubles the size)
        self.lstm_output_size = lstm_hidden_size * 2

        self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab))
        #########################################################

        # Downsampling convolution
        self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
                                    kernel_size=2, stride=2, padding=0, dilation=1)

        self.loss = nn.CTCLoss(blank=self.token_processor.blank_id(), zero_infinity=True, reduction='mean')
        self.ctc_decoder = ctc_decoder(lexicon=None,
                                      tokens=self.token_processor.vocab,
                                      lm=None,
                                      lm_dict=None,
                                      nbest=1,
                                      beam_size=5,
                                      blank_token="blank",
                                      sil_token=" ",
                                     )

    def forward(self, input_values, text, attention_mask=None):

      if attention_mask is None:
          attention_mask = torch.ones_like(input_values)

      text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
      text_idxs = text_idxs.to(input_values.device).to(torch.int32)
      text_lens = text_lens.to(input_values.device).to(torch.int32)

      ######################## YOUR CODE ######################
      # Run SSL upstream model
      source_encodings = self.upstream(
          input_values=input_values,
          attention_mask=attention_mask
      ).last_hidden_state
      #########################################################

      # Downsample
      source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)

      ######################## YOUR CODE ######################
      # LSTM forward pass
      lstm_out, _ = self.lstm(source_encodings)

      # Get logits
      logits = self.logit(lstm_out)  # (B, T, vocab_size)

      # Prepare for CTC loss
      # CTC expects: (T, B, C) format
      log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
      log_probs = log_probs.transpose(0, 1)  # (T, B, C)

      # Calculate input lengths based on actual sequence length after processing
      # source_encodings has shape (B, T_processed, D)
      # The actual sequence length is T_processed (the time dimension)
      # input_lengths = torch.full((source_encodings.size(0),), source_encodings.size(1), dtype=torch.int32, device=source_encodings.device)
      input_lengths = (attention_mask.sum(dim=1) / 640).long()
      input_lengths = input_lengths.clamp(max=log_probs.size(0))

      # CTC Loss
      loss = self.loss(
          log_probs=log_probs,
          targets=text_idxs,
          input_lengths=input_lengths,
          target_lengths=text_lens
      )
      #########################################################

      return loss

    def predict(self, input_values, attention_mask=None):

      if attention_mask is None:
          attention_mask = torch.ones_like(input_values)

      ######################## YOUR CODE ######################
      # Run SSL upstream model
      source_encodings = self.upstream(
          input_values=input_values,
          attention_mask=attention_mask
      ).last_hidden_state

      # Downsample
      source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)

      # LSTM forward pass
      lstm_out, _ = self.lstm(source_encodings)

      # Get logits
      logits = self.logit(lstm_out)  # (B, T, vocab_size)

      # Calculate source lengths based on actual sequence length
      source_lengths = torch.full((source_encodings.size(0),), source_encodings.size(1), dtype=torch.int32)
      #########################################################

      pred_texts = self.ctc_decoder(logits.cpu(), source_lengths.cpu())
      pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]

      return pred_texts, source_encodings

In [17]:
# Version 3

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torchaudio
from tqdm import tqdm
from datetime import datetime
from transformers import HubertModel, HubertConfig
from torchaudio.models.decoder import ctc_decoder

class ASR(nn.Module):
    def __init__(self,
                 hubert_layer=10,
                 lstm_hidden_size=768,
                 lstm_num_layers=2,
                 lstm_dropout=0.1):
        super().__init__()

        self.token_processor = CharacterTokenizer()

        # Load HuBERT model with specific layer configuration
        config = HubertConfig.from_pretrained('facebook/hubert-base-ls960')
        config.num_hidden_layers = hubert_layer
        self.upstream = HubertModel.from_pretrained('facebook/hubert-base-ls960', config=config)

        # Freeze upstream (comment out to fine-tune)
        for param in self.upstream.parameters():
            param.requires_grad = False

        # LSTM decoder
        ssl_output_size = 768  # HuBERT-base hidden size
        lstm_input_size = 768
        self.lstm_hidden_size = lstm_hidden_size

        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout if lstm_num_layers > 1 else 0,
            batch_first=True,
            bidirectional=True
        )

        self.lstm_output_size = lstm_hidden_size * 2  # Bidirectional
        self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab))

        # Downsampling convolution (50Hz -> 25Hz)
        self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
                                     kernel_size=2, stride=2, padding=0, dilation=1)

        # CTC loss
        self.loss = nn.CTCLoss(blank=self.token_processor.blank_id(), zero_infinity=True, reduction='mean')

        # CTC decoder
        self.ctc_decoder = ctc_decoder(lexicon=None,
                                       tokens=self.token_processor.vocab,
                                       lm=None,
                                       lm_dict=None,
                                       nbest=1,
                                       beam_size=5,
                                       blank_token="blank",
                                       sil_token=" ")

    def forward(self, input_values, text, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
        text_idxs = text_idxs.to(input_values.device).to(torch.int32)
        text_lens = text_lens.to(input_values.device).to(torch.int32)

        # SSL upstream encoding
        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state

        # Downsample
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)

        # LSTM
        lstm_out, _ = self.lstm(source_encodings)

        # Logits
        logits = self.logit(lstm_out)

        # Log probabilities for CTC
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Calculate source lengths
        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.to(torch.int32)

        # Transpose for CTC: (T, B, C)
        log_probs = log_probs.transpose(0, 1)

        # CTC loss
        loss = self.loss(log_probs, text_idxs, source_lengths, text_lens)

        return loss

    def predict(self, input_values, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        # SSL encoding
        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state

        # Downsample
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)

        # LSTM
        lstm_out, _ = self.lstm(source_encodings)

        # Logits
        logits = self.logit(lstm_out)

        # Log probabilities
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Calculate source lengths
        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.cpu().to(torch.int32)

        # CTC decode
        pred_texts = self.ctc_decoder(log_probs.cpu(), source_lengths.cpu())
        pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]

        return pred_texts, source_encodings

In [7]:
# Cell 3: Your complete ASR class (copy this exactly)
import torch
import torch.nn as nn
from transformers import HubertModel, HubertConfig
from torchaudio.models.decoder import ctc_decoder

class ASR(nn.Module):
    def __init__(self,
                 hubert_model_path='./hubert_model',
                 hubert_layer=10,
                 lstm_hidden_size=768,
                 lstm_num_layers=2,
                 lstm_dropout=0.1):
        super().__init__()

        self.token_processor = CharacterTokenizer()

        print(f"Loading HuBERT from {hubert_model_path}...")

        config = HubertConfig.from_pretrained(
            hubert_model_path,
            local_files_only=True
        )
        config.num_hidden_layers = hubert_layer

        self.upstream = HubertModel.from_pretrained(
            hubert_model_path,
            config=config,
            local_files_only=True
        )

        print("✓ HuBERT loaded!")

        # Freeze upstream
        for param in self.upstream.parameters():
            param.requires_grad = False

        # LSTM decoder
        ssl_output_size = 768
        lstm_input_size = 768
        self.lstm_hidden_size = lstm_hidden_size

        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout if lstm_num_layers > 1 else 0,
            batch_first=True,
            bidirectional=True
        )

        self.lstm_output_size = lstm_hidden_size * 2
        self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab))

        self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
                                     kernel_size=2, stride=2, padding=0, dilation=1)

        self.loss = nn.CTCLoss(blank=self.token_processor.blank_id(),
                               zero_infinity=True, reduction='mean')

        self.ctc_decoder = ctc_decoder(lexicon=None,
                                       tokens=self.token_processor.vocab,
                                       lm=None,
                                       lm_dict=None,
                                       nbest=1,
                                       beam_size=5,
                                       blank_token="blank",
                                       sil_token=" ")

    def forward(self, input_values, text, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
        text_idxs = text_idxs.to(input_values.device).to(torch.int32)
        text_lens = text_lens.to(input_values.device).to(torch.int32)

        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)
        lstm_out, _ = self.lstm(source_encodings)
        logits = self.logit(lstm_out)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.to(torch.int32)

        log_probs = log_probs.transpose(0, 1)
        loss = self.loss(log_probs, text_idxs, source_lengths, text_lens)

        return loss

    def predict(self, input_values, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)
        lstm_out, _ = self.lstm(source_encodings)
        logits = self.logit(lstm_out)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.cpu().to(torch.int32)

        pred_texts = self.ctc_decoder(log_probs.cpu(), source_lengths.cpu())
        pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]

        return pred_texts, source_encodings

## 4. Training

Training ASR model usually takes long time and consumes a lot of GPU space. Here, we won't ask to reach convergence to get very low WER (try if you have enough capacity in Colab!).

Below, we put very minimal setting for training, in which we expect to see below 50% WER when the model is trained for 20 epochs, if successfully implemented. As mentioned, the fine-tuning can get below 20% WER. Also, the loss may seem not improving for a few first epochs, which is natural for CTC. So be patient and please try early, so you can try many.

This procedure takes several hours.

NOTE - The choise of layer used for training ASR will be crucial. Try training the model with different layers.

The points distribution based on WER on the test set:

< 70% : 10 points

< 50% : 15 points

< 45% : 20 points

< 40% : 25 points

< 35% : 30 points (Full points)

In [8]:
## Helper functions for WER/CER calculation and train/validation loop.

def get_wer(gt_text, pred_text):
    gt_text_tokens = [t for t in gt_text.upper().split(' ') if t != '']
    pred_text_tokens = [t for t in pred_text.upper().split(' ') if t != '']
    return torchaudio.functional.edit_distance(pred_text_tokens,gt_text_tokens) /len(gt_text_tokens)

def get_cer(gt_text, pred_text):
    return torchaudio.functional.edit_distance(pred_text.upper(),gt_text.upper()) /len(gt_text)


def train_loop(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    report_step = np.ceil(len(dataloader)/5).astype(int)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch_i, batch in enumerate(dataloader):
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        text = batch['text']
        # Compute loss
        loss = model(input_values, text, attention_mask)

        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        optimizer.zero_grad()

        if batch_i % report_step == 0:
            loss, current = loss.item(), batch_i * batch_size + len(input_values)
            print(f"    Train loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    wer, cer = 0, 0
    cnt = 0
    with torch.no_grad():
        for batch in dataloader:
            input_values = batch['input_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            gt_texts = batch['text']
            pred_texts, _ = model.predict(input_values, attention_mask)
            for gt_text, pred_text in zip(gt_texts, pred_texts):
                wer += get_wer(gt_text, pred_text)
                cer += get_cer(gt_text, pred_text)
                cnt += 1

    wer = wer/cnt
    cer = cer/cnt
    return wer, cer

In [9]:
# Check if GPU is available
# If False, please check the runtime type of the session.
print(torch.cuda.is_available())

True


In [10]:
## Training config

device = 'cuda'
batch_size = 4 # You may meet GPU limit if you put this number higher.
epoch = 20 # This number is not enough for convergence but sufficient to pass the assignment.
lr = 4e-5 # this should be sufficiently small (was originally 4e-5)

In [14]:
# Cell 1: Download using Python requests (works on Windows)
import requests
import os
from tqdm import tqdm

def download_file(url, filepath):
    """Download file with progress bar"""
    print(f"Downloading {os.path.basename(filepath)}...")

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))

    with open(filepath, 'wb') as file:
        if total_size == 0:
            file.write(response.content)
        else:
            downloaded = 0
            for data in response.iter_content(chunk_size=4096):
                downloaded += len(data)
                file.write(data)
                done = int(50 * downloaded / total_size)
                print(f"\r[{'=' * done}{' ' * (50-done)}] {downloaded}/{total_size} bytes", end='')
    print()  # New line after progress
    return filepath

# Create model directory
model_dir = "./hubert_model"
os.makedirs(model_dir, exist_ok=True)

# Download files
base_url = "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/"
files = {
    "config.json": "config.json",
    "preprocessor_config.json": "preprocessor_config.json",
    "pytorch_model.bin": "pytorch_model.bin"
}

print("Downloading HuBERT model files...\n")
for filename, local_name in files.items():
    url = base_url + filename
    filepath = os.path.join(model_dir, local_name)

    try:
        download_file(url, filepath)
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"✓ {filename} ({size_mb:.1f} MB)\n")
    except Exception as e:
        print(f"✗ Failed to download {filename}: {e}\n")

print("\n" + "="*50)
print("Download Summary:")
print("="*50)
for filename in files.values():
    filepath = os.path.join(model_dir, filename)
    if os.path.exists(filepath):
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"✓ {filename}: {size_mb:.1f} MB")
    else:
        print(f"✗ {filename}: MISSING")

Downloading HuBERT model files...

Downloading config.json...

✓ config.json (0.0 MB)

Downloading preprocessor_config.json...
✓ preprocessor_config.json (0.0 MB)

Downloading pytorch_model.bin...
✓ pytorch_model.bin (360.1 MB)


Download Summary:
✓ config.json: 0.0 MB
✓ preprocessor_config.json: 0.0 MB
✓ pytorch_model.bin: 360.1 MB


In [15]:
# Cell 2: Verify downloads and load model
import os
from transformers import HubertModel, HubertConfig

model_dir = "./hubert_model"

# Check all files exist
required_files = ["config.json", "preprocessor_config.json", "pytorch_model.bin"]
all_present = all(os.path.exists(os.path.join(model_dir, f)) for f in required_files)

if not all_present:
    print("❌ Not all files downloaded. Please run Cell 1 again.")
else:
    print("✓ All files present. Loading model...\n")

    # Load configuration
    config = HubertConfig.from_pretrained(
        model_dir,
        local_files_only=True
    )
    config.num_hidden_layers = 10

    # Load model
    hubert = HubertModel.from_pretrained(
        model_dir,
        config=config,
        local_files_only=True
    )

    print("✓ HuBERT loaded successfully!")
    print(f"  - Model layers: {config.num_hidden_layers}")
    print(f"  - Hidden size: {config.hidden_size}")

✓ All files present. Loading model...



Some weights of the model checkpoint at ./hubert_model were not used when initializing HubertModel: ['encoder.layers.11.attention.q_proj.weight', 'encoder.layers.11.attention.out_proj.bias', 'encoder.layers.11.attention.v_proj.weight', 'encoder.layers.10.attention.v_proj.weight', 'encoder.layers.11.final_layer_norm.bias', 'encoder.layers.10.layer_norm.weight', 'encoder.layers.11.attention.k_proj.bias', 'encoder.layers.11.attention.v_proj.bias', 'encoder.layers.10.attention.out_proj.bias', 'encoder.layers.10.feed_forward.intermediate_dense.weight', 'encoder.layers.10.feed_forward.intermediate_dense.bias', 'encoder.layers.11.feed_forward.intermediate_dense.weight', 'encoder.layers.10.attention.k_proj.bias', 'encoder.layers.11.layer_norm.weight', 'encoder.layers.11.attention.q_proj.bias', 'encoder.layers.10.attention.v_proj.bias', 'encoder.layers.11.feed_forward.intermediate_dense.bias', 'encoder.layers.10.feed_forward.output_dense.weight', 'encoder.layers.10.final_layer_norm.weight', 'en

✓ HuBERT loaded successfully!
  - Model layers: 10
  - Hidden size: 768


In [17]:
# Initialize the model
model = ASR()
model = model.to(device).train()

Loading HuBERT from ./hubert_model...


Some weights of the model checkpoint at ./hubert_model were not used when initializing HubertModel: ['encoder.layers.11.attention.q_proj.weight', 'encoder.layers.11.attention.out_proj.bias', 'encoder.layers.11.attention.v_proj.weight', 'encoder.layers.10.attention.v_proj.weight', 'encoder.layers.11.final_layer_norm.bias', 'encoder.layers.10.layer_norm.weight', 'encoder.layers.11.attention.k_proj.bias', 'encoder.layers.11.attention.v_proj.bias', 'encoder.layers.10.attention.out_proj.bias', 'encoder.layers.10.feed_forward.intermediate_dense.weight', 'encoder.layers.10.feed_forward.intermediate_dense.bias', 'encoder.layers.11.feed_forward.intermediate_dense.weight', 'encoder.layers.10.attention.k_proj.bias', 'encoder.layers.11.layer_norm.weight', 'encoder.layers.11.attention.q_proj.bias', 'encoder.layers.10.attention.v_proj.bias', 'encoder.layers.11.feed_forward.intermediate_dense.bias', 'encoder.layers.10.feed_forward.output_dense.weight', 'encoder.layers.10.final_layer_norm.weight', 'en

✓ HuBERT loaded!


In [21]:
train_dataset = SpeechTextDataset(DATA_DIR, 'train')
val_dataset = SpeechTextDataset(DATA_DIR, 'dev')

train_dataloader = DataLoader(train_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0, # Changed from 2 to 0 to prevent pickling issues
                        drop_last=True,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)

val_dataloader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=0, # Changed from 2 to 0 to prevent pickling issues
                        drop_last=False,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)

In [22]:
# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [23]:
import time
def train_loop(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    report_step = max(1, len(dataloader) // 5)  # Report 5 times per epoch
    model.train()

    start_time = time.time()

    for batch_i, batch in enumerate(dataloader):
        batch_start = time.time()

        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        text = batch['text']

        # Print first batch info
        if batch_i == 0:
            print(f"    First batch shape: {input_values.shape}")
            print(f"    Device: {input_values.device}")

        # Compute loss
        loss = model(input_values, text, attention_mask)

        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        optimizer.zero_grad()

        batch_time = time.time() - batch_start

        if batch_i % report_step == 0 or batch_i == 0:
            loss_val = loss.item()
            current = batch_i * len(input_values)
            print(f"    Batch {batch_i}/{len(dataloader)} - Loss: {loss_val:>7f} - Time: {batch_time:.2f}s - [{current:>5d}/{size:>5d}]")

    total_time = time.time() - start_time
    print(f"    Epoch completed in {total_time:.1f}s ({total_time/60:.1f}min)")

In [None]:
best_weights = model.state_dict()
best_wer = 99999
for e in range(epoch):
    print(f"Epoch [{e+1}/{epoch}] - {datetime.now()}")
    train_loop(train_dataloader, model, optimizer, e)
    val_wer, val_cer = test_loop(val_dataloader, model)
    print(f"    Validation WER: {(100*val_wer):>0.2f}%, CER: {(100*val_cer):>0.2f}%")
    if val_wer < best_wer:
        best_weights = model.state_dict()



Epoch [1/20] - 2025-11-14 02:54:10.992956
    First batch shape: torch.Size([8, 236320])
    Device: cuda:0
    HuBERT (upstream) time: 612948.62 ms
    Downsample time: 46.82 ms
    LSTM time: 207.89 ms
    Logit time: 1.25 ms
    Batch 0/358 - Loss: 4.511792 - Time: 631.10s - [    0/ 2866]


In [None]:
# save model checkpoint

torch.save({'state_dict':model.state_dict()}, "best_model.pt")

In [None]:
# Load your best model
model.load_state_dict(torch.load("best_model.pt", map_location="cpu")['state_dict'])

# Load test dataset
test_dataset = SpeechTextDataset(DATA_DIR, 'test')
submission_order_file = Path(DATA_DIR) / "test_submission_order.txt"

# Generate test transcriptions
model.eval()
test_transcriptions = {}

print("Generating test transcriptions...")
with torch.no_grad():
    for i in tqdm.tqdm(range(len(test_dataset)), desc="Transcribing test set"):
        sample = test_dataset[i]
        wav_file = test_dataset.wav_files[i]
        filename = wav_file.stem

        input_values = sample['wav'].unsqueeze(0).to(device)
        attention_mask = torch.ones_like(input_values).to(device)

        pred_texts, _ = model.predict(input_values, attention_mask)
        prediction = pred_texts[0] if len(pred_texts) > 0 else ""
        test_transcriptions[filename] = prediction

# Create submission file in correct order
with open(submission_order_file, 'r') as f:
    submission_order = [line.strip() for line in f.readlines()]

output_lines = []
for audio_id in submission_order:
    if audio_id in test_transcriptions:
        output_lines.append(f"{audio_id}|{test_transcriptions[audio_id]}")
    else:
        output_lines.append(f"{audio_id}|")

with open("asr_submission.txt", "w") as f:
    for line in output_lines:
        f.write(line + "\n")

print("Test transcriptions saved to asr_submission.txt")

### !!!! Please make sure that the output file format is what's expected.

# Task
Optimize the performance of the ASR model's training process by reducing the complexity of the LSTM decoder and potentially increasing the batch size. Profile the `ASR` module's `forward` pass to identify bottlenecks, monitor local GPU usage, and then analyze the impact of these changes on training speed and resource utilization to confirm if the `ASR` module was the bottleneck and how it was alleviated.

## Gather Local GPU Information

### Subtask:
Share the specifications of your local GPU.


**Reasoning**:
To gather the specifications of the GPU available in the current Colab environment, I will execute a shell command to retrieve the GPU name and the total VRAM.



In [25]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram_total_bytes = torch.cuda.get_device_properties(0).total_memory
    vram_total_gb = vram_total_bytes / (1024**3)
    print(f"GPU Manufacturer and Model: {gpu_name}")
    print(f"Total VRAM: {vram_total_gb:.2f} GB")
else:
    print("No GPU available.")


GPU Manufacturer and Model: NVIDIA GeForce RTX 2070 SUPER
Total VRAM: 8.00 GB


## Add Profiling to ASR Forward Pass

### Subtask:
Insert `torch.cuda.Event` calls within the `ASR` model's `forward` method to measure the execution time of different components like the `upstream` HuBERT model, `downsample` convolution, `lstm` layers, and `logit` calculation.


**Reasoning**:
To measure the execution time of different components within the `ASR` model's `forward` pass, I need to modify the `ASR` class definition from cell `KkQNdSF6rdbi` to include `torch.cuda.Event` calls around each targeted operation. This will allow profiling the `upstream` HuBERT model, `downsample` convolution, `lstm` layers, and `logit` calculation.



In [26]:
import torch
import torch.nn as nn
from transformers import HubertModel, HubertConfig
from torchaudio.models.decoder import ctc_decoder

class ASR(nn.Module):
    def __init__(self,
                 hubert_model_path='./hubert_model',
                 hubert_layer=10,
                 lstm_hidden_size=768,
                 lstm_num_layers=2,
                 lstm_dropout=0.1):
        super().__init__()

        self.token_processor = CharacterTokenizer()

        print(f"Loading HuBERT from {hubert_model_path}...")

        config = HubertConfig.from_pretrained(
            hubert_model_path,
            local_files_only=True
        )
        config.num_hidden_layers = hubert_layer

        self.upstream = HubertModel.from_pretrained(
            hubert_model_path,
            config=config,
            local_files_only=True
        )

        print("✓ HuBERT loaded!")

        # Freeze upstream
        for param in self.upstream.parameters():
            param.requires_grad = False

        # LSTM decoder
        ssl_output_size = 768
        lstm_input_size = 768
        self.lstm_hidden_size = lstm_hidden_size

        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            dropout=lstm_dropout if lstm_num_layers > 1 else 0,
            batch_first=True,
            bidirectional=True
        )

        self.lstm_output_size = lstm_hidden_size * 2
        self.logit = nn.Linear(self.lstm_output_size, len(self.token_processor.vocab))

        self.downsample = nn.Conv1d(ssl_output_size, lstm_input_size,
                                     kernel_size=2, stride=2, padding=0, dilation=1)

        self.loss = nn.CTCLoss(blank=self.token_processor.blank_id(),
                               zero_infinity=True, reduction='mean')

        self.ctc_decoder = ctc_decoder(lexicon=None,
                                       tokens=self.token_processor.vocab,
                                       lm=None,
                                       lm_dict=None,
                                       nbest=1,
                                       beam_size=5,
                                       blank_token="blank",
                                       sil_token=" ")

    def forward(self, input_values, text, attention_mask=None, profile=False):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        text_idxs, text_lens = self.token_processor.enocode_torch_batch(text)
        text_idxs = text_idxs.to(input_values.device).to(torch.int32)
        text_lens = text_lens.to(input_values.device).to(torch.int32)

        # Initialize events for profiling if requested
        if profile and torch.cuda.is_available():
            start_hubert = torch.cuda.Event(enable_timing=True)
            end_hubert = torch.cuda.Event(enable_timing=True)
            start_downsample = torch.cuda.Event(enable_timing=True)
            end_downsample = torch.cuda.Event(enable_timing=True)
            start_lstm = torch.cuda.Event(enable_timing=True)
            end_lstm = torch.cuda.Event(enable_timing=True)
            start_logit = torch.cuda.Event(enable_timing=True)
            end_logit = torch.cuda.Event(enable_timing=True)
        else:
            profile = False # Disable profiling if CUDA not available

        # SSL upstream encoding
        if profile: start_hubert.record()
        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state
        if profile:
            end_hubert.record()
            torch.cuda.synchronize()
            print(f"    HuBERT (upstream) time: {start_hubert.elapsed_time(end_hubert):.2f} ms")

        # Downsample
        if profile: start_downsample.record()
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)
        if profile:
            end_downsample.record()
            torch.cuda.synchronize()
            print(f"    Downsample time: {start_downsample.elapsed_time(end_downsample):.2f} ms")

        # LSTM
        if profile: start_lstm.record()
        lstm_out, _ = self.lstm(source_encodings)
        if profile:
            end_lstm.record()
            torch.cuda.synchronize()
            print(f"    LSTM time: {start_lstm.elapsed_time(end_lstm):.2f} ms")

        # Logits
        if profile: start_logit.record()
        logits = self.logit(lstm_out)
        if profile:
            end_logit.record()
            torch.cuda.synchronize()
            print(f"    Logit time: {start_logit.elapsed_time(end_logit):.2f} ms")

        # Log probabilities for CTC
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Calculate source lengths
        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.to(torch.int32)

        # Transpose for CTC: (T, B, C)
        log_probs = log_probs.transpose(0, 1)

        # CTC loss
        loss = self.loss(log_probs, text_idxs, source_lengths, text_lens)

        return loss

    def predict(self, input_values, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values)

        # SSL encoding
        source_encodings = self.upstream(input_values, attention_mask=attention_mask).last_hidden_state

        # Downsample
        source_encodings = self.downsample(source_encodings.transpose(1, 2)).transpose(1, 2)

        # LSTM
        lstm_out, _ = self.lstm(source_encodings)

        # Logits
        logits = self.logit(lstm_out)

        # Log probabilities
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Calculate source lengths
        audio_lengths = attention_mask.sum(dim=1)
        hubert_lengths = ((audio_lengths - 400) // 320) + 1
        source_lengths = hubert_lengths // 2
        source_lengths = source_lengths.cpu().to(torch.int32)

        # CTC decode
        pred_texts = self.ctc_decoder(log_probs.cpu(), source_lengths.cpu())
        pred_texts = [self.token_processor.decode(pred_text[0].tokens) for pred_text in pred_texts]

        return pred_texts, source_encodings

# Note: This code block replaces the previous ASR class definition.
# You will need to re-initialize your model after running this cell.

## Reduce LSTM Complexity

### Subtask:
Reduce the complexity of the LSTM decoder by modifying its hidden size and number of layers, then re-initialize the model and optimizer.


**Reasoning**:
To reduce LSTM complexity, I will re-instantiate the ASR model with a smaller hidden size (256) and fewer layers (1) for the LSTM, then move it to the GPU and re-initialize the optimizer.



In [27]:
model = ASR(lstm_hidden_size=256, lstm_num_layers=1)
model = model.to(device).train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print("ASR model re-initialized with reduced LSTM complexity (hidden_size=256, num_layers=1).")
print(f"New optimizer created for the re-initialized model.")

Loading HuBERT from ./hubert_model...


Some weights of the model checkpoint at ./hubert_model were not used when initializing HubertModel: ['encoder.layers.11.attention.q_proj.weight', 'encoder.layers.11.attention.out_proj.bias', 'encoder.layers.11.attention.v_proj.weight', 'encoder.layers.10.attention.v_proj.weight', 'encoder.layers.11.final_layer_norm.bias', 'encoder.layers.10.layer_norm.weight', 'encoder.layers.11.attention.k_proj.bias', 'encoder.layers.11.attention.v_proj.bias', 'encoder.layers.10.attention.out_proj.bias', 'encoder.layers.10.feed_forward.intermediate_dense.weight', 'encoder.layers.10.feed_forward.intermediate_dense.bias', 'encoder.layers.11.feed_forward.intermediate_dense.weight', 'encoder.layers.10.attention.k_proj.bias', 'encoder.layers.11.layer_norm.weight', 'encoder.layers.11.attention.q_proj.bias', 'encoder.layers.10.attention.v_proj.bias', 'encoder.layers.11.feed_forward.intermediate_dense.bias', 'encoder.layers.10.feed_forward.output_dense.weight', 'encoder.layers.10.final_layer_norm.weight', 'en

✓ HuBERT loaded!
ASR model re-initialized with reduced LSTM complexity (hidden_size=256, num_layers=1).
New optimizer created for the re-initialized model.


**Reasoning**:
The user provided a new, more detailed subtask description for cell `2b64bb24`. I need to update the content of this `text_cell` accordingly.



In [28]:
## Reduce LSTM Complexity

### Subtask:
Reduce the complexity of the LSTM decoder by modifying its hidden size and number of layers, then re-initialize the model and optimizer.

#### Instructions
1. Modify the `ASR` class initialization by changing the `lstm_hidden_size` parameter from `768` to a smaller value (e.g., `512` or `256`).
2. Optionally, change the `lstm_num_layers` parameter from `2` to `1` to further reduce complexity.
3. Instantiate the `ASR` model with these new parameters (e.g., `model = ASR(lstm_hidden_size=256, lstm_num_layers=1)`).
4. Move the `model` to the appropriate device (e.g., `model = model.to(device).train()`).
5. Re-initialize the optimizer with the new model parameters (e.g., `optimizer = torch.optim.Adam(model.parameters(), lr=lr)`).

SyntaxError: invalid syntax (4039499260.py, line 4)

**Reasoning**:
The previous attempt to update cell `7bc09b7c` failed with a `SyntaxError` because the `cell_type` was incorrectly set to `python` while the content was markdown. I need to correct the `cell_type` to `markdown` to properly update the text cell.



# Task
**Consider Increasing Batch Size**: Update the `batch_size` variable from 4 to 8, then re-initialize `train_dataloader` and `val_dataloader` with the new batch size.

## Consider Increasing Batch Size

### Subtask:
Update the `batch_size` variable from 4 to 8, then re-initialize `train_dataloader` and `val_dataloader` with the new batch size.


## Consider Increasing Batch Size

### Subtask:
Update the `batch_size` variable from 4 to 8, then re-initialize `train_dataloader` and `val_dataloader` with the new batch size.

#### Instructions
1. Update the `batch_size` variable to 8.
2. Re-initialize the `train_dataloader` and `val_dataloader` using the updated `batch_size`.

**Reasoning**:
The user wants to add instructions for increasing the batch size and re-initializing dataloaders. I need to update the `batch_size` variable and then re-initialize the `train_dataloader` and `val_dataloader` as instructed.



In [30]:
batch_size = 8 # Update batch_size to 8

train_dataloader = DataLoader(train_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=0,
                        drop_last=True,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)

val_dataloader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=0,
                        drop_last=False,
                        pin_memory=True,
                        collate_fn=SpeechTextDataset.collate)

print(f"Batch size updated to {batch_size}.")
print("Train and validation dataloaders re-initialized with the new batch size.")

Batch size updated to 8.
Train and validation dataloaders re-initialized with the new batch size.


## Execute Training with Profiling

### Subtask:
Run the training loop again, ensuring the 'ASR' model's 'forward' method is called with 'profile=True' for the first batch to get detailed timing information for different components, and monitor local GPU usage during this process.


**Reasoning**:
The user wants to modify the `train_loop` function to enable profiling for the first batch of the first epoch only. I need to update the `train_loop` function definition to include a mechanism to pass `profile=True` for the first batch and `profile=False` otherwise. Additionally, the function needs to accept an `epoch_num` parameter to know if it's the first epoch.



In [31]:
import time
def train_loop(dataloader, model, optimizer, epoch_num):
    size = len(dataloader.dataset)
    report_step = max(1, len(dataloader) // 5)  # Report 5 times per epoch
    model.train()

    start_time = time.time()

    for batch_i, batch in enumerate(dataloader):
        batch_start = time.time()

        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        text = batch['text']

        # Enable profiling for the first batch of the first epoch
        profile_batch = (epoch_num == 0 and batch_i == 0)

        # Print first batch info
        if batch_i == 0:
            print(f"    First batch shape: {input_values.shape}")
            print(f"    Device: {input_values.device}")

        # Compute loss
        loss = model(input_values, text, attention_mask, profile=profile_batch)

        # Backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        optimizer.zero_grad()

        batch_time = time.time() - batch_start

        if batch_i % report_step == 0 or batch_i == 0:
            loss_val = loss.item()
            current = batch_i * len(input_values)
            print(f"    Batch {batch_i}/{len(dataloader)} - Loss: {loss_val:>7f} - Time: {batch_time:.2f}s - [{current:>5d}/{size:>5d}]")

    total_time = time.time() - start_time
    print(f"    Epoch completed in {total_time:.1f}s ({total_time/60:.1f}min)")

## Execute Training with Profiling

### Subtask:
Run the training loop again, ensuring the 'ASR' model's 'forward' method is called with 'profile=True' for the first batch to get detailed timing information for different components, and monitor local GPU usage during this process.

#### Instructions
1. Modify the `train_loop` function to call `model.forward(input_values, text, attention_mask, profile=True)` for the first batch of the first epoch, and `profile=False` for subsequent batches. This will print profiling information for the first batch.
2. Start the training process by executing the training loop. While the training is running, use your system's GPU monitoring tools (e.g., `nvidia-smi` for NVIDIA GPUs in a separate terminal) to observe GPU utilization and memory usage. This will help understand if the GPU is being fully utilized with the increased batch size.

**Reasoning**:
The user wants to update a `text_cell` with new content. I need to generate an `addTextCellAction` for the specified cell ID with the new markdown content.

