# Training Notebook
Notebook for training the acoustic model of my speech 2 text AI.
### To-Do
* Build data loaders
* Write code to convert wav to log mel spectrograms > separate module
* Add SpecAugment to improve training
* Build acoustic model > separate module
* Training loop
* Validation loop
* Training visualization

In [1]:
import torch
import torchaudio
import torch.nn as nn
import pandas as pd
import numpy as np
import utilities

In [2]:
class LogMelSpec(nn.Module):
    
    def __init__(self, sample_rate=8000, n_mels=128, win_length=160, hop_length=80):
        super(LogMelSpec, self).__init__()
        self.transform = torchaudio.transforms.MelSpectrogram(
                            sample_rate=sample_rate, n_mels=n_mels,
                            win_length=win_length, hop_length=hop_length)
        
    def forward(self, x):
        x = self.transform(x) # mel spectrogram
        x = np.log(x + 1e-14) # logarithm, add small value to avoid divergence
        return x

In [65]:
class DataLoader(torch.utils.data.Dataset):
        
    def __init__(self, json_path, sample_rate, n_feats, specaug_rate, specaug_policy, time_mask, freq_mask,
                 valid=False, shuffle=True, text_to_int=True, log_ex=True):
        self.log_ex = log_ex
        self.text_process = utilities.TextProcess()
        
        print("Loading data json file from", json_path)
        self.data = pd.read_json(json_path, lines=True)
        
        if valid: # 
            self.audio_transforms = torch.nn.Sequential(
                LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80)
            )
            
        else:
            self.audio_transforms = torch.nn.Sequential(
                LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80),
                #SpecAugment(specaug_rate, specaug_policy, freq_mask, time_mask) # To-do: Add spec augment
            ) 
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()
        
        try:
            file_path = self.data.key.iloc[idx]
            print("file_path: " + file_path)
            waveform, _ = torchaudio.load(file_path)
            print(self.data['text2'].iloc[idx])
            label = self.text_process.text_to_int_seq(self.data['text2'].iloc[idx])
            spectrogram = self.audio_transforms(waveform) # (channel, feature, time)
            spec_len = spectrogram.shape[-1] // 2
            label_len = len(label)
            if spec_len < label_len:
                raise Exception('spectrogram len is bigger then label len')
            if spectrogram.shape[0] > 1:
                raise Exception('dual channel, skipping audio file %s'%file_path)
            #if spectrogram.shape[2] > 1650:
            #    raise Exception('spectrogram too big. size %s'%spectrogram.shape[2])
            if label_len == 0:
                raise Exception('label len is zero... skipping %s'%file_path)
        except Exception as e:
            if self.log_ex:
                print(str(e), file_path)
            return self.__getitem__(idx - 1 if idx != 0 else idx + 1)  
                
        return spectrogram, label, spec_len, label_len
   
    def describe(self):
        return self.data.describe()

In [66]:
path = 'F:/cv-corpus-11.0-2022-09-21/en/test.json'

sample_rate = 8000
n_feats = 81
specaug_rate = 0.5
specaug_policy = 3
time_mask = 70
freq_mask = 15

data_loader = DataLoader(path, sample_rate, n_feats, specaug_rate, specaug_policy, time_mask, freq_mask)

Loading data json file from F:/cv-corpus-11.0-2022-09-21/en/test.json


In [68]:
data_loader

file_path: F:/cv-corpus-11.0-2022-09-21/en/wav/common_voice_en_22917213.wav
the city of manningham encompasses two victorian state electorates


(tensor([[[-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362],
          [-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362],
          [-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362],
          ...,
          [-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362],
          [-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362],
          [-32.2362, -32.2362, -32.2362,  ..., -32.2362, -32.2362, -32.2362]]]),
 [21,
  9,
  6,
  1,
  4,
  10,
  21,
  26,
  1,
  16,
  7,
  1,
  14,
  2,
  15,
  15,
  10,
  15,
  8,
  9,
  2,
  14,
  1,
  6,
  15,
  4,
  16,
  14,
  17,
  2,
  20,
  20,
  6,
  20,
  1,
  21,
  24,
  16,
  1,
  23,
  10,
  4,
  21,
  16,
  19,
  10,
  2,
  15,
  1,
  20,
  21,
  2,
  21,
  6,
  1,
  6,
  13,
  6,
  4,
  21,
  16,
  19,
  2,
  21,
  6,
  20],
 2275,
 66)