Note: some of this code comes from the deepspin/quati github repository of Andre Martins's lab.

## To do:
1. Use dataloader and dataset classes for test
2. Let the number of MLP layers and their widths at the end be easily controllable.
3. Write a better test set evaluation wrapper.
4. Use a closed form cts sparsemax and Gaussian mixture.
5. Refactor numerical integration code to reuse when possible.
6. Refactor code in general.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import math
from basis_functions import GaussianBasisFunctions
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.metrics import f1_score, accuracy_score
import torch.nn.functional as F
import random
import pickle as pkl
import time

In [2]:
#this is based on https://www.youtube.com/watch?v=PXOzkkB5eH0
class FordADataset(Dataset):
    def __init__(self,dset_string):
        '''
        @dset_string: set to "FordA_TRAIN.tsv" or "FordA_TEST.tsv"
        '''
        root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
        x, y = self.readucr(root_url + dset_string)
        self.x = x
        self.y = y
        self.n_samples = len(y)
    
    def readucr(self,filename):
        data = np.loadtxt(filename, delimiter="\t")
        y = torch.from_numpy(data[:, 0]).long()
        y[y==-1]=0
        x = torch.from_numpy(data[:, 1:]).float()
        return x,y
    
    def __getitem__(self,index,train=True):
        return self.x[index], self.y[index]
    
    def n_classes(self):
        return len(np.unique(self.y))
    
    def __len__(self):
        return self.n_samples

In [5]:
device = 1

In [8]:
def get_batch(X,Y,i,bs):
    batch_len = min(bs, X.shape[0] - 1 - i)
    if bs==1:
        batch_len = min(bs,X.shape[0]-i)
    data = X[i:i+batch_len,:]
    target = Y[i:i+batch_len]
    return data,target

In [9]:
def add_gaussian_basis_functions(nb_basis, sigmas, device=None):
    mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)),
                               torch.Tensor(sigmas))
    mus = mu.flatten().to(device)
    sigmas = sigma.flatten().to(device)
    return GaussianBasisFunctions(mus, sigmas)

def create_psi(length,nb_basis,device=None):
    psi = []
    nb_waves = nb_basis
    nb_waves = max(2,nb_waves)
    psi.append(
                add_gaussian_basis_functions(nb_waves,
                                             sigmas=[.025,.1, .3],
                                             # sigmas=[.03, .1, .3],
                                             device=device)
            )
    return psi

In [10]:
def get_G(max_length,nb_basis,device=None):
    psis=[]
    Gs = []

    for length in range(1,max_length+1):
        psi = create_psi(length,nb_basis,device=device)
        shift = 1 / float(2 * length)
        positions = torch.linspace(shift, 1 - shift, length)
        positions = positions.unsqueeze(1).to(device)
        all_basis = [basis_function.evaluate(positions)
                             for basis_function in psi]
        F = torch.cat(all_basis, dim=-1).t().to(device)
        nb_basis = sum([len(b) for b in psi])
        assert F.size(0) == nb_basis

        # compute G with a ridge penalty
        penalty = 5e-1
        I = torch.eye(nb_basis).to(device)
        G = F.t().matmul((F.matmul(F.t()) + penalty * I).inverse())
        psis.append(psi)
        Gs.append(G)
    G = Gs[max_length-1].to(device)
    return G,F

In [11]:
class Conv(torch.nn.Module):
    def __init__(self, input_size,n_classes):
        super(Conv, self).__init__()
        pool_size = 2
        self.relu = torch.nn.ReLU()

        self.conv1 = torch.nn.Conv1d(1,24,5,padding='same')
        self.conv2 = torch.nn.Conv1d(24,24,5,padding='same')

    def forward(self, x):
        global_conv = 1
        #two convs, 16 filters
        convolved = self.relu(self.conv1(x))
        convolved = self.relu(self.conv2(convolved))
        return convolved

In [12]:
class RNN(nn.Module):
    """RNN module(cell type lstm or gru)"""
    def __init__(
        self,
        input_size,
        hid_size,
        num_rnn_layers=1,
        dropout_p = 0.2,
        bidirectional = False,
        rnn_type = 'lstm',
    ):
        super().__init__()
        
        if rnn_type == 'lstm':
            self.rnn_layer = nn.LSTM(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )
            
        else:
            self.rnn_layer = nn.GRU(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )
    def forward(self, input):
        outputs, hidden_states = self.rnn_layer(input)
        return outputs, hidden_states

In [13]:
class RNNAttn(nn.Module):
    def __init__(
        self,
        input_size,
        hid_size,
        rnn_type,
        bidirectional,
        dx=None,
        bandwidth=None,
        nb_basis=32,
        n_classes=5,
        attn_type='cts_softmax',
        inducing_points = 24,
        device=None,
        components = 0
    ):
        super().__init__()
        
        self.attn_type = attn_type
        G,F = get_G(input_size,nb_basis,device=device)
        #G: 1 x input_size x nb_basis
        self.G = G.unsqueeze(0)
        self.F = F
        self.a = 0
        self.b = 1
        self.components = components
        
        if bandwidth==None:
            self.bandwidth = 5./input_size
        else:
            self.bandwidth = bandwidth
        if dx==None:
            self.dx = input_size
        else:
            self.dx = dx
        GB = add_gaussian_basis_functions(nb_basis,sigmas=[0.025,.1, .5])
        self.mu_basis = GB.mu
        self.sigma_basis = GB.sigma
        self.inducing_points = inducing_points
        self.threshold = 1.5
        
        self.rnn_layer = RNN(
            input_size=24,#hid_size * 2 if bidirectional else hid_size,
            hid_size=hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )
        self.conv = Conv(input_size,n_classes)
        self.relu = torch.nn.ReLU()
        self.plus = torch.nn.ReLU()
        self.avgpool = nn.AdaptiveAvgPool1d((1))
        
        self.attn_fc1 = nn.Linear(hid_size, hid_size, bias=True)
        self.attn_fc2 = nn.Linear(hid_size,1,bias=True)
        self.attn_softmax = torch.nn.Softmax(dim=1)
        self.inducing_map = nn.Identity(input_size,inducing_points)
        
        if components>0:
            self.encode_pi1 = torch.nn.Linear(input_size, components)
            self.encode_mu = torch.nn.Linear(input_size, components)
            self.encode_sigma_sq1 = torch.nn.Linear(input_size, components)
            self.encode_sigma_sq2 = torch.nn.Softplus()
        
        self.fc1 = nn.Linear(in_features=hid_size, out_features=2*hid_size)
        self.fc2 = nn.Linear(in_features=2*hid_size, out_features=hid_size)
        self.fc3 = nn.Linear(in_features=hid_size, out_features=n_classes)
    
    def update_device(self,device):
        self.G = self.G.to(device)
        self.F = self.F.to(device)
        self.mu_basis = self.mu_basis.to(device)
        self.sigma_basis = self.sigma_basis.to(device)
        
    def get_alpha(self,unscaled_attn):
        return unscaled_attn
    #torch.clamp(unscaled_attn,max=1)+torch.clamp(torch.sqrt(self.relu(unscaled_attn)),max=5)

    def attention_weights(self,input):
        #x: bs x features x T
        x = self.conv(input)
        #x: bs x T x features
        x = torch.transpose(x,1,2)
        #x_out: bs x T x hid_size
        x_out, _ = self.rnn_layer(x)
        unscaled_attn = self.attn_fc1(x_out)
        unscaled_attn = torch.tanh(unscaled_attn)
        unscaled_attn = self.attn_fc2(unscaled_attn)
        attn_weights = self.attn_softmax(unscaled_attn)
        attn_weights = torch.transpose(attn_weights,1,2)
        if self.attn_type=='discrete':
            return attn_weights
        elif self.attn_type=='cts_softmax':
            mu = torch.matmul(attn_weights,torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float())/attn_weights.shape[-1]
            sigma_sq = torch.matmul(attn_weights,(torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float()/attn_weights.shape[-1])**2)-mu**2
            attn_weights = self._integrate_product_of_gaussians(mu,sigma_sq,False)
            return attn_weights
        elif self.attn_type=='cts_sparsemax':
            mu = torch.matmul(attn_weights,torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float())/attn_weights.shape[-1]
            sigma_sq = torch.matmul(attn_weights,(torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float()/attn_weights.shape[-1])**2)-mu**2
            attn_weights = self._integrate_wrt_truncated_parabaloid(mu,sigma_sq,False)
            return attn_weights
        elif self.attn_type=='kernel_softmax':
            mu = torch.matmul(attn_weights,torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float())/attn_weights.shape[-1]
            sigma_sq = torch.matmul(attn_weights,(torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float()/attn_weights.shape[-1])**2)-mu**2
            alpha = self.get_alpha(unscaled_attn.transpose(1,2))
            attn_weights = self._integrate_kernel_exp_wrt_gaussian(mu,sigma_sq,alpha,False)
            return attn_weights
        elif self.attn_type=='gaussian_mixture':
            alpha  = self.get_alpha(unscaled_attn.transpose(1,2))
            #Compute mu, sigma_sq
            mu = self.encode_mu(alpha).squeeze(1)
            sigma_sq = self.encode_sigma_sq1(alpha)
            sigma_sq = self.encode_sigma_sq2(sigma_sq).squeeze(1)
            #compute pi
            pi = self.encode_pi1(alpha).squeeze(1)
            attn_weights = self._integrate_gaussian_mixture(mu,sigma_sq,pi,False)
            return attn_weights
        elif self.attn_type=='kernel_sparsemax':
            mu = torch.matmul(attn_weights,torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float())/attn_weights.shape[-1]
            sigma_sq = torch.matmul(attn_weights,(torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float()/attn_weights.shape[-1])**2)-mu**2
            alpha = self.get_alpha(unscaled_attn.transpose(1,2))
            attn_weights = self._integrate_wrt_kernel_deformed(mu,sigma_sq,alpha,False)
            return attn_weights
        
    def forward(self, input):
        #x: bs x features x T
        x = self.conv(input)
        #x: bs x T x features
        x = torch.transpose(x,1,2)
        layer_norm = nn.LayerNorm(x.shape[2],elementwise_affine=False)
        x = layer_norm(x)
        #x_out: bs x T x hid_size
        x_out, _ = self.rnn_layer(x)
        layer_norm = nn.LayerNorm(x_out.shape[2],elementwise_affine=False)
        x_out = layer_norm(x_out)
        if torch.isnan(x_out).any():
            print('hidden states have nans problem')
            print(5/0)
            x_out = torch.nan_to_num(x_out,0)
        #first get discrete attention weights
        unscaled_attn = self.attn_fc1(x_out)
        unscaled_attn = torch.tanh(unscaled_attn)
        unscaled_attn = self.attn_fc2(unscaled_attn)
        attn_weights = self.attn_softmax(unscaled_attn)
        
        attn_weights = torch.transpose(attn_weights,1,2)
        x = x_out
        
        if self.attn_type=='discrete':
            #x: bs x hid_size x T
            x = torch.transpose(x_out,1,2)
            c = self.avgpool(x*attn_weights)
            c = torch.transpose(c,1,2)
        else:
            mu = torch.matmul(attn_weights,torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float())/attn_weights.shape[-1]
            sigma_sq = torch.matmul(attn_weights,(torch.arange(1,attn_weights.shape[-1]+1,device=attn_weights.device).float()/attn_weights.shape[-1])**2)-mu**2
            alpha = self.get_alpha(unscaled_attn.transpose(1,2))
            #B: bs x hid_size x nb
            B = torch.matmul(torch.transpose(x,1,2),self.G)
            if self.attn_type=='cts_softmax':
                #numerical_integral: bs x nb x 1
                numerical_integral = self._integrate_product_of_gaussians(mu,sigma_sq)
            elif self.attn_type=='kernel_softmax':
                numerical_integral = self._integrate_kernel_exp_wrt_gaussian(mu,sigma_sq,alpha).transpose(1,2)#.unsqueeze(-1)
                #print(torch.max(numerical_integral))
            elif self.attn_type=='cts_sparsemax':
                numerical_integral = self._integrate_wrt_truncated_parabaloid(mu,sigma_sq).unsqueeze(-1)
            elif self.attn_type=='kernel_sparsemax':
                numerical_integral = self._integrate_wrt_kernel_deformed(mu,sigma_sq,alpha).transpose(1,2)
            elif self.attn_type=='gaussian_mixture':
                #Compute mu, sigma_sq
                mu = self.encode_mu(alpha).squeeze(1)
                sigma_sq = self.encode_sigma_sq1(alpha)
                sigma_sq = self.encode_sigma_sq2(sigma_sq).squeeze(1)
                #compute pi
                pi = self.encode_pi1(alpha).squeeze(1)
                #integrals gives shape: bs x heads x nb
                numerical_integral = self._integrate_gaussian_mixture(mu,sigma_sq,pi).unsqueeze(-1)
            elif self.attn_type=='None':
                numerical_integral = self._integrate_product_of_gaussians(mu,sigma_sq)
                x = torch.transpose(x_out,1,2)
                c = self.avgpool(x)
            if torch.isnan(numerical_integral).any():
                print('numerical integral has nans problem')
                print(5/0)
            numerical_integral = torch.nan_to_num(numerical_integral,0)
            c = torch.bmm(B,numerical_integral)
            c = torch.transpose(c,1,2)
            if torch.isnan(x).any():
                print('output has nans problem')
                print(5/0)
                x = torch.nan_to_num(x,0)
        x = self.fc1(c)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x
        
    def _phi(self,t):
        return 1.0/math.sqrt(2*math.pi)*torch.exp(-0.5*t**2)
    
    def gaussian_rbf(self,t,bandwidth):
        return torch.exp(-0.5*t**2/bandwidth)

    def exp_kernel(self,t,bandwidth):
        return torch.exp(-torch.abs(t)/bandwidth)
    
    def beta_exp(self,t):
        q = 0
        return self.plus(1+(1-q)*t)**(1./(1-q))
        
    def truncated_parabola(self,t,mu,sigma_sq,m,integrate=True):
        return self.plus(1+t*(mu/(sigma_sq))-t**2/(2*sigma_sq)-m)
        #return self.plus(-(t-mu)**2/(2*sigma_sq)+0.5*(3/(2*torch.sqrt(sigma_sq)))**(2./3.))
    
    def _integrate_product_of_gaussians(self,mu,sigma_sq,integrate=True):
        sigma = torch.sqrt(self.sigma_basis.unsqueeze(-1)**2+sigma_sq.unsqueeze(-1))
        if integrate:
            return self._phi((mu.unsqueeze(-1)-self.mu_basis.unsqueeze(-1))/sigma_sq.unsqueeze(-1))/self.sigma_basis.unsqueeze(-1)
        else:
            T = torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(0).unsqueeze(0)
            return self._phi((mu.unsqueeze(-1)-T))/sigma_sq.unsqueeze(-1)
    
    def _integrate_wrt_truncated_parabaloid(self,mu,sigma_sq,integrate=True):
        T = torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(0).unsqueeze(0)
        #phi1_upper: size 1 x x nb x dx
        phi1_upper = self.mu_basis.unsqueeze(-1)-T
        #phi1_lower: size 1 x nb x 1
        phi1_lower = self.sigma_basis.unsqueeze(-1)
        #phi1: size 1 x 1 x nb x dx
        phi1 = self._phi(phi1_upper/phi1_lower)/phi1_lower
        m = torch.max(T*(mu.unsqueeze(-1)/(sigma_sq.unsqueeze(-1)))-T**2/(2*sigma_sq.unsqueeze(-1)),-1)[0].unsqueeze(-1)
        deformed_term = self.truncated_parabola(T,mu.unsqueeze(-1),sigma_sq.unsqueeze(-1),m)
        unnormalized_density = deformed_term
        Z = torch.trapz(unnormalized_density,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1).unsqueeze(-1)
        numerical_integral = torch.trapz(phi1*unnormalized_density/Z,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1)
        #numerical_integral = torch.nan_to_num(numerical_integral,0)
        if integrate:
            if torch.isnan(numerical_integral).any():
                print('alpha')
                print(alpha)
                print('is alpha negative anywhere')
                print(torch.sum(alpha<=0))
                print('K')
                print(K)
                print('exp_f')
                print(exp_f)
                print('Z')
                print(Z)
                print(5/0)
            return numerical_integral
        else:
            return unnormalized_density
    
    def _integrate_kernel_exp_wrt_gaussian(self,mu,sigma_sq,alpha,integrate=True):
        T = torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(0).unsqueeze(0)
        inducing_locations = torch.linspace(0,1,self.inducing_points,device=mu.device)
        #phi1_upper: size 1 x nb x dx
        phi1_upper = self.mu_basis.unsqueeze(-1)-T
        #phi1_lower: size 1 x 1 x nb x 1
        phi1_lower = self.sigma_basis.unsqueeze(-1)
        #phi1: size 1 x 1 x nb x dx
        phi1 = self._phi(phi1_upper/phi1_lower)/phi1_lower
        K_inputs = torch.cdist(inducing_locations.unsqueeze(-1),torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(-1))
        #K = self.gaussian_rbf(K_inputs,self.bandwidth)
        K = self.exp_kernel(K_inputs,self.bandwidth)
        m = torch.max(torch.matmul(alpha,K).unsqueeze(-2),-1)[0].unsqueeze(-1)
        exp_terms = torch.exp(torch.matmul(alpha,K).unsqueeze(-2)-m)
        #m = Sparsemax(dim=3)
        #exp_terms = m(torch.matmul(alpha,K).unsqueeze(-2))
        Z = torch.trapz(exp_terms,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1).unsqueeze(-1)
        numerical_integral = torch.trapz(phi1*exp_terms/Z,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1)
        if torch.isnan(numerical_integral).any():
            print('had to rescale')
            exp_terms = self.beta_exp(torch.matmul(alpha,K).unsqueeze(-2)/torch.abs(m))
            Z = torch.trapz(exp_terms,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1).unsqueeze(-1)
            numerical_integral = torch.trapz(phi1*exp_terms/Z,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1)
        if integrate:
            return numerical_integral
        else:
            return exp_terms

    def _integrate_wrt_kernel_deformed(self,mu,sigma_sq,alpha,integrate=True):
        T = torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(0).unsqueeze(0)
        inducing_locations = torch.linspace(0,1,self.inducing_points,device=mu.device)
        #phi1_upper: size 1 x nb x dx
        phi1_upper = self.mu_basis.unsqueeze(-1)-T
        #phi1_lower: size 1 x 1 x nb x 1
        phi1_lower = self.sigma_basis.unsqueeze(-1)
        #phi1: size 1 x 1 x nb x dx
        phi1 = self._phi(phi1_upper/phi1_lower)/phi1_lower
        K_inputs = torch.cdist(inducing_locations.unsqueeze(-1),torch.linspace(self.a,self.b,self.dx,device=mu.device).unsqueeze(-1))
        #K = self.gaussian_rbf(K_inputs,self.bandwidth)
        K = self.exp_kernel(K_inputs,self.bandwidth)
        m = torch.max(torch.matmul(alpha,K).unsqueeze(-2),-1)[0].unsqueeze(-1)
        exp_terms = self.beta_exp(torch.matmul(alpha,K).unsqueeze(-2)-m)
        #exp_terms = self.beta_exp(torch.matmul(alpha,K).unsqueeze(-2))
        #m = Sparsemax(dim=3)
        #exp_terms = m(torch.matmul(alpha,K).unsqueeze(-2))
        Z = torch.trapz(exp_terms,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1).unsqueeze(-1)
        numerical_integral = torch.trapz(phi1*exp_terms/Z,torch.linspace(self.a,self.b,self.dx,device=mu.device),dim=-1)
        if integrate:
            return numerical_integral
        else:
            return exp_terms
    
    def _integrate_gaussian_mixture(self,mu,sigma_sq,pi,integrate=True):
        #T: size 1 x 1 x 1 x dx
        T = torch.linspace(self.a,self.b,self.dx,device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
        #phi1_upper: size 1 x 1 x nb x dx
        phi1_upper = self.mu_basis.unsqueeze(-1)-T
        #phi1_lower: size 1 x 1 x nb x 1
        phi1_lower = self.sigma_basis.unsqueeze(-1)
        #phi1: size 1 x 1 x nb x dx
        phi1 = self._phi(phi1_upper/phi1_lower)/phi1_lower
        #phi2_upper: size bs x components x 1 x dx
        phi2_upper = mu.unsqueeze(-1).unsqueeze(-1)-T
        #phi2_lower: size bs x components x 1 x 1
        phi2_lower = sigma_sq.unsqueeze(-1).unsqueeze(-1).pow(0.5)
        #phi2: size bs x components x 1 x dx
        phi2 = self._phi(phi2_upper/phi2_lower)/phi2_lower
        bs = phi2.shape[0]
        phi2 = torch.reshape(phi2,(bs,self.components,1,self.dx))
        pi = torch.reshape(pi,(bs,self.components))
        pi = torch.softmax(pi,dim=1)
        pi = pi.unsqueeze(-1).unsqueeze(-1)
        phi2 = pi*phi2
        phi2 = torch.sum(phi2, dim=1)
        #phi1*phi2: size bs x nb x dx
        numerical_integral = torch.trapz(phi1.squeeze(1)*phi2,torch.linspace(self.a,self.b,self.dx,device=device),dim=-1)   
        if integrate:
            return numerical_integral
        else:
            return phi2

In [14]:
def train_model(model,dataloader,epochs,device,scaler=None,lr=1e-3):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=1e-5)
    model.to(device,dtype=torch.float32)
    model.update_device(device)
    model.train() 
    for epoch in range(epochs):
        dim = 0
        batch = 0
        for i, (data, target) in enumerate(dataloader):
            data = data.to(device=device).unsqueeze(1)
            target = target.to(device=device)
            optimizer.zero_grad()
            # Forward pass
            
            # Compute Loss
            if scaler!=None:
                with torch.cuda.amp.autocast(dtype=torch.float32,enabled=False):
                    y_pred = model(data)
                    loss = criterion(y_pred.squeeze(), target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                # Updates the scale for next iteration
                scaler.update()
            else:
                with torch.cuda.amp.autocast():
                    y_pred = model(data)
                    loss = criterion(y_pred.squeeze(), target)
                    loss.backward()
                    optimizer.step()
            if i%10==0:
                print(loss)
                print(i)
                print('epoch')
                print(epoch)
    return model

In [15]:
def test_model(model,x_test,y_test):
    with torch.no_grad():
            y_pred = []
            y_test_cpu = []
            for j in range(len(y_test)):
                data,target = get_batch(x_test,y_test,j,1)
                data = data.to(device).unsqueeze(1)
                
                prediction = model(data).argmax(dim=2).cpu()
                y_pred.append(prediction.flatten().item())
                y_test_cpu.append(target.item())
                if j%100==0:
                    print(j)
    return y_pred, y_test_cpu

In [16]:
def get_metrics(y_test_cpu,y_pred):
    confusion = metrics.confusion_matrix(y_test_cpu,y_pred)
    for i in range(len(confusion)):
        print(confusion[i,:]/np.sum(confusion[i,:]),np.sum(confusion[i,:]))
    f1 = f1_score(y_test_cpu, y_pred, average="macro")
    print("Test f1 score : %s "% f1)
    acc = accuracy_score(y_test_cpu, y_pred)
    print("Test accuracy score : %s "% acc)
    return f1, acc

In [17]:
bs = 64
nb_basis = 256
epochs = 1
j=0

FordATrain = FordADataset("FordA_TRAIN.tsv")
DataLoaderTrain = DataLoader(dataset = FordATrain, batch_size = bs,shuffle = True)
n_classes = FordATrain.n_classes()
features, labels = next(iter(DataLoaderTrain))
input_size = features.shape[1]

dx = input_size
components = 1#input_size
runs = 1
hid_sizes = [128]
lr = 1e-3
inducing_points = input_size

rnn_types = ['lstm']
attn_types=['cts_softmax']
bidirectional = False
models = []
bandwidth = 1./input_size
for run in range(runs):
    for hid_size in hid_sizes:
        for rnn_type in rnn_types:
            for attn_type in attn_types:
                model = RNNAttn(input_size,hid_size,rnn_type,bidirectional,dx=dx,nb_basis=nb_basis,bandwidth=bandwidth,attn_type=attn_type,device=device,inducing_points=inducing_points,components=components)
                models.append(model)
                model.to(device)

trained_models = []
np.set_printoptions(suppress=True)
scaler=None
start = time.time()
while len(models)>0:
    model = models.pop(0)
    model = train_model(model,DataLoaderTrain,epochs,device,scaler=scaler,lr=lr)
    trained_models.append(model)
end = time.time()
print(end-start)
    

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tensor(1.6293, device='cuda:1', grad_fn=<NllLossBackward0>)
0
epoch
0
tensor(0.7728, device='cuda:1', grad_fn=<NllLossBackward0>)
10
epoch
0
tensor(0.6871, device='cuda:1', grad_fn=<NllLossBackward0>)
20
epoch
0
tensor(0.7145, device='cuda:1', grad_fn=<NllLossBackward0>)
30
epoch
0
tensor(0.7027, device='cuda:1', grad_fn=<NllLossBackward0>)
40
epoch
0
tensor(0.6890, device='cuda:1', grad_fn=<NllLossBackward0>)
50
epoch
0
2.3791890144348145


In [18]:
d_nice_names = {'cts_softmax':'Continuous Softmax','cts_sparsemax':'Continuous Sparsemax','kernel_softmax':'Kernel Softmax','kernel_sparsemax':'Kernel Sparsemax','gaussian_mixture':'Gaussian Mixture'}

In [19]:
# total = 0
# max_display = 4
# display_class = 0
# selections = np.random.choice(np.arange(len(y_test)),max_display)
# i = 1
# for k in selections:
#     print('iterating')
#     data = x_test[k,:].unsqueeze(0).unsqueeze(0).to(device)
#     for model in trained_models:
#         plt.plot(model.attention_weights(data)[0,:,:].cpu().flatten().detach().numpy()/torch.max(model.attention_weights(data)).item()*torch.max(data).item(),label=d_nice_names[attn_types[0]])
#         plt.plot(data[0,:,:].cpu().flatten().detach().numpy(),label='Series')
#         #plt.plot(model.attention_weights(data)[0,:,:].cpu().flatten().detach().numpy()/torch.max(model.attention_weights(data)).item(),label='Kernel Sparsemax')
#         #plt.plot(model.attention_weights(data)[0,:,:].cpu().flatten().detach().numpy()/torch.max(model.attention_weights(data)),label='Kernel Sparsemax')
#         plt.title('FordA: Original Series vs Rescaled Attention')
#         plt.xlabel('Time')
#         #plt.savefig('Kernel Sparsemax %i'%(k))
#         plt.legend()
#         plt.savefig('figures/forda0%i%s.png'%(i,attn_types[0]))
#         plt.show()
#         total+=1
#     i+=1

In [20]:
# test_results = []
# for model in trained_models:
#     y_pred,y_cpu = test_model(model,x_test,y_test)
#     test_results.append(y_pred)

In [21]:
# accuracies = []
# f1_scores = []
# for y_pred in test_results:
#     print('model')
#     f1, acc = get_metrics(y_cpu,y_pred)
#     f1_scores.append(f1)
#     accuracies.append(acc)

In [22]:
# results = (f1_scores,accuracies)

In [23]:
# print('mean f1 score')
# print(np.mean(np.array(f1_scores)))
# print('1.96*sd')
# print(1.96*np.std(np.array(f1_scores)))
# print('max f1 score')
# print(np.max(np.array(f1_scores)))
# print('min f1 score')
# print(np.min(np.array(f1_scores)))

# print('mean accuracy')
# print(np.mean(np.array(accuracies)))
# print('1.96*sd')
# print(1.96*np.std(np.array(accuracies)))
# print('max accuracy')
# print(np.max(np.array(accuracies)))
# print('min accuracy')
# print(np.min(np.array(accuracies)))


In [24]:
# f = open('results+%s.pkl'%(attn_types[0]), 'wb')   # Pickle file is newly created where foo1.py is
# pkl.dump(results, f)          # dump data to f
# f.close()           

In [25]:
# with open('results+%s.pkl'%(attn_types[0]), 'rb') as f:
#      results = pkl.load(f)

In [26]:
# f1_scores, accuracies = results[0],results[1]