# Building an end-to-end Speech Recognition model in PyTorch for Igbo

## Welcome to this notebook, where we will walk you through building an end-to-end speech recognition model in Pytorch for Igbo language. This is part of the [OkwuGbe paper](https://arxiv.org/abs/2103.07762) on building ASR models for Fon and Igbo


---





This code was inspired by the article [Building an end-to-end Speech Recognition model in PyTorch](https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch)

---



We connect to Google drive. We provide the directory to save the model checkpoints, etc.

In [None]:
import os
#Connect to Gdrive to store model checkpoints
from google.colab import drive
drive.mount('/content/drive')

model_path = '/content/drive/MyDrive/IgboASR'
if not os.path.isdir(model_path):
  os.makedirs(model_path)
model_path = '/content/drive/MyDrive/IgboASR/ig_asr'
model_path_loss = '/content/drive/MyDrive/IgboASR/ig_asr_best_loss' 
#model_path='./ig_asr'

Mounted at /content/drive


Install dependencies.


> Torch audio v 0.4.0

> Torch v 1.4.0

> gdown is used to access Google drive capabilities. Like automatically downloading your dataset, saved on your GDrive straight to your workspace.







In [None]:
!pip install torchaudio==0.4.0 torch==1.4.0 > /dev/null

In [None]:
!pip install gdown >/dev/null

In [None]:
#All imports

import os,sys,re
import unicodedata #to normalize diacritics
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np
import zipfile
import random
from typing import Tuple




This is to download the data. Since the data set used to perform the experiments described in the paper were not open source, we signed an agreement not to disclose the dataset. 

In [None]:
!gdown --id [id of your zipped dataser] -O igbodata.zip #This downloads the zipped data set and saves as igbodata.zip

In [None]:
#To extract the zipped data set
with zipfile.ZipFile("igbodata.zip","r") as zip_ref:
    zip_ref.extractall("./IgboAudio")

## Train-test-valid split

Our zipped data set only has train and test folders. So we need to get the valid data set. In order to boost generalization, we took 70% of the valid dataset from train part and 30% from test part. 

In [None]:
#To get the valid indices
random.seed(123)
max_words=10
with open('./IgboAudio/data/train/data.txt', newline='',encoding='UTF-8') as f:
      tr_data= f.readlines()
      tr_data = [tr_data[i] for i in range(len(tr_data)) if len((tr_data[i].split('|')[1]).strip().split(' ')) <=max_words]
 
v = 10000  #samples for valid. Change as you want.
v_train = 7000 #how many to take from train data
v_test=3000 #how many to take from test data
patience=50 #our patience
BATCH_MULTIPLIER = 7
test_list = [i for i in range(len(tr_data))]
valid_indices_train = random.choices(test_list, k=v_train)
print(f"Valid indices from train \n {valid_indices_train}")

with open('./IgboAudio/data/test/data.txt', newline='',encoding='UTF-8') as f:
      t_data= f.readlines()
      t_data = [t_data[i] for i in range(len(t_data))]
test_list = [i for i in range(len(t_data))]
valid_indices_test = random.choices(test_list, k=v_test)
print(f"Valid indices from test \n {valid_indices_test}")
 


We defined a function that based on the type ('train', 'test', 'valid'), generates a list of directories to the audio files, as well as their utterances.



In [None]:
def get_data(datatype): # can be either train or test. Any other format will throw an error.
 
  if datatype == "train":
    with open('./IgboAudio/data/{}/data.txt'.format(datatype), 'r') as f:
      data = f.readlines()
      
    train_data = [data[i] for i in range(len(data)) if i not in valid_indices_train and len((data[i].split('|')[1]).strip().split(' ')) <=max_words]
    print(f"Length of train data: {len(train_data)}")
    return train_data

  if datatype == "test":
    with open('./IgboAudio/data/{}/data.txt'.format(datatype), 'r') as f:
      data = f.readlines()
      
    test_data = [data[i] for i in range(len(data)) if i not in valid_indices_test and len((data[i].split('|')[1]).strip().split(' ')) <= max_words]
    print(f"Length of test data: {len(test_data)}")
    return test_data

  if datatype=="valid": # then we should get out some for valid
     with open('./IgboAudio/data/train/data.txt', 'r') as f:
      with open('./IgboAudio/data/test/data.txt', 'r') as ft:
        data = f.readlines()
        t_data = ft.readlines()
      
        v_train_data = [data[i] for i in range(len(data)) if i in valid_indices_train and len((data[i].split('|')[1]).strip().split(' ')) <= max_words]
        v_test_data = [t_data[i] for i in range(len(t_data)) if i in valid_indices_test and len((tr_data[i].split('|')[1]).strip().split(' ')) <= max_words]
        val_data = v_train_data + v_test_data
        print(f"Length of validation data: {len(val_data)}")
        return val_data


## Dataloader class creation

Here, we create our special IgboASR dataset that can be used in a torch DataLoader

In [None]:
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
    download_url,
    extract_archive,
    walk_files,
)

def load_audio_item(d: list):
    d=d.split("|")
    utterance = d[1].strip()
    wav_path =d[0]
    wav_path = re.sub(r'\\','/',wav_path)
    wav_path = wav_path+'.wav'
    
    
    wav_path = wav_path.replace("./data/",'./IgboAudio/data/')
  
    
    #wav_path=os.path.normpath(wav_path)
    
    # Load audio
    waveform, sample_rate = torchaudio.load(wav_path)
    #print(wav_path)
    return (waveform, 
        sample_rate,
        utterance
    )


class IgboASR(torch.utils.data.Dataset):
    """Create a Dataset for Igbo ASR.
    Args:
    data_type could be either 'test', 'train' or 'valid'
    """
    def __init__(self, data_type):

      """data_type could be either 'test', 'train' or 'valid' """
      self.data = get_data(data_type)

    def __getitem__(self, n: int):
      """Load the n-th sample from the dataset.

      Args:
          n (int): The index of the sample to be loaded

      Returns:
          tuple: ``(waveform, sample_rate, utterance)``
      """
      fileid = self.data[n]
      return load_audio_item(fileid)


    def __len__(self) -> int:
      return len(self.data)


## Taking care of accents and diacritics (for Fon)

since we added diacritics and accents to the model (as described in our paper), this bit of code was written to take care of the mapping: from character to number. The aim of this code is that a character and its accent will be taken as one.

In [None]:
accent_code = [b'\\u0301',b'\\u0300',b'\\u0306',b'\\u0308',b'\\u0303']
alpha = {'ɔ':0,'ɛ':5}
accents = {b'\\u0301':1,b'\\u0300':2,b'\\u0306':3,b'\\u0308':4,b'\\u0303':5}
mapping={
    1:'ɔ́',2:'ɔ̀',3:'ɔ̆',6:'έ',7:'ὲ',8:'ɛ̆'
}
#we are following the idea that the composition gives the letter first followed by the sign(accent)
def get_better_mapping(text):
  t_arr = [t for t in text]
  s=[]
  for i in range(len(t_arr)):
    if t_arr[i].encode("unicode_escape") in accent_code:
      to_check = s[-1]
      try:
        val = mapping[alpha[to_check] + accents[t_arr[i].encode("unicode_escape")]]
        s.pop()
        s.append(val)
      except KeyError:
        #print("Could not find for {} in sentence {} | Proceeding with default.".format(t_arr[i],text))
        print("")
      
    else: 
      s.append(t_arr[i])
  return s

## Setting up your metrics
Here we define the functions to calculate the error of the model using the metrics:
1.   WER
2.   CER



In [None]:

def avg_wer(wer_scores, combined_ref_len):
    return float(sum(wer_scores)) / float(combined_ref_len)


def _levenshtein_distance(ref, hyp):
    """Levenshtein distance is a string metric for measuring the difference
    between two sequences. Informally, the levenshtein disctance is defined as
    the minimum number of single-character edits (substitutions, insertions or
    deletions) required to change one word into the other. We can naturally
    extend the edits to word level when calculate levenshtein disctance for
    two sentences.
    """
    m = len(ref)
    n = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if m == 0:
        return n
    if n == 0:
        return m

    if m < n:
        ref, hyp = hyp, ref
        m, n = n, m

    # use O(min(m, n)) space
    distance = np.zeros((2, n + 1), dtype=np.int32)

    # initialize distance matrix
    for j in range(0,n + 1):
        distance[0][j] = j

    # calculate levenshtein distance
    for i in range(1, m + 1):
        prev_row_idx = (i - 1) % 2
        cur_row_idx = i % 2
        distance[cur_row_idx][0] = i
        for j in range(1, n + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
            else:
                s_num = distance[prev_row_idx][j - 1] + 1
                i_num = distance[cur_row_idx][j - 1] + 1
                d_num = distance[prev_row_idx][j] + 1
                distance[cur_row_idx][j] = min(s_num, i_num, d_num)

    return distance[m % 2][n]


def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in word-level.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
    :return: Levenshtein distance and word number of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    ref_words = reference.split(delimiter)
    hyp_words = hypothesis.split(delimiter)

    edit_distance = _levenshtein_distance(ref_words, hyp_words)
    return float(edit_distance), len(ref_words)


def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
    """Compute the levenshtein distance between reference sequence and
    hypothesis sequence in char-level.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
    :return: Levenshtein distance and length of reference sentence.
    :rtype: list
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    join_char = ' '
    if remove_space == True:
        join_char = ''

    reference = join_char.join(filter(None, reference.split(' ')))
    hypothesis = join_char.join(filter(None, hypothesis.split(' ')))

    edit_distance = _levenshtein_distance(reference, hypothesis)
    return float(edit_distance), len(reference)


def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Calculate word error rate (WER). WER compares reference text and
    hypothesis text in word-level. WER is defined as:
    .. math::
        WER = (Sw + Dw + Iw) / Nw
    where
    .. code-block:: text
        Sw is the number of words subsituted,
        Dw is the number of words deleted,
        Iw is the number of words inserted,
        Nw is the number of words in the reference
    We can use levenshtein distance to calculate WER. Please draw an attention
    that empty items will be removed when splitting sentences by delimiter.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param delimiter: Delimiter of input sentences.
    :type delimiter: char
    :return: Word error rate.
    :rtype: float
    :raises ValueError: If word number of reference is zero.
    """
    edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
                                         delimiter)

    if ref_len == 0:
        raise ValueError("Reference's word number should be greater than 0.")

    wer = float(edit_distance) / ref_len
    return wer


def cer(reference, hypothesis, ignore_case=False, remove_space=False):
    """Calculate charactor error rate (CER). CER compares reference text and
    hypothesis text in char-level. CER is defined as:
    .. math::
        CER = (Sc + Dc + Ic) / Nc
    where
    .. code-block:: text
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
        Nc is the number of characters in the reference
    We can use levenshtein distance to calculate CER. Chinese input should be
    encoded to unicode. Please draw an attention that the leading and tailing
    space characters will be truncated and multiple consecutive space
    characters in a sentence will be replaced by one space character.
    :param reference: The reference sentence.
    :type reference: basestring
    :param hypothesis: The hypothesis sentence.
    :type hypothesis: basestring
    :param ignore_case: Whether case-sensitive or not.
    :type ignore_case: bool
    :param remove_space: Whether remove internal space characters
    :type remove_space: bool
    :return: Character error rate.
    :rtype: float
    :raises ValueError: If the reference length is zero.
    """
    edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
                                         remove_space)

    if ref_len == 0:
        raise ValueError("Length of reference should be greater than 0.")

    cer = float(edit_distance) / ref_len
    return cer

class TextTransform:
    """Maps characters to integers and vice versa"""
    def __init__(self):
      #The Igbo alphabet (used for this experiment) consists of the set {a,...,z,space,apostrophe,comma} 
        char_map_str = """
          ' 0
          <SPACE> 1
          a 2
          b 3
          c 4
          d 5
          e 6
          f 7
          g 8
          h 9
          i 10
          j 11
          k 12
          l 13
          m 14
          n 15
          o 16
          p 17
          q 18
          r 19
          s 20
          t 21
          u 22
          v 23
          w 24
          x 25
          y 26
          z 27
          """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        #text=unicodedata.normalize("NFC",text)
        for c in text.strip():
            try:
              if c == ' ':
                  ch = 1
              #elif c =='̀':
              #    ch=0
              else:
                  ch = self.char_map[c]
            except KeyError:
              print("Error for character {} in this sentence: {}".format(c,text))
              ch=0
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<SPACE>', ' ')

## Data processing and Greedy Decoder

Here we deifne the data preprocessing part: the mel spectogram, and data augmentation techniques. We also define the GreedyDecoder algorithm which is used at the model output.

In [None]:
train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
test_audio_transforms = torchaudio.transforms.MelSpectrogram()

text_transform = TextTransform()

def data_processing(data, data_type="train"):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []

    for waveform,_,utterance in data:
        if data_type == 'train':
            spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        elif data_type == 'valid':
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        elif data_type == 'test':
            spec = test_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            raise Exception('data_type should be train, valid or test')
        spectrograms.append(spec)
        label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
        labels.append(label)
        input_lengths.append(spec.shape[0]//2)
        label_lengths.append(len(label))
    
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, labels, input_lengths, label_lengths


def GreedyDecoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
	arg_maxes = torch.argmax(output, dim=2)
	decodes = []
	targets = []
	for i, args in enumerate(arg_maxes):
		decode = []
		targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
		for j, index in enumerate(args):
			if index != blank_label:
				if collapse_repeated and j != 0 and index == args[j -1]:
					continue
				decode.append(index.item())
		decodes.append(text_transform.int_to_text(decode))
	return decodes, targets


## The Model
See [here](https://drive.google.com/file/d/1gT4r1R8Iq_183WkU3l0nNtdy4YYvYHPp/view?usp=sharing) for more explanation. 

In [None]:
class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x # (batch, channel, feature, time)

class BidirectionalLSTM(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalLSTM, self).__init__()

        self.BiLSTM = nn.LSTM(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiLSTM(x)
        x = self.dropout(x)
        return x


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class SpeechRecognitionModel(nn.Module):
    
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats//2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
        self.birnn_layers1 = nn.Sequential(*[
            BidirectionalLSTM(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])
        self.birnn_layers2 = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        # print(x.size())
        x = self.birnn_layers1(x)
        # print(x.size())
        x = self.birnn_layers2(x)
        x = self.classifier(x)
        return x


## The Training and Evaluating Script

We define seperate functions to perform the training, validation and testing. Then we put it together inside a 'main' function.

Here we also made tweaks to allow for saving best model weights to specified path in GDrive and re-training from last saved checkpoint.

In [None]:
class IterMeter(object):
    """keeps track of total iterations"""
    def __init__(self):
        self.val = 0

    def step(self):
        self.val += 1

    def get(self):
        return self.val


def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment,valid_loader,best_loss,curr_patience,best_test_loss):
    model.train()
    data_len = len(train_loader.dataset)
    train_loss=0
    batch_train_loss=0
    
    optimizer.zero_grad()
    for batch_idx, _data in enumerate(train_loader):
        spectrograms, labels, input_lengths, label_lengths = _data 
        spectrograms, labels = spectrograms.to(device), labels.to(device)


        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
        #print("Loss untouched: ",loss.item())
        train_loss += loss.item() / (len(train_loader)*BATCH_MULTIPLIER)
        loss.backward()

       
        if (batch_idx + 1) % BATCH_MULTIPLIER == 0:
       
            optimizer.step()
            #scheduler.step()
            iter_meter.step()
            #model.zero_grad() #reset gradients
            optimizer.zero_grad()
            batch_train_loss+=train_loss
            train_loss=0
        if batch_idx % 300 == 0 or batch_idx == data_len:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(spectrograms), data_len,
                100. * batch_idx / len(train_loader), loss.item()))
            
    experiment['loss'].append((batch_train_loss,iter_meter.get()))
    val_loss,val_test_loss = valid(model, device, valid_loader, criterion, epoch, iter_meter, experiment)
    if val_loss < best_loss:
      curr_patience=0
      best_loss = val_loss
      #save model dicts
      torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'val_loss':val_loss
              }, model_path)

    else:
      curr_patience+=1
      print("...No improvement in validation WER from {}...".format(best_loss))
    if val_test_loss < best_test_loss:
      print("Improvement in main loss. Saving model weights to model_path_loss ")
      curr_patience=0
      best_test_loss = val_test_loss
      #save model dicts
      torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'val_loss':val_test_loss
              }, model_path_loss)

    
    return best_loss, curr_patience,best_test_loss


def valid(model, device, test_loader, criterion, epoch, iter_meter, experiment):
    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)

            decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                test_wer.append(wer(decoded_targets[j], decoded_preds[j]))


    avg_cer = sum(test_cer)/len(test_cer)
    avg_wer = sum(test_wer)/len(test_wer)
    experiment['val_loss'].append((test_loss, iter_meter.get()))
    experiment['cer'].append((avg_cer, iter_meter.get()))
    experiment['wer'].append((avg_wer, iter_meter.get()))

    print('Valid set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(test_loss, avg_cer, avg_wer))
    return avg_wer,test_loss

    

def test(model, device, test_loader, criterion, epoch, iter_meter, experiment):
    print('\nevaluating...')
    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)

            decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                test_wer.append(wer(decoded_targets[j], decoded_preds[j]))


    avg_cer = sum(test_cer)/len(test_cer)
    avg_wer = sum(test_wer)/len(test_wer)
    

    print('Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(test_loss, avg_cer, avg_wer))


def main(learning_rate, batch_size, epochs,experiment, disabled=True):

    hparams = {
        "n_cnn_layers": 7,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 29,
        "n_feats": 128,
        "stride":2,
        "dropout": 0.1,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }

    

    use_cuda = torch.cuda.is_available()
    torch.manual_seed(7)
    device = torch.device("cuda" if use_cuda else "cpu")
    print("DEVICE: {}".format(device))

    if not os.path.isdir("./data"):
        print("Making dir of /data")
        os.makedirs("./data")


    train_dataset = IgboASR("train")
    valid_dataset = IgboASR("valid")
    test_dataset = IgboASR("test")
    

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                collate_fn=lambda x: data_processing(x, 'train'),
                                **kwargs)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'valid'),
                                **kwargs)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'test'),
                                **kwargs)

    model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        ).to(device)

    print(model)
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    criterion = nn.CTCLoss(blank=28,zero_infinity=True).to(device)
    #scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
    #                                        steps_per_epoch=batch_size*BATCH_MULTIPLIER,
     #                                       epochs=hparams['epochs'],anneal_strategy="linear")
    scheduler=None


    iter_meter = IterMeter()
    best_loss=1000
    best_test_loss=1000
    
    curr_patience=0
    """
    #If your Colab crashes mid-way, uncomment this code

    #If loading checkpoint for best model on WER, comment the block below
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_saved = checkpoint['epoch']
    best_loss = checkpoint['val_loss']
    best_test_loss=1000
    """
    #if loading checkpoint for best model on Loss, comment the block above
    checkpoint = torch.load(model_path_loss)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_saved = checkpoint['epoch']
    best_test_loss = checkpoint['val_loss']
    best_loss=1000
    

    for epoch in range(epoch_saved, epochs + 1):
        best_loss,curr_patience,best_test_loss = train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment,valid_loader,best_loss,curr_patience,best_test_loss)
        #valid(model, device, valid_loader, criterion, epoch, iter_meter, experiment)
        if curr_patience==patience:
          print("Early stopping with patience of {}".format(patience))
          break
    
    """

    
    for epoch in range(1, epochs + 1):
        best_loss,curr_patience,best_test_loss = train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment,valid_loader,best_loss,curr_patience,best_test_loss)
        #valid(model, device, valid_loader, criterion, epoch, iter_meter, experiment)
        if curr_patience==patience:
          print("Early stopping with patience of {}".format(patience))
          break
    """
    print("Evaluating on Test data:")
    test(model, device, test_loader, criterion, epoch, iter_meter, experiment)
    
    

## Running Experiment

Here you choose the learning rate, batch size and epochs. Then you run the cell and your model starts training.

In [None]:
"""
experiment logs the metrics of the model while training. each array element is a tuple (metric,step).
For example, we may have loss= [(loss:1.0, step:2)]
"""
experiment={
    'loss':[],
    'val_loss':[],
    'cer':[],
    'wer':[]

}

learning_rate = 0.001
batch_size = 6
epochs = 1000

print("Using BATCH SIZE -> {} and multiplier -> {}".format(batch_size,BATCH_MULTIPLIER))
main(learning_rate, batch_size, epochs, experiment)


Below is for Debugging


In [None]:

txt = get_better_mapping('ɖɔ̆lɔ̆àyìɔ̆zέn')
print(txt)

In [None]:
txt = unicodedata.normalize("NFC",'ɛ̃')
#txt='Bäume'
c=0
print([i for i in txt])
for t in txt:
  c+=1

  print(t.encode("unicode_escape"))
  