In [None]:
from torch.optim import Adam
from torch import nn
from torch.nn.utils import clip_grad_norm_
import numpy as np
import torch
from torchvision import transforms
import torchvision
import sys
import os
from datetime import datetime
import random

from pytorch_end2end import CTCDecoder
from jiwer import wer
from dataloader.ctc_support_set_loader import SupportDataSet, collate_wrapper, stackup
from dataloader.ctc_batch_loader import QueryDataSet, batch_stackup, pad_batch
from torch.utils.data import DataLoader

import Configuration as config
from tqdm.notebook import tqdm

In [None]:
EPSILON = 1e-8
seed=config.seed
torch.manual_seed(seed)
np.random.seed(seed)

# CNN Embedding

In [None]:
# cnn embedding - encoder_g  (SS)
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def conv_block(in_channels: int, out_channels: int) -> nn.Module:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

def few_shot_cnn_encoder(num_input_channels=1, outchnls=64) -> nn.Module:
    return nn.Sequential(
        conv_block(num_input_channels, outchnls),
        conv_block(outchnls, outchnls),
        conv_block(outchnls, outchnls),
        Flatten(),
    )

# LSTM Embedding

In [None]:
# lstm embedding - encoder_f (Batch utterances)
class few_shot_lstm_encoder(nn.Module):
    def __init__(self, lstm_inp_dim):
        super(few_shot_lstm_encoder, self).__init__()
        
        self.count = 0
        
        # LSTM parameters
        self.lstm_inp_dim = lstm_inp_dim  #39
        self.lstm_hidden_dim = 128
        self.lstm_layers = 2
        self.batch_norm = True
                
        self.bilstm = nn.LSTM(input_size=self.lstm_inp_dim, hidden_size=self.lstm_hidden_dim, num_layers=self.lstm_layers,
                              bidirectional=True,batch_first=True)
        
     
    def forward(self, input):
        ys, _ = self.bilstm(input)  #1024
        return ys

# Matching Network

In [None]:
class MatchingNetwork(nn.Module):
    def __init__(self, n, k, q, num_input_channels, lstm_input_size,use_cuda=False):
 
        super(MatchingNetwork, self).__init__()
        self.n = n
        self.k = k
        self.q = q
        self.num_input_channels = num_input_channels
        self.lstm_input_size = lstm_input_size
        self.encoder_g = few_shot_cnn_encoder(self.num_input_channels)    # cnn encoder for SS , emb_size = 256
        self.encoder_f = few_shot_lstm_encoder(self.lstm_input_size)     # LSTM encoder for batch , emb_size = 256
        
        isCudaAvailable = torch.cuda.is_available()
        self.use_cuda = use_cuda
        
        if isCudaAvailable & self.use_cuda:
            self.device=torch.device('cuda:0')
        else:
            self.device=torch.device('cpu')

    def forward(self, inputs):
        pass

# Cosine computation

In [None]:
class DistanceNetwork(nn.Module):
    def __init__(self,n,k,q,bshot):
        super(DistanceNetwork, self).__init__()
        self.n=n
        self.k=k
        self.q=q
        self.bshot=bshot

    def forward(self, support_set, query_set):
        eps = 1e-10
        similarities = []
        
        f = self.bshot  # total no:of blanks in support set
        sum_support = torch.sum(torch.pow(support_set, 2), 2)
        support_magnitude = sum_support.clamp(eps, float("inf")).rsqrt()
        dot_product = query_set.matmul(support_set.permute(0,2,1))
             
        cosine_similarity = dot_product * (support_magnitude.unsqueeze(1))  # for ctc
        cosine_similarity  = torch.cat( [cosine_similarity[:,:, :f].topk(self.n, dim=2).values, cosine_similarity [:,:, f:]], dim=2) # take top n blanks    
        return cosine_similarity

In [None]:
class episodeBuilder:
    def __init__(self, model, n_shot, k_way, b_shot, q_queries,    # changes for handling more blank samples 
                 batch_size,
                 data_train=None,data_val=None,data_test=None,batch_test=None):
        
        self.optimiser = Adam(model.parameters(), lr=1e-3)
        
        self.label_dict = config.label_dict
        self.ctc_labels = config.ctc_labels
        
        self.decoder = CTCDecoder(blank_idx=0, beam_width=100, time_major=False,after_logsoftmax=True,labels=self.ctc_labels)
         
        
        self.model=model
    
        self.n_shot=n_shot
        self.b_shot=b_shot  # changes for handling more blank samples
        self.k_way=k_way
        self.q_queries=q_queries
        self.batch_size=batch_size
        self.batch_test = batch_test
        self.dn = DistanceNetwork(n_shot,k_way,q_queries,b_shot)

        self.data_test=data_test
        
        self.device=self.model.device
        
    def each_episode_ctc(self, X,y,batch_query, trainflag=True):
       
        self.train = trainflag
        
        if self.train:
            # Zero gradients
            self.model.train()
            self.optimiser.zero_grad()
        else:
            self.model.eval()

        self.batch_size=X.size(0)
        
        inputs, input_sizes, targets, target_sizes, utt = batch_query
        inputs = inputs.to(self.device)
               
        inputs = torch.squeeze(inputs,2)
        inputs = inputs.reshape(inputs.shape[0],inputs.shape[1], -1)  #(batch_size,frames,inp_dim)
        
        # Embed all query utterances using LSTM ( to use the sequence information - f)
        embeddings_X_ctc = self.model.encoder_f(inputs)
        
        
        
        # Embed SS samples using CNN embedding (g)
        encoded_items_X = []
        for i in np.arange(self.batch_size):
            ssi_X = X[i]
            gen_encode_X = self.model.encoder_g(ssi_X)
            encoded_items_X.append(gen_encode_X)
        embeddings_X = torch.stack(encoded_items_X)
        
        support = embeddings_X
        queries = embeddings_X_ctc
                
        similarities = self.dn(support_set=support, query_set=queries)
        softmax = nn.LogSoftmax(dim=2)

        attention = self.matching_net_predictions(similarities)
        ypreds = softmax(attention)   # [batch_size , len, num_classes]
        
        ypreds = ypreds.to(self.device, dtype=torch.float32)  
        ypreds_sizes = torch.zeros(self.batch_size)
        ypreds_sizes[0] = ypreds.shape[1]
        ypreds_sizes = ypreds_sizes.long().to(self.device)
        values, indices = torch.max(ypreds, dim=-1)
        decoded_targets, decoded_targets_lengths, decoded_sentences = self.decoder.decode(ypreds,ypreds_sizes)
        target_label_indices = torch.squeeze(targets,0).tolist()
        target_label = [self.label_dict[ele] for ele in target_label_indices]
        pred_index = decoded_targets.view(decoded_targets.shape[1]).tolist()
        

        pred_label = [self.label_dict[ele] for ele in pred_index]
        per = wer(target_label,pred_label)
        
        return per

    def matching_net_predictions(self, attention):
        """Calculates Matching Network predictions based on equation (1) of the paper.
        """
        q=self.q_queries
        k=self.k_way
        n=self.n_shot
        
       
        y_preds=[]
        for eachbatch in range(attention.size(0)):
            # Create one hot label vector for the support set
            y_onehot = torch.zeros(k * n, k)
            
            ys = self.create_nshot_task_label(k, n).unsqueeze(-1)       
            y_onehot = y_onehot.scatter(1, ys, 1)
        
            y_pred = torch.mm(attention[eachbatch], y_onehot.to(self.device, dtype=torch.float32))
            y_preds.append(y_pred)
            
        y_preds=torch.stack(y_preds)

        return y_preds
    
    def create_nshot_task_label(self, k, n):
        return torch.arange(0, k, 1 / n).long()    
    
    def run_training_epoch(self, total_train_batches):
        total_loss = 0.0
        total_accuracy = 0.0
        with tqdm(total=total_train_batches, desc='train batches:', leave=False) as pbar:
            for i in range(total_train_batches):
                data = next(iter(self.data_train))
                X,y,ylabels=data
                X=X.to(self.device, dtype=torch.float32)
                y=y.to(self.device, dtype=torch.long)
                loss, acc, _ = self.each_episode(X, y, trainflag=True)
                
                #raise Exception("Forced break...testing")
                
                loss.backward()
                clip_grad_norm_(self.model.parameters(), 1)
                self.optimiser.step()
                pbar.update(1)
                total_loss += loss.data
                total_accuracy += acc.data

            total_loss = total_loss / total_train_batches
            total_accuracy = total_accuracy / total_train_batches
            return total_loss, total_accuracy

        
    def run_test_epoch(self, total_test_batches):
        total_per = 0.0
        with tqdm(total=total_test_batches, desc='test batches:', leave=False) as pbar:
            pred_list=[]
            for i,  batch_query in enumerate(self.batch_test):   # to iterate through all utterances in an epoch
                support_data = next(iter(self.data_test))
                X,y,ylabels=support_data
                X=X.to(self.device, dtype=torch.float32)
                y=y.to(self.device, dtype=torch.long)
                per = self.each_episode_ctc(X, y, batch_query, trainflag=False)
                total_per += per
                print("test epoch  ", i,  "PER => ", per, "  and avg PER  ",total_per/(i+1))
                pbar.update(1)
                
        total_per = total_per / total_test_batches
        return total_per

In [None]:
def build_model(n_train, k_train, q_train, fce=False, use_cuda=False):
    num_input_channels=1 
    lstm_input_size = 39 # encoder output   ## check

    model = MatchingNetwork(n_train, k_train, q_train, num_input_channels, lstm_input_size,      # include lstm size
                            use_cuda=use_cuda)

    model=model.to(model.device, dtype=torch.float32)
    return model

def load_model(fpath, n_train, k_train, q_train, use_cuda=False, eval_flag=True):
    model = build_model(n_train, k_train, q_train, use_cuda=use_cuda)
    model.load_state_dict(torch.load(fpath,map_location=model.device))
    model = model.to(model.device, dtype=torch.float32)
    if eval_flag:
        model.eval()
    return model

In [None]:
q=1;#q=1 is supported as of now
batch_size=1 # batch_size=1 is recommended as of now

n = config.N
k = config.K
bshot = config.blanks_inference

inference_model=load_model(config.saved_model_for_inference, n, k, q, use_cuda=config.use_cuda, eval_flag=True)


test_dataset=SupportDataSet(supportdatafile=config.devSupportSet, kway=k, nshot=n, bshot=bshot, nqueries=q, 
                            phonemes=config.ALLPHONEMES, index_dict=config.index_dict, label_dict=config.label_dict,
                           transform=transforms.Compose([ stackup()]))
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_wrapper)

 

batch_dataset=QueryDataSet(querysetfile=config.testQuerySet,transform=transforms.Compose([ batch_stackup(config.index_dict)]))                              
batch_loader = DataLoader(batch_dataset, batch_size=batch_size, collate_fn=pad_batch)

episode1 = episodeBuilder(inference_model, n, k, bshot, q, batch_size,
                          data_test=test_dataloader,
                          batch_test=batch_loader)

total_test_batches = batch_loader.__len__()
per = episode1.run_test_epoch(total_test_batches)
print("Per of network : ", per*100)  # blanks 10 shots per phn 