## MN+LSTM+CTC training using SS with untagged blanks (LSTM - 2 layer, inp_dim = 429 and embedding dim = 1024)

In [None]:
from torch.optim import Adam
from torch import nn
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
import numpy as np
import torch
from torchvision import transforms
import torchvision
import sys
import os
from datetime import datetime
import random

from pytorch_end2end import CTCDecoder
from jiwer import wer
from dataloader.ctc_support_set_loader import SupportDataSet, collate_wrapper, stackup
from dataloader.ctc_batch_loader import QueryDataSet, batch_stackup, pad_batch
from torch.utils.data import DataLoader

import Configuration as config
from tqdm.notebook import tqdm
import editdistance as ed

In [None]:
EPSILON = 1e-8
seed=25
torch.manual_seed(seed)
np.random.seed(seed)

# CNN Embedding

In [None]:
# cnn embedding - encoder_g  (SS)
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def conv_block(in_channels: int, out_channels: int) -> nn.Module:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

def few_shot_cnn_encoder(num_input_channels=1, outchnls=64) -> nn.Module:
    return nn.Sequential(
        conv_block(num_input_channels, outchnls),
        conv_block(outchnls, outchnls),
        conv_block(outchnls, outchnls),
        Flatten(),
    )

# LSTM Embedding

In [None]:
# lstm embedding - encoder_f (Batch utterances)
class few_shot_lstm_encoder(nn.Module):
    def __init__(self, lstm_inp_dim):
        super(few_shot_lstm_encoder, self).__init__()
        
        self.count = 0
        
        # LSTM parameters
        self.lstm_inp_dim = lstm_inp_dim  #39
        self.lstm_hidden_dim = 128
        self.lstm_layers = 2
        self.batch_norm = True
                
        self.bilstm = nn.LSTM(input_size=self.lstm_inp_dim, hidden_size=self.lstm_hidden_dim, num_layers=self.lstm_layers,
                              bidirectional=True,batch_first=True)
        
     
    def forward(self, input):
        ys, _ = self.bilstm(input)  #1024
        return ys

# Matching Network

In [None]:
class MatchingNetwork(nn.Module):
    def __init__(self, n, k, q, num_input_channels, lstm_input_size,use_cuda=False):

        super(MatchingNetwork, self).__init__()
        self.n = n
        self.k = k
        self.q = q
        self.num_input_channels = num_input_channels
        self.lstm_input_size = lstm_input_size
        self.encoder_g = few_shot_cnn_encoder(self.num_input_channels)    # cnn encoder for SS , emb_size = 256
        self.encoder_f = few_shot_lstm_encoder(self.lstm_input_size)     # LSTM encoder for batch , emb_size = 256
        
        isCudaAvailable = torch.cuda.is_available()
        self.use_cuda = use_cuda
        
        if isCudaAvailable & self.use_cuda:
            self.device=torch.device('cuda:0')
        else:
            self.device=torch.device('cpu')

    def forward(self, inputs):
        pass

# Cosine computation

In [None]:
class DistanceNetwork(nn.Module):
    def __init__(self,n,k,q,bshot):
        super(DistanceNetwork, self).__init__()
        self.n=n
        self.k=k
        self.q=q
        self.bshot=bshot

    def forward(self, support_set, query_set):
        eps = 1e-10
        similarities = []
        
        f = self.bshot  # total no:of blanks in support set
        sum_support = torch.sum(torch.pow(support_set, 2), 2)
        support_magnitude = sum_support.clamp(eps, float("inf")).rsqrt()
        dot_product = query_set.matmul(support_set.permute(0,2,1))
             
        cosine_similarity = dot_product * (support_magnitude.unsqueeze(1))  # for ctc
        cosine_similarity  = torch.cat( [cosine_similarity[:,:, :f].topk(self.n, dim=2).values, cosine_similarity [:,:, f:]], dim=2) # take top n blanks    
        return cosine_similarity

In [None]:
class episodeBuilder:
    def __init__(self, model, n_shot, k_way, b_shot, q_queries=1,    # changes for handling more blank samples
                 batch_size=1,
                 data_train=None,batch_train=None,batch_val=None,batch_test=None):
        
        self.ctc_loss = nn.CTCLoss(reduction='sum')
        self.optimiser = Adam(model.parameters(), lr=1e-3)
        
        self.label_dict = config.label_dict
        self.ctc_labels = config.ctc_labels
        
        self.decoder = CTCDecoder(blank_idx=0, beam_width=100, time_major=False,after_logsoftmax=True,labels=self.ctc_labels)
         
        
        self.model=model
    
        self.n_shot=n_shot
        self.b_shot=b_shot  # changes for handling more blank samples
        self.k_way=k_way
        self.q_queries=q_queries
        self.batch_size=batch_size
        
        self.dn = DistanceNetwork(n_shot,k_way,q_queries,b_shot)

        self.data_train = data_train
        self.batch_train = batch_train
        self.batch_val = batch_val
        self.batch_test = batch_test
        
        self.device=self.model.device
    
    def compute_wer(self, index, input_sizes, targets, target_sizes):
        batch_errs = 0
        batch_tokens = 0
        for i in range(len(index)):
            label = targets[i][:target_sizes[i]]
            pred = []
            for j in range(len(index[i][:input_sizes[i]])):
                if index[i][j] == 0:
                    continue
                if j == 0:
                    pred.append(index[i][j])
                if j > 0 and index[i][j] != index[i][j-1]:
                    pred.append(index[i][j])
            batch_errs += ed.eval(label, pred)
            batch_tokens += len(label)
        return batch_errs, batch_tokens

    def each_episode_ctc(self, X,y,batch_query, trainflag=True):
    
        self.train = trainflag
        if self.train:
            # Zero gradients
            self.model.train()
            self.optimiser.zero_grad()
        else:
            self.model.eval()

        self.batch_size=X.size(0)
        
        inputs, input_sizes, targets, target_sizes, utt = batch_query
        inputs = inputs.to(self.device)
        input_sizes = input_sizes.to(self.device)
        targets = targets.to(self.device)
        target_sizes = target_sizes.to(self.device)
        
        inputs = torch.squeeze(inputs,2)
        inputs = inputs.reshape(inputs.shape[0],inputs.shape[1], -1)  #(batch_size,frames,inp_dim)
        
        # Embed all batch utterances using LSTM ( to use the sequence information - f)
        embeddings_X_ctc = self.model.encoder_f(inputs)
        
        # Embed SS samples using CNN embedding (g)
        encoded_items_X = []
        for i in np.arange(self.batch_size):
            ssi_X = X[i]
            gen_encode_X = self.model.encoder_g(ssi_X)
            encoded_items_X.append(gen_encode_X)
        embeddings_X = torch.stack(encoded_items_X)

        support = embeddings_X
        queries = embeddings_X_ctc
               
        similarities = self.dn(support_set=support, query_set=queries)
        softmax = nn.LogSoftmax(dim=2)        
        attention = self.matching_net_predictions(similarities)
        ypreds = softmax(attention)
        
        ypreds = ypreds.transpose(0,1)  # [len, batch_size , num_classes]
        ypreds = ypreds.to(self.device, dtype=torch.float32)   
        
        out_len, b_size, _ = ypreds.size()
        input_sizes = (input_sizes * out_len).long()
        
        loss = self.ctc_loss(ypreds, targets, input_sizes, target_sizes)
         
        values, indices = torch.max(ypreds, dim=-1)
        
        
        batch_errs, batch_tokens = self.compute_wer(indices.transpose(0,1).cpu().numpy(), input_sizes.cpu().numpy(), 
                                                    targets.cpu().numpy(), target_sizes.cpu().numpy())
        
        error = batch_errs/batch_tokens
        return loss, error
    
    def each_episode_ctc_test(self, X,y,batch_query, trainflag=False):
        
        self.train = trainflag
        
        if self.train:
            # Zero gradients
            self.model.train()
            self.optimiser.zero_grad()
        else:
            self.model.eval()

        self.batch_size=X.size(0)
        
        inputs, input_sizes, targets, target_sizes, utt = batch_query
        inputs = inputs.to(self.device)
        
        
        inputs = torch.squeeze(inputs,2)
        inputs = inputs.reshape(inputs.shape[0],inputs.shape[1], -1)  #(batch_size,frames,inp_dim)
        
        # Embed all query utterances using LSTM ( to use the sequence information - f)
        embeddings_X_ctc = self.model.encoder_f(inputs)
        
        # Embed SS samples using CNN embedding (g)
        encoded_items_X = []
        for i in np.arange(self.batch_size):
            ssi_X = X[i]
            gen_encode_X = self.model.encoder_g(ssi_X)
            encoded_items_X.append(gen_encode_X)
        embeddings_X = torch.stack(encoded_items_X)

        support = embeddings_X
        queries = embeddings_X_ctc
    
        
        similarities = self.dn(support_set=support, query_set=queries)
        softmax = nn.LogSoftmax(dim=2)
        
        attention = self.matching_net_predictions(similarities)
        ypreds = softmax(attention)   # [batch_size , len, num_classes]
        
        ypreds = ypreds.to(self.device, dtype=torch.float32)  
        ypreds_sizes = torch.zeros(self.batch_size)
        ypreds_sizes[0] = ypreds.shape[1]
        ypreds_sizes = ypreds_sizes.long().to(self.device)
        
        values, indices = torch.max(ypreds, dim=-1)
            
        decoded_targets, decoded_targets_lengths, decoded_sentences = self.decoder.decode(ypreds,ypreds_sizes)
        
        target_label_indices = torch.squeeze(targets,0).tolist()
        target_label = [self.label_dict[ele] for ele in target_label_indices]
        pred_index = decoded_targets.view(decoded_targets.shape[1]).tolist()

        pred_label = [self.label_dict[ele] for ele in pred_index]
        per = wer(target_label,pred_label)
        
        return per
     
    def matching_net_predictions(self, attention):
        """Calculates Matching Network predictions based on equation (1) of the paper.
        """
        q=self.q_queries
        k=self.k_way
        n=self.n_shot
        
        y_preds=[]
        for eachbatch in range(attention.size(0)):
            # Create one hot label vector for the support set
            y_onehot = torch.zeros(k * n, k)
            ys = self.create_nshot_task_label(k, n).unsqueeze(-1)       
            y_onehot = y_onehot.scatter(1, ys, 1)
            
            y_pred = torch.mm(attention[eachbatch], y_onehot.to(self.device, dtype=torch.float32))
            y_preds.append(y_pred)
            
        y_preds=torch.stack(y_preds)

        return y_preds

    def create_nshot_task_label(self, k, n):
        return torch.arange(0, k, 1 / n).long()    
    
    def run_training_epoch(self, total_train_batches):
        total_loss = 0.0
        total_error = 0.0
        total_frame_acc = 0.0
        
        with tqdm(total=total_train_batches, desc='train', leave=False) as pbar1: 
            for i,  batch_query in enumerate(self.batch_train):  # to iterate through all utterances in an epoch
                support_data = next(iter(self.data_train))
                X,y,ylabels=support_data
                X=X.to(self.device, dtype=torch.float32)
                y=y.to(self.device, dtype=torch.long)
                               
                loss, err = self.each_episode_ctc(X, y, batch_query, trainflag=True)
                
                total_loss += loss
                total_error += err
               
                loss.backward()
                self.optimiser.step()
        
                pbar1.update(1)
        
        total_loss = total_loss / total_train_batches
        total_error = total_error / total_train_batches
        
        return total_loss, total_error
        
    def run_val_epoch(self, total_val_batches):
        total_loss = 0.0
        total_error = 0.0
        total_frame_acc = 0.0
               
        with tqdm(total=total_val_batches, desc='val', leave=False) as pbar1:
            for i,  batch_query in enumerate(self.batch_val):  # to iterate through all utterances in an epoch
                support_data = next(iter(self.data_train))
                X,y,ylabels=support_data
                X=X.to(self.device, dtype=torch.float32)
                y=y.to(self.device, dtype=torch.long)
                               
                loss, err = self.each_episode_ctc(X, y, batch_query, trainflag=False)

                total_loss += loss.data
                total_error += err
                
                pbar1.update(1)

        total_loss = total_loss / total_val_batches
        total_error = total_error / total_val_batches
        return total_loss, total_error#, total_frame_acc

    def run_test_epoch(self, total_test_batches):
        total_per = 0.0
        
        with tqdm(total=total_test_batches, desc='test batches:', leave=False) as pbar:
            pred_list=[]
            for i,  batch_query in enumerate(self.batch_test):   # to iterate through all utterances in an epoch
                support_data = next(iter(self.data_train))
                X,y,ylabels=support_data
                X=X.to(self.device, dtype=torch.float32)
                y=y.to(self.device, dtype=torch.long)

                per = self.each_episode_ctc_test(X, y, batch_query, trainflag=False)
                total_per += per
                pbar.update(1)
                
            total_per = total_per / total_test_batches
            return(total_per)
    
    def save_model(self, tepochs, fpath):
        fpath=fpath[:-4]+'_'+str(self.k_way)+'_'+str(self.n_shot)+ '_'+str(self.b_shot)+ '_'+str(tepochs)+fpath[-4:]
        torch.save(self.model.state_dict(), fpath)
        return fpath


In [None]:
def build_model(n_train, k_train, q_train, fce=False, use_cuda=False):
    num_input_channels=1 
    lstm_input_size = 39 # encoder output   ## check

    model = MatchingNetwork(n_train, k_train, q_train, num_input_channels, lstm_input_size,      # include lstm size
                            use_cuda=use_cuda)

    model=model.to(model.device, dtype=torch.float32)
    return model

def load_model(fpath, n_train, k_train, q_train, use_cuda=False, eval_flag=True):
    model = build_model(n_train, k_train, q_train, use_cuda=use_cuda)
    model.load_state_dict(torch.load(fpath,map_location=model.device))
    model = model.to(model.device, dtype=torch.float32)
    if eval_flag:
        model.eval()
    return model

In [None]:
#K way N Shot.. 
n = config.Q
k = config.P

bshot=config.blanks_train;
q=1; #As of now q=1 only supported
batch_size=1 #As of now batch_size=1 supported


train_dataset=SupportDataSet(supportdatafile=config.trainSupportSet, kway=k, nshot=n, bshot=bshot, nqueries=q, phonemes=config.ALLPHONEMES, index_dict=config.index_dict, label_dict=config.label_dict,
                           transform=transforms.Compose([ stackup()]))     # changes for handling more blank samples
                          
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_wrapper)


train_batch_dataset=QueryDataSet(querysetfile=config.trainQuerySet,transform=transforms.Compose([ batch_stackup(config.index_dict)]))                              
train_batch_loader = DataLoader(train_batch_dataset, batch_size=batch_size, collate_fn=pad_batch) 



val_batch_dataset=QueryDataSet(querysetfile=config.devQuerySet,transform=transforms.Compose([ batch_stackup(config.index_dict)]))                              
val_batch_loader = DataLoader(val_batch_dataset, batch_size=batch_size, collate_fn=pad_batch)


test_batch_dataset=QueryDataSet(querysetfile=config.testQuerySet,transform=transforms.Compose([ batch_stackup(config.index_dict)]))                              
test_batch_loader = DataLoader(test_batch_dataset, batch_size=batch_size, collate_fn=pad_batch)

In [None]:
old_epochs = config.prev_model_epochs

if old_epochs!=0:
    model = load_model(config.prev_model_path, n, k, q, use_cuda=True, eval_flag=False)
else:
    model = build_model(n, k, q, use_cuda=config.use_cuda)

model = model.to(model.device, dtype=torch.float32)
        
epochs = config.epochs

total_train_batches = train_batch_loader.__len__() 
total_val_batches = val_batch_loader.__len__()

episode = episodeBuilder(model, n, k, bshot, q, batch_size, 
                          data_train=train_dataloader,
                          batch_train=train_batch_loader,
                          batch_val=val_batch_loader,
                          batch_test=test_batch_loader)

logdir=config.model_store_path

try:
    os.stat(logdir)
except:
    os.makedirs(logdir)

In [None]:
print('MN_CTC Training started...')
loss, acc, val_loss, val_acc = [], [], [], []


start_time = datetime.now()
with tqdm(total=epochs, desc='epochs', leave=False) as pbar:
    for e in range(epochs):
        epoch_start_time = datetime.now()
        
        total_c_loss, total_error = episode.run_training_epoch(total_train_batches)
        loss.append(total_c_loss); acc.append(total_error)
        total_val_c_loss, total_val_error = episode.run_val_epoch(total_val_batches)
        val_loss.append(total_val_c_loss); val_acc.append(total_val_error)

        print("Epoch {}: train: [loss-{:.6f} error-{:.6f} ], val: [loss-{:.6f} error-{:.6f}]".                 
              format(e+old_epochs, total_c_loss.item(), total_error,                 
                     total_val_c_loss.item(), total_val_error))
        pbar.update(1)

        end_time = datetime.now()
        print('Epoch Training time:', end_time-epoch_start_time)
        # save model
        modelpath = episode.save_model(e+old_epochs,fpath= logdir + '/model.pth')
        
print('Total Training time:', end_time-start_time)
print('No. of classes: {}, No. of support samples per class: {}, No. of query samples per class: {}'.format(
                k, n, q))
print('Epochs: {}, No. of batches in train: {}, No. of batches in val: {}, Each batch size: {}'.format(
                epochs, total_train_batches, total_val_batches, batch_size))
print('Training completed...')


In [None]:

total_test_batches = test_batch_loader.__len__()
episode.test_file = open(logdir + '/test_log_r1',"w")
per = episode.run_test_epoch(total_test_batches)

print("Per of network : ", per*100)
episode.test_file.write('\nper = %f\n'%(per))
episode.test_file.close()
print('Testing completed...')

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(loss,label='training')
plt.plot(val_loss,label='validation')
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.title("Evolution of the loss function")
plt.legend()
plt.savefig(logdir + "/CTC_loss_r2.png")

plt.figure()
plt.plot(acc, label='training')
plt.plot(val_acc,label='validation')
plt.ylabel("Error")
plt.xlabel("Epoch")
plt.title("Evolution of the WER function")
plt.legend()
plt.savefig( logdir+ "/error_r2.png")