In [1]:
import pandas as pd
import numpy as np
from path import Path
import torch.utils.data as data
from imageio import imread
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F 
import matplotlib.pyplot as plt
%matplotlib inline 
from torch.utils.data.sampler import SubsetRandomSampler

from PIL import *
import ast

In [2]:
import numpy as np

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''
    # given query, key,value it finds the rightful weighted component of v to get the attention applied ouput
    #q,v,k- batch X length of sequence X features or encoding
    #attention sholuld be -batchX7X7
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
#         print(k.transpose(1,2).shape)

        attn = torch.bmm(q, k.transpose(1, 2)) 
#         print(attn.shape)
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
#         print(str(attn.shape)+" "+str(v.shape))
        output = torch.bmm(attn, v)

        return output, attn
SDP=ScaledDotProductAttention(5)
Ss=SDP(torch.zeros(5,6,100),torch.zeros(5,6,100),torch.zeros(5,6,100))
# print(Ss[0].shape)
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()
#         print(str(sz_b)+"die")

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
#         print("v-"+str(v.shape))
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)
#         print(q.shape,k.shape,v.shape)
        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn
MHA=MultiHeadAttention(4,15,15,15)
op=MHA(torch.zeros(5,7,15),torch.zeros(5,7,15),torch.zeros(5,7,15))
# print(op[0].shape)
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
        self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)
#         print("FCC-"+str(output.shape))
#         print("FFC_out-"+str(self.w_1(output).shape))
        output = self.w_2(F.relu(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output

    

In [3]:
class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
#         print("ENC_o")
#         print(enc_output.shape)


        enc_output = self.pos_ffn(enc_output)


        return enc_output, enc_slf_attn
    
XX=EncoderLayer(15,10,4,10,10)

zz=XX(torch.zeros(5,7,15))
# print("ENc")
# print(zz[0].shape)
# print("start")
class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(self,n_modality,d_model,n_head,d_k,d_v,dropout,n_layers,d_inner=500):
        #d_model - number of features in input 100 here
        #n_head - number of heads of multihaded attention
        #d_k=d_q=  number of features in query, key
        #d_v = number of features in value whose weighted(attentioned) sum we gonna take
        

        super().__init__()
        self.n_modality=n_modality
#         self.stn=nn.ModuleList([SpatialTransformer(3, (240,240), 8) for _ in range(n_ref)])
        
        self.layer_stack = nn.ModuleList([EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 
                                          for _ in range(n_layers)])
        self.em=nn.Linear(225,100)
        self.fc1=nn.Linear(d_model*n_modality,300)
        self.relu=nn.ReLU()
        self.fc2=nn.Linear(300,100)
        self.fc3=nn.Linear(100,3)
#         self.fc4=nn.Linear(50,3)
        self.bn1 = nn.BatchNorm1d(num_features=300)
        self.bn2 = nn.BatchNorm1d(num_features=100)
        self.softmax=nn.Softmax(1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, embeddings1,embeddings2 ):


        
        encodings_total=[self.em(embeddings1),embeddings2]

        enc_output=torch.stack(encodings_total,0)

        
        enc_output=enc_output.permute(1,0,2)
#         print("encoding_OUTPUT2-"+str(enc_output.shape))

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output,non_pad_mask=None,slf_attn_mask=None)
           
        
        final_input=enc_output.reshape(enc_output.shape[0],-1)

        final=self.relu(self.fc3(self.bn2(self.relu((self.fc2(self.bn1(self.relu(self.fc1(final_input)))))))))
        
        return(final)
     

In [4]:
e1=torch.zeros(4,100)
e2=torch.zeros(4,100)
e3=torch.stack([e1,e2],0)
e3.shape

torch.Size([2, 4, 100])

In [5]:
feature_set_MLP=pd.read_csv('ILD_MLP_features_ankit.csv')
feature_set_LSTM=pd.read_csv('ILD_LSTM_features_ankit.csv')
feature_set_MLP.head()

Unnamed: 0,gene_name,tag_0,tag_1,tag_2,tag_3,tag_4,tag_5,tag_6,tag_7,tag_8,...,tag_215,tag_216,tag_217,tag_218,tag_219,tag_220,tag_221,tag_222,tag_223,tag_224
0,hspa6,0.0,3.500152,0.0,4.197427,0.0,0.0,2.50499,4.630731,3.426645,...,0.0,3.367921,2.667888,0.0,0.0,0.0,0.0,0.0,0.702599,3.955442
1,scarb1,0.0,3.895088,0.0,4.644502,0.0,0.0,2.67644,5.127488,4.194887,...,0.0,3.544901,2.634494,0.0,0.0,0.0,0.0,0.0,1.263837,4.390254
2,mapk1,0.0,3.197021,0.0,3.810046,0.0,0.0,2.164615,3.919276,3.232924,...,0.0,2.883429,2.289313,0.0,0.0,0.0,0.0,0.0,0.687884,3.244735
3,adam32,0.0,1.336279,0.0,1.62095,0.0,0.0,0.971384,1.8307,1.393075,...,0.0,1.329008,0.997157,0.0,0.0,0.0,0.0,0.0,0.352354,1.665234
4,spata17,0.0,1.768997,0.0,1.791337,0.0,0.0,0.706958,2.092716,1.662794,...,0.0,2.180046,1.217243,0.0,0.0,0.0,0.0,0.0,0.788542,1.739285


In [6]:
feature_set_MLP.head()

Unnamed: 0,gene_name,tag_0,tag_1,tag_2,tag_3,tag_4,tag_5,tag_6,tag_7,tag_8,...,tag_215,tag_216,tag_217,tag_218,tag_219,tag_220,tag_221,tag_222,tag_223,tag_224
0,hspa6,0.0,3.500152,0.0,4.197427,0.0,0.0,2.50499,4.630731,3.426645,...,0.0,3.367921,2.667888,0.0,0.0,0.0,0.0,0.0,0.702599,3.955442
1,scarb1,0.0,3.895088,0.0,4.644502,0.0,0.0,2.67644,5.127488,4.194887,...,0.0,3.544901,2.634494,0.0,0.0,0.0,0.0,0.0,1.263837,4.390254
2,mapk1,0.0,3.197021,0.0,3.810046,0.0,0.0,2.164615,3.919276,3.232924,...,0.0,2.883429,2.289313,0.0,0.0,0.0,0.0,0.0,0.687884,3.244735
3,adam32,0.0,1.336279,0.0,1.62095,0.0,0.0,0.971384,1.8307,1.393075,...,0.0,1.329008,0.997157,0.0,0.0,0.0,0.0,0.0,0.352354,1.665234
4,spata17,0.0,1.768997,0.0,1.791337,0.0,0.0,0.706958,2.092716,1.662794,...,0.0,2.180046,1.217243,0.0,0.0,0.0,0.0,0.0,0.788542,1.739285


In [7]:
header_of_MLP=['tag_'+str(i) for i in range(feature_set_MLP.shape[1]-1)]
features_MLP=np.array(feature_set_MLP[header_of_MLP])
gene_MLP=feature_set_MLP['gene_name']
print(features_MLP.shape)
print(len(gene_MLP))

(18144, 225)
18144


In [8]:
dictionary_MLP={}
u=0
for gn in gene_MLP:
    dictionary_MLP[gn]=features_MLP[u]
    u=u+1

In [9]:
header_of_LSTM=['tag_'+str(i) for i in range(feature_set_LSTM.shape[1]-1)]
features_LSTM=np.array(feature_set_LSTM[header_of_LSTM])
gene_LSTM=feature_set_LSTM['gene_name']
print(features_LSTM.shape)
print(len(gene_LSTM))

(13716, 100)
13716


In [10]:
dictionary_LSTM={}
u=f=0
for gn in gene_LSTM:
    if gn in dictionary_LSTM.keys():
#         print(gn)
        f=f+1
    dictionary_LSTM[gn]=features_LSTM[u]
    u=u+1
print(f)

4022


In [11]:
print(len(dictionary_LSTM.keys()))

9694


In [12]:
fil=open('../../Multi-modality/Model/ILD/data/labels_ILD.txt','r')
tmp=list()
for line in fil:
	tmp.append(int(line))

label_ILD=np.array(tmp)
print(label_ILD.shape)

(18144,)


In [13]:
class Sequenceloader(data.Dataset):
    def __init__(self,GN,Feat,label):
        self.gene_names=GN
        self.features_mlp=Feat
        self.label=label
        self.coincdgene_name=[]
        self.coincidfeature_MLP=[]
        self.coincidfeature_LSTM=[]
        self.coincidlabel=[]
        for i in range(len(self.gene_names)):
            u=self.gene_names[i]
            if u in dictionary_LSTM.keys():
                
                if np.array(self.label[i])==2:
                    ch=3
                else:
                    ch=1
                    
                for jj in range(ch):
                    self.coincdgene_name.append(u)
                    self.coincidfeature_MLP.append(self.features_mlp[i])
                    self.coincidfeature_LSTM.append(dictionary_LSTM[u])
                    self.coincidlabel.append(self.label[i])
        
        
    
    def __len__(self):
#         print(len(self.dataset))
        return len(self.coincdgene_name)       

    def __getitem__(self, index):
         return np.array(self.coincidfeature_MLP[index]),np.array(self.coincidfeature_LSTM[index]),np.array(self.coincidlabel[index])
#         print(self.dataset['gen_name'][index])          
#         try:
           
#         except :
            
total_set=Sequenceloader(gene_MLP,features_MLP,label_ILD)  
a=b=c=0
for x,y,z in total_set:
    if(z==0):
        a=a+1
    elif z==1:
        b=b+1
    else:
        c=c+1
print(a,b,c)
    

6477 5042 6591


In [14]:
batch_size = 4
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(total_set)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(total_set, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(total_set, batch_size=batch_size,
                                                sampler=valid_sampler)


In [130]:
for a,b,c in train_loader:
    print(a.shape,b.shape,c.shape)

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4, 100]) torch.Size([4, 100]) torch.Size([4])
torch.Size([4,

In [15]:
def test(test_loader,model):

    total_imgs=0;
    total_corrects=0
    u=0
    nb_classes=3
    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    for i1,i2,label in test_loader:
                

        output=model(i1.to(device).float(),i2.to(device).float())
        total_imgs=total_imgs+label.shape[0]
        z=torch.max(output,1)[1]==label.to(device)
        _, preds = torch.max(output, 1)
#         print(output.shape)
 
        num_corrects=torch.sum(z)
        total_corrects=total_corrects+num_corrects
        for t, p in zip(label.view(-1), preds.view(-1)):
            confusion_matrix[t.long(), p.long()] += 1


        u=u+1
    
    
    print(confusion_matrix)
    return(total_corrects,total_imgs)

In [21]:
device='cuda:7'
model_lstmXMLP=Encoder(2,100,4,300,300,True,4).to(device) #1)mlp 2lstm
uz=torch.rand(4, 100).to(device)
vz=torch.rand(4,225).to(device)
model_lstmXMLP(vz,uz).shape


torch.Size([4, 3])

In [22]:

#model_lstmXMLP.load_state_dict(torch.load(Path('1ANKIT_ILD COMBO_LSTMXMLP.pt')))
optim_params = [
    {'params': model_lstmXMLP.parameters(), 'lr': 0.0001}
]
optimizer = torch.optim.Adam(optim_params)
criterion = nn.CrossEntropyLoss()

In [23]:
epoch=10000
for i in range(epoch):
    torch.save(model_lstmXMLP.state_dict(), '2ANKIT_ILD COMBO_LSTMXMLP.pt')
    print("Accuracy-"+str(test(validation_loader,model_lstmXMLP)))
    total_loss=0
    for inp1,inp2,lab in train_loader:

        

        output=model_lstmXMLP(inp1.to(device).float(),inp2.to(device).float())

        loss_batch=criterion(output,lab.to(device))/4
        optimizer.zero_grad()  
        loss_batch.backward()
        optimizer.step()
        total_loss+=loss_batch
       
    print(total_loss)
    
    
        
        

tensor([[502., 377., 416.],
        [431., 179., 414.],
        [323., 347., 633.]])
Accuracy-(tensor(1314, device='cuda:7'), 3622)
tensor(509.4636, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1134.,   57.,  104.],
        [  57.,  815.,  152.],
        [ 121.,  156., 1026.]])
Accuracy-(tensor(2975, device='cuda:7'), 3622)
tensor(404.8477, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1173.,   41.,   81.],
        [  61.,  798.,  165.],
        [ 130.,  105., 1068.]])
Accuracy-(tensor(3039, device='cuda:7'), 3622)
tensor(376.5764, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1171.,   40.,   84.],
        [  38.,  818.,  168.],
        [  97.,  131., 1075.]])
Accuracy-(tensor(3064, device='cuda:7'), 3622)
tensor(357.3091, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1186.,   47.,   62.],
        [  36.,  808.,  180.],
        [ 104.,  133., 1066.]])
Accuracy-(tensor(3060, device='cuda:7'), 3622)
tensor(351.8315, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1163.

tensor(267.8316, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1210.,   40.,   45.],
        [  15.,  847.,  162.],
        [ 101.,  126., 1076.]])
Accuracy-(tensor(3133, device='cuda:7'), 3622)
tensor(248.2769, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1240.,   18.,   37.],
        [  32.,  821.,  171.],
        [ 118.,  121., 1064.]])
Accuracy-(tensor(3125, device='cuda:7'), 3622)
tensor(249.8990, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1185.,   34.,   76.],
        [   6.,  871.,  147.],
        [  73.,  108., 1122.]])
Accuracy-(tensor(3178, device='cuda:7'), 3622)
tensor(242.5143, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1227.,   22.,   46.],
        [   9.,  847.,  168.],
        [  65.,  119., 1119.]])
Accuracy-(tensor(3193, device='cuda:7'), 3622)
tensor(234.4445, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1193.,   33.,   69.],
        [   6.,  897.,  121.],
        [  39.,  149., 1115.]])
Accuracy-(tensor(3205, device='cuda:7'), 3622)
tenso

tensor([[1236.,    5.,   54.],
        [   7.,  874.,  143.],
        [  59.,   74., 1170.]])
Accuracy-(tensor(3280, device='cuda:7'), 3622)
tensor(206.8943, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2220e+03, 9.0000e+00, 6.4000e+01],
        [1.0000e+00, 9.4700e+02, 7.6000e+01],
        [3.2000e+01, 1.3800e+02, 1.1330e+03]])
Accuracy-(tensor(3302, device='cuda:7'), 3622)
tensor(189.8052, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2390e+03, 3.0000e+00, 5.3000e+01],
        [1.0000e+00, 9.3300e+02, 9.0000e+01],
        [3.5000e+01, 1.0300e+02, 1.1650e+03]])
Accuracy-(tensor(3337, device='cuda:7'), 3622)
tensor(197.8261, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1207.,    3.,   85.],
        [   4.,  879.,  141.],
        [  28.,   81., 1194.]])
Accuracy-(tensor(3280, device='cuda:7'), 3622)
tensor(185.1586, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1245.,    4.,   46.],
        [   2.,  897.,  125.],
        [  25.,   95., 1183.]])
Accuracy-(tensor(332

tensor([[1256.,    0.,   39.],
        [   2.,  928.,   94.],
        [  55.,   45., 1203.]])
Accuracy-(tensor(3387, device='cuda:7'), 3622)
tensor(148.3191, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2150e+03, 1.0000e+00, 7.9000e+01],
        [0.0000e+00, 9.7700e+02, 4.7000e+01],
        [1.7000e+01, 1.9900e+02, 1.0870e+03]])
Accuracy-(tensor(3279, device='cuda:7'), 3622)
tensor(176.6240, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1253.,    3.,   39.],
        [   0.,  937.,   87.],
        [  39.,   94., 1170.]])
Accuracy-(tensor(3360, device='cuda:7'), 3622)
tensor(159.5724, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1243.,    2.,   50.],
        [   3.,  915.,  106.],
        [  35.,   49., 1219.]])
Accuracy-(tensor(3377, device='cuda:7'), 3622)
tensor(161.3827, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2280e+03, 4.0000e+00, 6.3000e+01],
        [1.0000e+00, 9.4600e+02, 7.7000e+01],
        [2.8000e+01, 9.8000e+01, 1.1770e+03]])
Accuracy-(tensor(335

tensor(186.4704, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1252.,    7.,   36.],
        [  19.,  820.,  185.],
        [ 117.,   65., 1121.]])
Accuracy-(tensor(3193, device='cuda:7'), 3622)
tensor(187.1941, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1244.,    2.,   49.],
        [   4.,  830.,  190.],
        [  64.,   49., 1190.]])
Accuracy-(tensor(3264, device='cuda:7'), 3622)
tensor(162.8997, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1262.,    0.,   33.],
        [   0.,  909.,  115.],
        [  55.,   40., 1208.]])
Accuracy-(tensor(3379, device='cuda:7'), 3622)
tensor(152.9481, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2610e+03, 1.0000e+00, 3.3000e+01],
        [2.0000e+00, 9.0600e+02, 1.1600e+02],
        [6.3000e+01, 5.3000e+01, 1.1870e+03]])
Accuracy-(tensor(3354, device='cuda:7'), 3622)
tensor(141.6540, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2290e+03, 0.0000e+00, 6.6000e+01],
        [1.0000e+00, 9.4900e+02, 7.4000e+01],
        [

tensor(140.6162, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2550e+03, 1.0000e+00, 3.9000e+01],
        [0.0000e+00, 9.5100e+02, 7.3000e+01],
        [2.3000e+01, 7.2000e+01, 1.2080e+03]])
Accuracy-(tensor(3414, device='cuda:7'), 3622)
tensor(140.2526, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1205.,    0.,   90.],
        [   0., 1011.,   13.],
        [   9.,  280., 1014.]])
Accuracy-(tensor(3230, device='cuda:7'), 3622)
tensor(112.7682, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2230e+03, 0.0000e+00, 7.2000e+01],
        [1.0000e+00, 9.9200e+02, 3.1000e+01],
        [6.0000e+00, 1.6000e+02, 1.1370e+03]])
Accuracy-(tensor(3352, device='cuda:7'), 3622)
tensor(122.2133, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2690e+03, 0.0000e+00, 2.6000e+01],
        [1.0000e+00, 9.1700e+02, 1.0600e+02],
        [2.4000e+01, 2.3000e+01, 1.2560e+03]])
Accuracy-(tensor(3442, device='cuda:7'), 3622)
tensor(109.3678, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[12

tensor([[1260.,    0.,   35.],
        [   0.,  963.,   61.],
        [  22.,   49., 1232.]])
Accuracy-(tensor(3455, device='cuda:7'), 3622)
tensor(116.1352, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2580e+03, 0.0000e+00, 3.7000e+01],
        [1.0000e+00, 9.2300e+02, 1.0000e+02],
        [2.4000e+01, 2.4000e+01, 1.2550e+03]])
Accuracy-(tensor(3436, device='cuda:7'), 3622)
tensor(134.8658, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2640e+03, 0.0000e+00, 3.1000e+01],
        [1.0000e+00, 8.8600e+02, 1.3700e+02],
        [4.0000e+01, 2.6000e+01, 1.2370e+03]])
Accuracy-(tensor(3387, device='cuda:7'), 3622)
tensor(116.1012, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1231.,    0.,   64.],
        [   0.,  923.,  101.],
        [   5.,   27., 1271.]])
Accuracy-(tensor(3425, device='cuda:7'), 3622)
tensor(118.2661, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2430e+03, 0.0000e+00, 5.2000e+01],
        [1.0000e+00, 9.6800e+02, 5.5000e+01],
        [6.0000e+00, 5

tensor(110.1406, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1278.,    0.,   17.],
        [   8.,  412.,  604.],
        [ 259.,    0., 1044.]])
Accuracy-(tensor(2734, device='cuda:7'), 3622)
tensor(165.8160, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2380e+03, 1.0000e+00, 5.6000e+01],
        [0.0000e+00, 9.5700e+02, 6.7000e+01],
        [1.6000e+01, 9.0000e+01, 1.1970e+03]])
Accuracy-(tensor(3392, device='cuda:7'), 3622)
tensor(133.7591, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1244.,    0.,   51.],
        [   3.,  962.,   59.],
        [   7.,   51., 1245.]])
Accuracy-(tensor(3451, device='cuda:7'), 3622)
tensor(118.6767, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2730e+03, 1.0000e+00, 2.1000e+01],
        [0.0000e+00, 8.2300e+02, 2.0100e+02],
        [7.1000e+01, 1.1000e+01, 1.2210e+03]])
Accuracy-(tensor(3317, device='cuda:7'), 3622)
tensor(122.7283, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2150e+03, 0.0000e+00, 8.0000e+01],
        [1

tensor(109.8938, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1267.,    0.,   28.],
        [   0.,  919.,  105.],
        [  22.,   28., 1253.]])
Accuracy-(tensor(3439, device='cuda:7'), 3622)
tensor(111.0982, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1256.,    0.,   39.],
        [   3.,  959.,   62.],
        [  16.,   24., 1263.]])
Accuracy-(tensor(3478, device='cuda:7'), 3622)
tensor(96.4973, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1250.,    0.,   45.],
        [   3.,  950.,   71.],
        [  12.,   20., 1271.]])
Accuracy-(tensor(3471, device='cuda:7'), 3622)
tensor(102.1496, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1242.,    0.,   53.],
        [   4.,  958.,   62.],
        [  11.,   39., 1253.]])
Accuracy-(tensor(3453, device='cuda:7'), 3622)
tensor(101.9795, device='cuda:7', grad_fn=<AddBackward0>)
tensor([[1.2430e+03, 0.0000e+00, 5.2000e+01],
        [1.0000e+00, 9.8800e+02, 3.5000e+01],
        [5.0000e+00, 5.6000e+01, 1.2420e+03]])
Accuracy

KeyboardInterrupt: 

In [None]:
tensor([[1243.,    9.,   43.],
        [  31.,  778.,  215.],
        [  99.,   76., 1128.]])
Accuracy-(tensor(3149, device='cuda:7'), 3622)