# Sequential CRNN

This notebook aims to train the model for the STT task with a sequential CRNN model. We will first feed the processed audio features to a convolutional layer. The output features map from the convolutional layer will then be fed to the recurrent layer, before finally leading to the output. In contrast, a parallel CRNN will have both convolutional layer and recurrent layer run in parallel, before feature fusion into an output.

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, Subset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import h5py
import numpy as np
import json
from tqdm.notebook import tqdm
import librosa
import Levenshtein as lev

## Vocab Building & Tokenization

### Vocab Functions

In [2]:
def build_vocab():
    """
    Builds a fixed vocabulary of lowercase English letters, space, and a special
    '<blank>' token for CTC.
    Returns:
        dict: A dictionary mapping characters to their integer encodings.
    """
    alphabet = "abcdefghijklmnopqrstuvwxyz "
    vocab = {"<blank>": 0}  # CTC blank token
    for idx, char in enumerate(alphabet, start=1):  # Starting from 1 to reserve 0 for blank
        vocab[char] = idx
    return vocab

def save_vocab(vocab, filepath):
    with open(filepath, 'w') as f:
        json.dump(vocab, f)

def load_vocab(filepath):
    with open(filepath, 'r') as f:
        vocab = json.load(f)
    return vocab

### Tokenization Functions

In [3]:
def encode_label(label, vocab):
    return [vocab[char] for char in label]

def decode_label(encoded_label, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return ''.join(inv_vocab[id] for id in encoded_label if id not in (0, 1))  # Skip pad and blank tokens

### Building the Vocab (use only if you don't already have the vocab built!)

In [4]:
vocab = build_vocab()
print(vocab)
save_vocab(vocab, 'vocab.json')

{'<blank>': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, ' ': 27}


### Loading the vocab

In [5]:
vocab = load_vocab('vocab.json')
VOCAB_SIZE = len(vocab)

## Dataset Class Definition

In [6]:
class SpeechDataset(Dataset):
    def __init__(self, hdf5_path, vocab, max_length_frames=247):
        super(SpeechDataset, self).__init__()
        self.hdf5_path = hdf5_path
        self.vocab = vocab
        # Maximum sequence length for padding, 247 = 8s @ 16000 Hz, 512 hop length for MFCC
        self.max_length_frames = max_length_frames  
        
        self.keys = []  # Initialize an empty list for valid keys
        with h5py.File(hdf5_path, 'r') as file:
            for key in file.keys():
                # Check if 'label' exists for this key
                if 'label' in file[key]:
                    self.keys.append(key)
                else:
                    print(f"Skipping {key} due to missing label.")

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        with h5py.File(self.hdf5_path, 'r') as f:
            key = self.keys[idx]
            # Or use whatever features you need
            # melspectrogram = np.array(f[key]['melspectrogram']).astype(np.float32)
            mfccs = np.array(f[key]['mfccs']).astype(np.float32)
            
            # Calculate the number of padding frames needed
            # padding_length = self.max_length_frames - melspectrogram.shape[1]
            padding_length = self.max_length_frames - mfccs.shape[1]
            if padding_length > 0:
                # Pad the sequence to max_length_frames if it's shorter
                # melspectrogram = np.pad(melspectrogram, ((0,0), (0, padding_length)), mode='constant', constant_values=0)
                mfccs = np.pad(mfccs, ((0,0), (0, padding_length)), mode='constant', constant_values=0)
            elif padding_length < 0:
                # Truncate the sequence to max_length_frames if it's longer
                # melspectrogram = melspectrogram[:, :self.max_length_frames]
                mfccs = mfccs[:, :self.max_length_frames]

            # melspectrogram = np.expand_dims(melspectrogram, 0)  # Shape: [1, Freq, Time]
            mfccs = np.expand_dims(mfccs, 0)  # Shape: [1, Freq, Time]

            label_str = f[key]['label'][()].decode('utf-8')
            label = encode_label(label_str, self.vocab)
            input_length = self.max_length_frames
            label_length = len(label)

        # return torch.tensor(melspectrogram), torch.tensor(label, dtype=torch.int), input_length, label_length
        return torch.tensor(mfccs), torch.tensor(label, dtype=torch.int64), self.max_length_frames, torch.tensor(label_length, dtype=torch.int64)

## CRNN Class Definition

In [7]:
class CRNN(nn.Module):
    def __init__(self, num_mfcc_features, hidden_size, num_layers=2):
        super(CRNN, self).__init__()
        self.fc_out_size = VOCAB_SIZE  # Number of output classes, including the blank for CTC

        # Convolutional layers with Batch Normalization and Dropout
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),  # BatchNorm after convolution
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),  # Dropout after pooling
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),  # BatchNorm after convolution
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),  # Dropout after pooling
        )

        # Calculate the size of the RNN's input. Assuming the input MFCCs have shape [Batch, 1, Time, Features]
        # and after convolutions and pooling, the feature (height) dimension is reduced by a factor of 4,
        # and the time (width) dimension is also reduced. The factor reduction in the time dimension depends on
        # the length of your input sequences and the exact architecture of your convolutional layers.
        self.rnn_input_size = 64 * (num_mfcc_features // 4)  # Adjust based on your pooling and convolution operations

        # Recurrent layers
        self.rnn = nn.GRU(
            input_size=self.rnn_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
        )

        # Fully connected layer
        self.fc = nn.Linear(hidden_size, self.fc_out_size)

    def forward(self, x):
        # Apply convolutional layers
        x = self.conv(x)  
        # Prepare the output of the CNN for the RNN
        batch, channels, height, width = x.size()
        x = x.permute(0, 3, 1, 2).contiguous()  # Change to [Batch, Width, Channels, Height]
        x = x.view(batch, width, -1)  # Flatten the feature maps
        
        # Apply RNN
        output, _ = self.rnn(x)
        
        # Apply fully connected layer
        output = self.fc(output)
        
        return output

## Inference Functions

### Decoder Function

In [8]:
def greedy_decoder(output, labels, blank_label=0):
    """
    Decodes the output of a CTC network and returns the string representation.
    
    Args:
        output (torch.Tensor): The raw output from the CRNN model. Shape: [T, N, C] where
            T is the timestep, N is the batch size, and C is the number of classes (including the blank).
        labels (List[str]): The encoded labels as a list of strings.
        label_lengths (torch.Tensor): The length of each label in the batch.
        blank_label (int): The index of the blank label used in CTC. Defaults to 0.
    
    Returns:
        List[str]: The decoded strings.
    """
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    for i, args in enumerate(arg_maxes):
        decode = []
        for j, index in enumerate(args):
            if index != blank_label:  # Not a blank
                if j != 0 and index == args[j - 1]:
                    continue  # Repeated character
                decode.append(index.item())
        decodes.append(''.join([labels[k] for k in decode]))
    return decodes


### Preprocessing Function

In [9]:
def preprocess_audio(audio_path, sampling_rate=16000, n_mfcc=13, max_length_frames=247):
    # Load the audio file
    signal, sr = librosa.load(audio_path, sr=sampling_rate)
    
    # Extract MFCC features from the audio signal
    mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc)
    
    # Pad or truncate the MFCC sequences to the fixed length
    padding_length = max_length_frames - mfccs.shape[1]
    if padding_length > 0:
        # Pad the sequence if shorter
        mfccs = np.pad(mfccs, ((0, 0), (0, padding_length)), mode='constant', constant_values=0)
    elif padding_length < 0:
        # Truncate the sequence if longer
        mfccs = mfccs[:, :max_length_frames]
    
    # Add a channel dimension and return
    mfccs = np.expand_dims(mfccs, axis=0)  # Shape: [1, n_mfcc, max_length_frames]
    return mfccs

## Training functions

### Saver/Loader functions

In [10]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    """
    Saves the model and training parameters at the specified checkpoint.
    """
    torch.save(state, filename)

def load_checkpoint(checkpoint, model, optimizer):
    """
    Loads the model and training parameters from a specified checkpoint.
    """
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])


### Custom Collate Function
This is necessary because the default data collation function (default_collate) attempts to stack all tensors in a batch along a new dimension, but this requires all tensors to have the same shape. But we are using CTC, so our tensors will be a little different

In [11]:
def custom_collate_fn(batch):
    mfccs, labels, input_lengths, label_lengths = zip(*batch)
    
    # Pad the mfcc sequences (this might be redundant with your existing padding, depending on your implementation)
    mfccs_padded = pad_sequence(mfccs, batch_first=True, padding_value=0)
    
    # Pad the label sequences to the maximum label length in the batch
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)  # Assumes 0 is an appropriate pad value
    
    input_lengths = torch.tensor(input_lengths, dtype=torch.long)
    label_lengths = torch.tensor(label_lengths, dtype=torch.long)
    
    return mfccs_padded, labels_padded, input_lengths, label_lengths

### Eval Metrics

In [12]:
def cer(target, prediction):
    """
    Computes the Character Error Rate, defined as the edit distance between the
    two given strings normalized by the length of the true string.
    """
    char_error_rate = lev.distance(target, prediction) / max(len(target), 1)
    return char_error_rate

def wer(target, prediction):
    """
    Computes the Word Error Rate, defined as the edit distance between the
    word sequences of the two given strings normalized by the number of words
    in the true string.
    """
    target_words = target.split()
    prediction_words = prediction.split()
    word_error_rate = lev.distance(' '.join(target_words), ' '.join(prediction_words)) / max(len(target_words), 1)
    return word_error_rate

### Train Function

In [13]:
def train_and_validate(model, device, train_loader, val_loader, optimizer, epochs, start_epoch=0):
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    
    # Define inv_vocab here
    inv_vocab = {v: k for k, v in vocab.items()}

    for epoch in range(start_epoch, epochs):
        model.train()
        total_train_loss = 0.0
        audio_no = 0
        for mels, labels, input_lengths, label_lengths in train_loader:
            mels = mels.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(mels)
            outputs = F.log_softmax(outputs, dim=2)
            processed_seq_length = 61
            input_lengths = torch.full(size=(mels.size(0),), fill_value=processed_seq_length, dtype=torch.long).to(device)
            loss = criterion(outputs.permute(1, 0, 2), labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            audio_no += 1
            print(f'Epoch {epoch}, Audio Batch No: {audio_no} processed!')

        avg_train_loss = total_train_loss / len(train_loader)
        print(f'Epoch {epoch}, Training Loss: {avg_train_loss:.4f}')
        
        # Validation
        model.eval()
        total_val_loss = 0.0
        total_cer = 0.0
        total_wer = 0.0
        total_samples = 0
        audio_no = 0
        
        with torch.no_grad():
            for mels, labels, input_lengths, label_lengths in val_loader:
                mels = mels.to(device)
                labels = labels.to(device)
                outputs = model(mels)
                outputs = F.log_softmax(outputs, dim=2)
                processed_seq_length = 61
                input_lengths = torch.full(size=(mels.size(0),), fill_value=processed_seq_length, dtype=torch.long).to(device)
                loss = criterion(outputs.permute(1, 0, 2), labels, input_lengths, label_lengths)
                total_val_loss += loss.item()

                # Decoding and computing CER and WER
                decoded_preds = greedy_decoder(outputs, inv_vocab, blank_label=vocab["<blank>"])
                for i, label_tensor in enumerate(labels):
                    target = ''.join([inv_vocab.get(id.item(), '') for id in label_tensor if id.item() not in (0, 1)])
                    prediction = decoded_preds[i]
                    total_cer += cer(target, prediction)
                    total_wer += wer(target, prediction)
                total_samples += len(labels)
                audio_no += 1
                print(f'Epoch {epoch}, Audio Batch No: {audio_no} evaluated!')

        
        avg_val_loss = total_val_loss / len(val_loader)
        avg_cer = total_cer / total_samples
        avg_wer = total_wer / total_samples
        print(f'Epoch {epoch}, Validation Loss: {avg_val_loss:.4f}, CER: {avg_cer:.4f}, WER: {avg_wer:.4f}')
        
        # Save model at each epoch
        save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=f"./checkpoints/checkpoint_epoch_{epoch}.pth.tar")

        # Write metrics to a file
        with open("training_metrics.txt", "a") as file:
            file.write(f"Epoch {epoch}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, CER: {avg_cer:.4f}, WER: {avg_wer:.4f}\n")
            # file.write(f"Epoch {epoch}, Validation Loss: {avg_val_loss:.4f}, CER: {avg_cer:.4f}, WER: {avg_wer:.4f}\n")

### Train/Validate Split

In [14]:
# Splitting the dataset
def split_dataset(dataset, train_size=0.8):
    train_idx, val_idx = train_test_split(np.arange(len(dataset)), train_size=train_size, random_state=42)
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    return train_subset, val_subset

## Execution

### Params

In [15]:
# Paths declaration
hdf5_path = r"C:\Users\jonec\Documents\SUTD\T6\AI\STT\Recorded-Lecture-Transcription-STT\reduced_mfcc_dataset.h5"

# Model params declaration
learning_rate = 0.001
epochs = 200
batch_size = 1024
num_mfcc_features = 13
hidden_size = 256
num_layers = 2

In [16]:
# Set device to GPU if available
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset and DataLoader instantiation
dataset = SpeechDataset(hdf5_path, vocab)
# Split your dataset
train_subset, val_subset = split_dataset(dataset)

# Create DataLoaders for training and validation sets
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

# Load the checkpoint file. Use only if you are indeed loading from a checkpoint
# checkpoint = torch.load(r"./checkpoints/checkpoint_epoch_14.pth.tar") 

# Model initialisation
model = CRNN(num_mfcc_features=num_mfcc_features, hidden_size=hidden_size, num_layers=num_layers).to(device)

# Training execution
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Load the model and optimizer state from the checkpoint. Again, use only if you are indeed loading from a checkpoint
start_epoch = 0
# load_checkpoint(checkpoint, model, optimizer)

# Training and validation execution
train_and_validate(model, device, train_loader, val_loader, optimizer, epochs, start_epoch)

Using device: cuda


### Infer & Decode

In [None]:
# Example usage
audio_path = r"C:\Users\jonec\Documents\SUTD\T6\AI\Voice dataset\cv-corpus-4\clips\common_voice_en_12.mp3"
mfccs = preprocess_audio(audio_path)
mfccs_tensor = torch.tensor(mfccs).float()
mfccs_tensor = mfccs_tensor.unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
    output = model(mfccs_tensor)

# Assuming you have a list or dict `vocab` mapping indices to characters
inv_vocab = {i: char for char, i in vocab.items()}
decoded_output = greedy_decoder(output, inv_vocab, blank_label=vocab["<blank>"])

print("Transcription:", decoded_output)