In [1]:
import torch
import torchaudio
import torchtext
import torchaudio.functional as F
import torchaudio.transforms as T
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from audio_augmentations import *

import os, re, random
import numpy as np
import sklearn
import itertools

import pickle
from tqdm.auto import tqdm
from IPython.display import clear_output
import IPython.display as ipd
import gc
import matplotlib.pyplot as plt
import wandb

import sys
sys.path.append('..')
from models.cnn import ResidualCNN
from models.encoder import Encoder
from models.attention import Attention
from models.model import Speech_recognition_model

print(torch.__version__)
print(torchaudio.__version__)

  from .autonotebook import tqdm as notebook_tqdm


2.0.0
2.0.1


In [15]:
class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class SpeechRecognitionModel(nn.Module):
    
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats//2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


In [19]:
learning_rate=5e-4
batch_size=20
epochs=10
hparams = {
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 29,
        "n_feats": 128,
        "stride":2,
        "dropout": 0.1,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }
test_asr_model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        )

In [21]:
test_asr_model.load_state_dict(torch.load('../checkpoints/model_asr.pth', map_location=torch.device('cpu')))

<All keys matched successfully>

In [22]:
from collections import OrderedDict
submodel = OrderedDict(list(test_asr_model.state_dict().items())[:28])
torch.save(submodel, '../checkpoints/sub_model_asr.pth')

In [23]:
import yaml

params = None
with open("../configs/model_params_tl.yaml", "r") as stream:
    try:
        params = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)
model = Speech_recognition_model(**(params['Architecture']))

In [24]:
model.state_dict()

OrderedDict([('cnn.weight',
              tensor([[[[-0.2682, -0.2088, -0.0109],
                        [-0.1491, -0.0043,  0.1976],
                        [-0.1710, -0.0059,  0.1245]]],
              
              
                      [[[-0.1058, -0.2418, -0.1758],
                        [-0.2706,  0.3230, -0.0857],
                        [ 0.2923,  0.3083,  0.1840]]],
              
              
                      [[[ 0.1227,  0.0579,  0.0771],
                        [-0.3205,  0.3324, -0.1856],
                        [-0.0907, -0.0952,  0.1629]]],
              
              
                      [[[ 0.0440,  0.1645, -0.2154],
                        [-0.2481, -0.1938, -0.2603],
                        [-0.0848, -0.1096, -0.1254]]],
              
              
                      [[[ 0.2594,  0.0084,  0.1558],
                        [ 0.1859,  0.2870,  0.0746],
                        [ 0.1965,  0.1124,  0.1423]]],
              
              
                 

In [25]:
model.load_state_dict(torch.load('../checkpoints/sub_model_asr.pth', map_location=torch.device('cpu')), strict=False)

_IncompatibleKeys(missing_keys=['encoder.rnn.weight_ih_l0', 'encoder.rnn.weight_hh_l0', 'encoder.rnn.bias_ih_l0', 'encoder.rnn.bias_hh_l0', 'encoder.rnn.weight_ih_l0_reverse', 'encoder.rnn.weight_hh_l0_reverse', 'encoder.rnn.bias_ih_l0_reverse', 'encoder.rnn.bias_hh_l0_reverse', 'encoder.rnn.weight_ih_l1', 'encoder.rnn.weight_hh_l1', 'encoder.rnn.bias_ih_l1', 'encoder.rnn.bias_hh_l1', 'encoder.rnn.weight_ih_l1_reverse', 'encoder.rnn.weight_hh_l1_reverse', 'encoder.rnn.bias_ih_l1_reverse', 'encoder.rnn.bias_hh_l1_reverse', 'encoder.rnn.weight_ih_l2', 'encoder.rnn.weight_hh_l2', 'encoder.rnn.bias_ih_l2', 'encoder.rnn.bias_hh_l2', 'encoder.rnn.weight_ih_l2_reverse', 'encoder.rnn.weight_hh_l2_reverse', 'encoder.rnn.bias_ih_l2_reverse', 'encoder.rnn.bias_hh_l2_reverse', 'encoder.rnn.weight_ih_l3', 'encoder.rnn.weight_hh_l3', 'encoder.rnn.bias_ih_l3', 'encoder.rnn.bias_hh_l3', 'encoder.rnn.weight_ih_l3_reverse', 'encoder.rnn.weight_hh_l3_reverse', 'encoder.rnn.bias_ih_l3_reverse', 'encoder.r

In [26]:
model.state_dict()

OrderedDict([('cnn.weight',
              tensor([[[[-2.1007e-03, -8.3343e-02, -7.8526e-04],
                        [ 2.8128e-04,  5.8829e-03, -1.9605e-04],
                        [-1.3582e-03, -8.5892e-02, -4.4249e-04]]],
              
              
                      [[[ 1.4868e-02,  8.4935e-03, -9.5815e-02],
                        [ 4.7094e-03, -1.9104e-02,  1.2676e-02],
                        [-2.7919e-02,  8.4048e-03, -4.4409e-01]]],
              
              
                      [[[ 7.1305e-02, -1.6534e-01,  9.3900e-02],
                        [-3.5325e-02,  5.1871e-02, -1.0762e-02],
                        [-3.2981e-02,  5.2358e-02, -3.1448e-03]]],
              
              
                      [[[ 1.1526e-02,  2.9924e-01, -1.2135e-01],
                        [ 8.0982e-04,  1.6567e-01,  1.5670e-02],
                        [ 4.9740e-03,  2.3941e-01,  1.0255e-01]]],
              
              
                      [[[-6.6552e-02, -1.6664e-01, -5.7732e-04],