In [None]:
"""
File Name: main.ipynb
Author: Eli Claggett
Date: Apr 2023

Description:
    In-class project implementing a bidirectional long short-term memory (BiLSTM) neural network for speech recognition
    
"""

In [None]:
# Imports
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
from ctcdecode import CTCBeamDecoder
# from torchsummaryX import summary
from torch.autograd import Variable
import torchaudio.transforms as tat
from collections import namedtuple
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn as nn
import pandas as pd
import numpy as np
import Levenshtein
import ctcdecode
import datetime
import warnings
import zipfile
import random
import torch
import wandb
import gc
import os

In [None]:
# Configuration
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

# ARPABET phoneme mapping
CMUdict_ARPAbet = {
    "" : " ",
    "[SIL]": "-", "NG": "G", "F" : "f", "M" : "m", "AE": "@", 
    "R"    : "r", "UW": "u", "N" : "n", "IY": "i", "AW": "W", 
    "V"    : "v", "UH": "U", "OW": "o", "AA": "a", "ER": "R", 
    "HH"   : "h", "Z" : "z", "K" : "k", "CH": "C", "W" : "w", 
    "EY"   : "e", "ZH": "Z", "T" : "t", "EH": "E", "Y" : "y", 
    "AH"   : "A", "B" : "b", "P" : "p", "TH": "T", "DH": "D", 
    "AO"   : "c", "G" : "g", "L" : "l", "JH": "j", "OY": "O", 
    "SH"   : "S", "D" : "d", "AY": "Y", "S" : "s", "IH": "I",
    "[SOS]": "[SOS]", "[EOS]": "[EOS]"
}

CMUdict = list(CMUdict_ARPAbet.keys())
ARPAbet = list(CMUdict_ARPAbet.values())

PHONEMES = CMUdict[:-2]
LABELS = ARPAbet[:-2]
p2iMap = {p:i for i, p in enumerate(PHONEMES)}

In [None]:
# Create audio dataset
class AudioDataset(torch.utils.data.Dataset):

    def __init__(self, partition, phoneme_dict = p2iMap): 
        self.phoneme_dict = phoneme_dict

        self.mfcc_dir       =  partition+"/mfcc/"         
        self.transcript_dir = partition+"/transcript/" 

        mfcc_names          = sorted(os.listdir(self.mfcc_dir)) 
        self.mfcc_names = [self.mfcc_dir + i for i in mfcc_names]
        
        transcript_names    = sorted(os.listdir(self.transcript_dir)) 
        self.transcript_names = [self.transcript_dir + i for i in transcript_names]
        
        assert len(mfcc_names) == len(transcript_names)
        self.length = len(mfcc_names)
        
    def __len__(self):
        return self.length

    def __getitem__(self, ind):
        mfcc = np.load(self.mfcc_names[ind])
        
        raw_phoneme = np.load(self.transcript_names[ind])
        
        # Get phonemes without SOS and EOS tokens
        phoneme = np.array([self.phoneme_dict[i] for i in raw_phoneme[1:-1]])

        return torch.tensor(mfcc), torch.tensor(phoneme)

    def collate_fn(self,batch):
        '''
        1.  Extract the features and labels from each batch 
        2.  Pad both features and labels
        3.  Perform transforms on batches (if desired)
        4.  Return batch of features, labels, feature lengths, and label lengths.
        '''

        batch_mfcc = [x for x, y in batch]
        batch_transcript = [y for x, y in batch]
        
        norm_mfcc = []
        for mfcc in batch_mfcc:
            mfcc -= torch.mean(mfcc, axis=0, keepdims=True)
            mfcc /= torch.std(mfcc, axis=0, keepdims=True)
            norm_mfcc.append(mfcc)
        
        batch_mfcc_pad = pad_sequence(norm_mfcc, batch_first=True)
        lengths_mfcc = [i.shape[0] for i in norm_mfcc]

        batch_transcript_pad = pad_sequence(batch_transcript, batch_first=True)
        lengths_transcript = [i.shape[0] for i in batch_transcript]

        return batch_mfcc_pad, batch_transcript_pad, torch.tensor(lengths_mfcc), torch.tensor(lengths_transcript)

In [None]:
# Create audio dataset for test data
class AudioDatasetTest(torch.utils.data.Dataset):

    def __init__(self, partition, phoneme_dict = p2iMap): 
        self.phoneme_dict = phoneme_dict
        
        self.mfcc_dir       =  partition+"/mfcc/" 
        
        mfcc_names          = sorted(os.listdir(self.mfcc_dir)) 
        self.mfcc_names = [self.mfcc_dir + i for i in mfcc_names]
        
        self.length = len(mfcc_names)
        
    def __len__(self):
        return self.length

    def __getitem__(self, ind):
        mfcc = np.load(self.mfcc_names[ind])
        return torch.tensor(mfcc)

    def collate_fn(self,batch):
        '''
        1.  Extract the features and labels from each batch 
        2.  Pad both features and labels
        3.  Perform transforms on batches (if desired)
        4.  Return batch of features, labels, feature lengths, and label lengths.
        '''
        
        batch_mfcc = [x for x in batch]
        
        # Do cepestral norm
        norm_mfcc = []
        for mfcc in batch_mfcc:
            mfcc = mfcc - torch.mean(mfcc, axis=0, keepdims=True)
            mfcc = mfcc / torch.std(mfcc, axis=0, keepdims=True)
            norm_mfcc.append(mfcc)
        
        batch_mfcc_pad = pad_sequence(norm_mfcc, batch_first=True)
        lengths_mfcc = [i.shape[0] for i in norm_mfcc]

        return batch_mfcc_pad, torch.tensor(lengths_mfcc)

In [None]:
# Let the computer rest
import gc 
gc.collect()

In [None]:
config = {
    'batch_size': 32,
    'feature_dim': 27,
    'rnn_hidden_size': 256,
    'rnn_num_layers': 4,
    'rnn_dropout': 0.35,
    'rnn_residual': True,
    'learning_rate': 1e-3,
    'epochs' : 5,
    'beam_width': 16,
    'clip_thresh': 1.,
    'context': 20,
    'architecture': 'apc'
}

root = './data'

# Set of transforms to apply to the dataset
transforms = []

In [None]:
# Populate datasets
train_data = AudioDataset(partition='./data/dev-clean')
val_data = AudioDataset(partition='./data/dev-clean')
test_data = AudioDatasetTest(partition='./data/test-clean')

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    dataset     = train_data, 
    num_workers = 0,
    batch_size  = config['batch_size'], 
    pin_memory  = True,
    shuffle     = True,
    collate_fn  = train_data.collate_fn
)
val_loader = torch.utils.data.DataLoader(
    dataset     = val_data, 
    num_workers = 0,
    batch_size  = config['batch_size'],
    pin_memory  = True,
    shuffle     = False,
    collate_fn  = val_data.collate_fn
)
test_loader = torch.utils.data.DataLoader(
    dataset     = test_data, 
    num_workers = 0, 
    batch_size  = config['batch_size'], 
    pin_memory  = True, 
    shuffle     = False,
    collate_fn  = test_data.collate_fn
)

print("Batch size: ", config['batch_size'])
print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))
print("Val dataset samples = {}, batches = {}".format(val_data.__len__(), len(val_loader)))
print("Test dataset samples = {}, batches = {}".format(test_data.__len__(), len(test_loader)))

torch.cuda.empty_cache()

In [None]:
# Create the bidirectional LSTM network

class Network(nn.Module):

    def __init__(self, input_size, output_size):

        super(Network, self).__init__()

        self.lstm = nn.LSTM(input_size      = input_size,
                            num_layers      = 4,
                            hidden_size     = 256, 
                            dropout         = 0.4,
                            batch_first     = True, 
                            bidirectional   = True) 
        self.classification = nn.Sequential(
            torch.nn.Linear(512, output_size)
        )

    def forward(self, x, lx):
        input_pack = pack_padded_sequence(x, lx, 
                                          batch_first   = True,
                                          enforce_sorted= False)
        
        out, _ = self.lstm(input_pack)
        out, lens = pad_packed_sequence(out, batch_first=True)
        
        out = self.classification(out)
        out = F.log_softmax(out, dim=2)
        
        return out, lens

In [None]:
# Setup the model
RNNConfig = namedtuple(
  'RNNConfig',
  ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual'])

# As described in the paper
prenet_config = None 
rnn_config = RNNConfig(
      config['feature_dim'],)

model = Network(config['feature_dim'], 41).cuda()

# Setup the training paradigm (scheduler, optimizer, loss function)
criterion = nn.CTCLoss(zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config['epochs'], 1e-5) 
scaler = torch.cuda.amp.GradScaler()
decoder = CTCBeamDecoder(PHONEMES, log_probs_input=True, beam_width = config['beam_width'])

In [None]:
# Decode output of the model
def decode_prediction(output, output_lens, decoder, PHONEME_MAP= LABELS):
    results, _, _, lens = decoder.decode(output, seq_lens= output_lens)

    pred_strings = []
    
    for i in range(output_lens.shape[0]):
        beam = results[i][0][:lens[i][0]]
        mapped = ''.join([LABELS[i] for i in beam])
        pred_strings.append(mapped)

    return pred_strings

# Calculate levenshtein distance as model performance metric
def calculate_levenshtein(output, label, output_lens, label_lens, decoder, PHONEME_MAP= LABELS):
    
    dist            = 0
    batch_size      = label.shape[0]

    pred_strings    = decode_prediction(output, output_lens, decoder, PHONEME_MAP)
    
    for i in range(batch_size):
        
        pred_string = pred_strings[i]
        
        slices = label[i][:label_lens[i]]
        
        label_string = ''.join([LABELS[i] for i in slices])
        dist += Levenshtein.distance(pred_string, label_string)

    dist /= batch_size 
    return dist

In [None]:
# Track training progress on Weights & Biases
wandb.login(key='REDACTED')

run = wandb.init(
    name = config['architecture'],
    reinit = True,
    project = "bilstm-ablations",
    config = config
)

In [None]:
# Setup a training function
def train_model(model, train_loader, criterion, optimizer):
    
    model.train()
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train') 

    total_loss = 0

    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        x, y, lx, ly = data
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast(): 
            h, lh = model(x, lx)
            h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)

        total_loss += loss.item()

        # Prevent gradient vanishing with FP16 calculations
        scaler.scale(loss).backward()

        grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(),
                                                    config.clip_thresh)

        batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr'])),
            grad_norm="{:.04f}".format(float(grad_norm)))

        batch_bar.update()

        scaler.step(optimizer)
        scaler.update()

        del x, y, lx, ly, h, lh, loss 
        torch.cuda.empty_cache()

    batch_bar.close()
    
    return total_loss / len(train_loader)

# Create a validation function
def validate_model(model, val_loader, decoder, phoneme_map= LABELS):

    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')

    total_loss = 0
    vdist = 0

    for i, data in enumerate(val_loader):

        x, y, lx, ly = data
        x, y = x.to(device), y.to(device)

        with torch.inference_mode():
            h, lh = model(x, lx)
            h = torch.permute(h, (1, 0, 2))
            loss = criterion(h, y, lh, ly)

        total_loss += float(loss)
        vdist += calculate_levenshtein(torch.permute(h, (1, 0, 2)), y, lh, ly, decoder, phoneme_map)

        batch_bar.set_postfix(loss="{:.04f}".format(float(total_loss / (i + 1))), dist="{:.04f}".format(float(vdist / (i + 1))))

        batch_bar.update()
    
        del x, y, lx, ly, h, lh, loss
        torch.cuda.empty_cache()
        
    batch_bar.close()
    total_loss = total_loss/len(val_loader)
    val_dist = vdist/len(val_loader)
    return total_loss, val_dist

In [None]:
# Create a function to save the model state
def save_model(model, optimizer, scheduler, metric, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         metric[0]                  : metric[1], 
         'epoch'                    : epoch}, 
         path
    )

# Create a function to load a saved model
def load_model(path, model, metric= 'valid_acc', optimizer= None, scheduler= None):

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
    epoch   = checkpoint['epoch']
    metric  = checkpoint[metric]

    return [model, optimizer, scheduler, epoch, metric]

In [None]:
# Let the computer rest
torch.cuda.empty_cache()
gc.collect()

# Variables for tracking progress
last_epoch_completed = 0
start = last_epoch_completed
end = config["epochs"]
best_lev_dist = float("inf")
epoch_model_path = 'epoch.pth'
best_model_path = 'best.bth'

# Train the model
for epoch in range(0, config['epochs']):

    print("\nEpoch: {}/{}".format(epoch+1, config['epochs']))
    
    curr_lr = float(optimizer.param_groups[0]['lr'])
    train_loss = train_model(model, train_loader, criterion, optimizer)
    valid_loss, valid_dist = validate_model(model, val_loader, decoder, phoneme_map= LABELS)
    scheduler.step(valid_dist)

    print("\tTrain Loss {:.04f}\t Learning Rate {:.07f}".format(train_loss, curr_lr))
    print("\tVal Dist {:.04f}%\t Val Loss {:.04f}".format(valid_dist, valid_loss))

    wandb.log({
        'train_loss': train_loss,  
        'valid_dist': valid_dist, 
        'valid_loss': valid_loss, 
        'lr'        : curr_lr
    })
    
    save_model(model, optimizer, scheduler, ['valid_dist', valid_dist], epoch, epoch_model_path)
    wandb.save(epoch_model_path)
    print("Saved epoch model")

    if valid_dist <= best_lev_dist:
        best_lev_dist = valid_dist
        save_model(model, optimizer, scheduler, ['valid_dist', valid_dist], epoch, best_model_path)
        wandb.save(best_model_path)
        print("Saved best model")
      
run.finish()

In [None]:
# Test the model

TEST_BEAM_WIDTH = config['beam_width'] * 2

test_decoder = CTCBeamDecoder(PHONEMES, log_probs_input=True, beam_width = TEST_BEAM_WIDTH)
results = []

model.eval()
print("Testing")
for data in tqdm(test_loader):

    x, lx   = data
    x       = x.to(device)

    with torch.no_grad():
        h, lh = model(x, lx)

    prediction_string= decode_prediction(h, lh, decoder, PHONEME_MAP=LABELS)
    results.append(prediction_string)
    
    del x, lx, h, lh
    torch.cuda.empty_cache()

In [None]:
# Save test results
template = './data/test-clean/transcript/result_template.csv'
df = pd.read_csv(template)
df.label = results
df.to_csv('test_bilstm.csv', index = False)