In [1]:
import torch as T
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torchvision.models as models
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, ConcatDataset, IterableDataset
import numpy as np
import pandas as pd
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import glob
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
from pickle import dump,load

In [15]:
# SEQ_LEN=20
# D_HID_SIZE = 10

In [16]:
def get_cuda(tensor):
    if T.cuda.is_available():
        tensor = tensor.cuda(0)
    return tensor

In [17]:
class Custom_Embedding(nn.Module):
    def __init__ (self, inputDimSize, embSize):
        super(Custom_Embedding, self).__init__()
        self.inputDimSize = inputDimSize
        self.embSize = embSize
        
        self.W_emb = nn.Parameter(torch.randn(self.inputDimSize, self.embSize) * 0.01)
        self.b_emb = nn.Parameter(torch.zeros(self.embSize) * 0.01) 
       
    def forward(self, x):
        #x=x.cuda(0)
        return torch.tanh(x@self.W_emb + self.b_emb)


In [18]:
class Encoder(nn.Module):
    ''' C-RNN-GAN generator
    '''
    def __init__(self, input_dim, drug_dim, age_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        # params
        self.visitEmbedding = Custom_Embedding(input_dim, emb_dim)
        
        self.drugEmbedding = Custom_Embedding(drug_dim, emb_dim)
        
        self.ageEmbedding = Custom_Embedding(age_dim, emb_dim)
        
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True, batch_first=True)
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    
    def forward(self, src, drug, age):
        ''' Forward prop
        '''
        #src=src.cuda(0)
        #age=age.cuda(0)
#         print("==============Inside Encoder=============")
        #src = [batch_size, seq_len, num_feats]
        batch_size=src.shape[0]
        seq_len=src.shape[1]
        #print("src",src.shape)
        src = src.view(-1, src.size(2)) # (N*seq_len, num_feats)
        #print("src",src.shape)
        #src=src.double()
        visitEmbedded = self.visitEmbedding(src)

        visitEmbedded = visitEmbedded.view(batch_size,seq_len, -1)
#         print("visitEmbedded",visitEmbedded.shape)
#         print("visitEmbedded",visitEmbedded[0:2,:])
#         visitEmbedded = visitEmbedded.view(*src_size, -1)
        #print("visitEmbedded",visitEmbedded.shape)
#         #print("listEmbedding",list(self.visitEmbedding.parameters()))
#         #print("visitEmbedded",visitEmbedded[0,0,:,:])
#         #print("visitEmbedded",visitEmbedded[0,0,0,:])
#         visitEmbedded = visitEmbedded.sum(2)
#         #print("visitEmbedded",visitEmbedded.shape)
#         #print("visitEmbedded",visitEmbedded[0,0,:])
        
        #####DRUG EMBEDDING############
        batch_size=drug.shape[0]
        seq_len=drug.shape[1]
        drug = drug.view(-1, drug.size(2)) # (N*seq_len, num_feats)
        #print("drug",drug.shape)
        drugEmbedded = self.drugEmbedding(drug)
        drugEmbedded = drugEmbedded.view(batch_size,seq_len, -1)
        
        
        batch_size=age.shape[0]
        seq_len=age.shape[1]
        age = age.view(-1, age.size(2))
#         print("encoder age",age.shape)
#         #print("listEmbedding",list(self.ageEmbedding.parameters()))
#         age = age.view(-1, age.size(2)) # (N*seq_len, num_feats)
#         age=age.squeeze()
#         print("age",age.shape)
        ageEmbedded = self.ageEmbedding(age)
        ageEmbedded = ageEmbedded.view(batch_size,seq_len, -1)
        #print("Encoder ageEmbedded",ageEmbedded.shape)
        #ageEmbedded = ageEmbedded.view(*age_size, -1)
        ##print("ageEmbedded",ageEmbedded.shape)
        #ageEmbedded = ageEmbedded.sum(2)
        ##print("ageEmbedded",ageEmbedded.shape)
        
        embedded = visitEmbedded + ageEmbedded+ drugEmbedded
#         print("embedded",embedded.shape)
        embedded = self.dropout(embedded)
#         print("embedded",embedded.shape)
        
        #embedded = [batch_size, seq_len, emb dim]
        
        outputs, hidden = self.rnn(embedded)
#         print("hidden",hidden.shape)
        #print("outputs",outputs.shape)        
        #outputs = [batch_size, seq_len, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]
        
        ##print("hidden",hidden[-2,:,:].shape)
        ##print("hidden",hidden[-1,:,:].shape)
        
        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]. 
        #hidden is last visit representation
        #outputs are always from the last layer. outputs are given at each GRU cell
        
        #hidden [-2, :, : ] is the last of the forwards RNN 
        #hidden [-1, :, : ] is the last of the backwards RNN
        
        #initial decoder hidden is final hidden state of the forwards and backwards 
        #  encoder RNNs fed through a linear layer
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        #print("hidden",hidden.shape)
        #hidden=hidden.cuda(0)
        hidden = self.fc(hidden)
#         print("hidden",hidden.shape)
        hidden = torch.tanh(hidden)
        #print("hidden",hidden.shape)
        #outputs = [src len, batch size, enc hid dim * 2]
        #hidden = [batch size, dec hid dim]
        
        return outputs, hidden
    def printWeights(self):
        print("printing encoder weights")
#         print("Enc RNN: ",self.rnn.weight.grad)
#         print("Enc FC: ",self.fc.weight.grad)
        

In [19]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        #print("=====================inside attention======================")
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        #print("hidden",hidden.shape)
        #print("encoder_outputs",encoder_outputs.shape)
        #hidden = [batch size, src len, dec hid dim]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        #print("energy",energy.shape)
        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)
        #print("attention",attention.shape)
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)

In [20]:
class Decoder(nn.Module):
    def __init__(self, output_dim, age_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        
        self.visitEmbedding = Custom_Embedding(output_dim, emb_dim)
        
        self.ageEmbedding = Custom_Embedding(age_dim, emb_dim)
        
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, age, hidden, encoder_outputs):
             
        #print("================Inside Decoder===============")
        #print("input",input.shape)
        #print("age",age.shape)
        #print("hidden",hidden.shape)
        #print("encoder_outputs",encoder_outputs.shape)
        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        #input = input.unsqueeze(0)
        
        #input = [1, batch size]
        batch_size=input.shape[0]
        input=input.float()
        visitEmbedded = self.visitEmbedding(input)
        #print("visitEmbedded",visitEmbedded.shape)
        
        
        #print("age",age.shape)
#         #print("listEmbedding",list(self.ageEmbedding.parameters()))
#         age = age.view(-1, age.size(2)) # (N*seq_len, num_feats)
#         age=age.squeeze()
#         print("age",age.shape)
        ageEmbedded = self.ageEmbedding(age)
        #print("ageEmbedded",ageEmbedded.shape)
        #ageEmbedded = ageEmbedded.view(batch_size,seq_len, -1)
        
        
        embedded = visitEmbedded + ageEmbedded
        embedded = self.dropout(embedded)
        embedded=embedded.unsqueeze(0)
        #print("embedded",embedded.shape)
        
        #embedded = [1, batch size, emb dim]
        
        a = self.attention(hidden, encoder_outputs)        
        #print("out of attention")
        #a = [batch size, src len]
        #print("a",a.shape)
        a = a.unsqueeze(1)
        #print("a",a.shape)
        #a = [batch size, 1, src len]
        
        #encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #print("encoder_outputs",encoder_outputs.shape)
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        weighted = torch.bmm(a, encoder_outputs)
        #print("weighted",weighted.shape)
        #weighted = [batch size, 1, enc hid dim * 2]
        #weighted=weighted.squeeze()
        weighted = weighted.permute(1, 0, 2)
        #print("weighted",weighted.shape)
        #weighted = [1, batch size, enc hid dim * 2]
        #print("embedded",embedded.shape)
        #print("hidden",hidden.shape)    
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        #print("rnn_input",rnn_input.shape)
        #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        #print("output",output.shape) 
        #print("hidden",hidden.shape) 
        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]
        
        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
        assert (output == hidden).all()
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        #print("embedded",embedded.shape)
        #print("output",output.shape) 
        #print("weighted",weighted.shape)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        #print("prediction",prediction .shape)
        ##print("prediction",prediction)
        #prediction = [batch size, output dim]
        
        return prediction, hidden.squeeze(0),a.squeeze(1)
    

In [21]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, drug, trg, age, mask, train, train_on_gpu):
        
        #src = [batch size, src len]
        #trg = [trg len, batch size, trg_features]
        trg=trg.permute(1,0,2)
        #print("trg",trg.shape)
        #print("===========Inside seq2seq===========")
        batch_size = src.shape[0]
        trg_len = trg.shape[0]
        trg_features = self.decoder.output_dim
        
        #print("batch_size",batch_size)
        #print("trg_len",trg_len)
        #print("trg_features",trg_features)
        
        #tensor to store decoder outputs
        if train_on_gpu:
            outputs = torch.zeros(trg_len, batch_size, trg_features).cuda(0)
            attn = torch.zeros(trg_len, batch_size, trg_len).cuda(0)
        else:
            outputs = torch.zeros(trg_len, batch_size, trg_features)
            attn = torch.zeros(trg_len, batch_size, trg_len)
            
        #print("outputs",outputs.shape)
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        #print("Age",age.shape)
        encoder_outputs, hidden = self.encoder(src, drug, age)
        #print("===========Outside Encoder===========")
        
        #print("encoder_outputs",encoder_outputs.shape)
        #print("hidden",hidden.shape)
        
        #print("===========Preparing input for Decoder===========")
        #first input to the decoder is the <sos> tokens
        input = torch.zeros(batch_size, trg_features, dtype=T.long)#trg[0,:]
        #print("input",input.shape)
        age=age.squeeze()
        #print("age",age.shape)
        #print(age)
        age=age.permute(1,0)
        #print("age",age.shape)
        #inputAge = age[0]#torch.zeros(batch_size, dtype=T.long)
        
        mask=mask.permute(1,0)
        
        for t in range(0, trg_len):
            #print("=================Inside for loop================")
            #insert input token embedding, previous hidden state and all encoder hidden states
            #receive output tensor (predictions) and new hidden state
            inputAge=age[t]
            inputAge=inputAge.unsqueeze(1)
            #print("age",inputAge.shape)
            output, hidden,a = self.decoder(input, inputAge, hidden, encoder_outputs)
            #print("==============Outside decoder=================")
            #print("output",output.shape)
            #print("hidden",hidden.shape)
            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            attn[t] = a
            #print("OUTPUTS",outputs.shape)
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 
            #print("top1",top1.shape)
            ##print("top1",top1[0:2])
            
            ##print("age",inputAge)
            #print("======Deciding next input for decoder=========")
            prediction = F.one_hot(top1, num_classes=trg_features)
            #print("prediction",prediction.shape)
            ##print("prediction",prediction[0:2,:])
            #print("mask",mask.shape)
            ##print("mask",mask[t])
            ##print("mask",mask[t].unsqueeze(1).repeat(1,trg_features))
            inputMask=mask[t]
            #print("inputMask",inputMask.shape)
            #inputMask[1]=1
            inputMask=inputMask.unsqueeze(1).repeat(1,trg_features)
            #print("inputMask",inputMask.shape)
            #print("trg",trg[t].shape)
            ##print("inputMask",inputMask[0:2,:])
            inputMask=inputMask.float()
            trg[t]=trg[t].float()
            prediction=prediction.float()
            context=trg[t] * inputMask + prediction * (1-inputMask)
            #print("context",context.shape)
            
            ##print("inputMask",inputMask[0:2,:])
            ##print("trg",trg[t,0:2,:])
            ##print("prediction",prediction[0:2,:])
            ##print("context",context[0:2,:])

            if train:
                #print("=========Training===========")
                input = context
            else:
                #print("===========Testing==========")
                input = prediction
            #print("Next input for decoder",input.shape)
        #print("Final output",outputs.shape)
        return outputs, attn
    

In [22]:
class Discriminator(nn.Module):
    def __init__(self, MAX_SEQ_LEN, D_HID_SIZE):
        super(Discriminator, self).__init__()
        
        self.SEQ_LEN = MAX_SEQ_LEN
        self.D_HID_SIZE = D_HID_SIZE
        self.build()
        
    def build(self):
        self.rnn_cell = nn.LSTMCell(1, self.D_HID_SIZE)
        self.regression1 = nn.Linear(self.D_HID_SIZE, 5)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.regression2 = nn.Linear(5, 1)
        self.sig = nn.Sigmoid()
        
    def merge_score(self, score_f):
        
        #print("Foward Scores",score_f['scores'][0,:])
        #print("Backward Scores",score_b['scores'][0,:])
        #print("Missing",score_f['missing'][0,:])
        #print("Foward Scores",score_f['scores'].size())
        #print("Backward Scores",score_b['scores'].size())
        #print("Missing",score_f['missing'].size())
        
        #Calculate Loss for Sigmid layer
        Tensor = torch.cuda.FloatTensor
        
        score_f['scoresSig'] = torch.flatten(score_f['scoresSig'])
        score_f['missing'] = torch.flatten(score_f['missing'])
        
        
        real_ids = (score_f['missing'].nonzero())
        fake_ids = ((1-score_f['missing']).nonzero())
        
        # Loss function
        adversarial_loss = torch.nn.BCELoss()

        # Adversarial ground truths
        valid = Variable((score_f['missing'])[real_ids], requires_grad=False)
        fake = Variable((score_f['missing'])[fake_ids], requires_grad=False)
        validG = Variable((1-score_f['missing'])[fake_ids], requires_grad=False)
        
        #print("Valid",valid.size())
        #print("fake",fake.size())
               
        #ret_b['scores'] = ret_b['scores'] * ret_b['missing']
        #print("Final Scores",ret_b['imputations'][0,:])
                
        if(fake_ids.size()[0]==0 ):     
            loss_gSig=Variable(torch.cuda.FloatTensor([0]), requires_grad=True)
        else:
            loss_gF = adversarial_loss((score_f['scoresSig'])[fake_ids], validG)
            loss_gSig=loss_gF
        
        loss_dReal = adversarial_loss((score_f['scoresSig'])[real_ids], valid)
        if(fake_ids.size()[0]==0 ): 
            loss_dSig = loss_dReal
        else:
            loss_dFake = adversarial_loss((score_f['scoresSig'])[fake_ids], fake)
            loss_dSig = (loss_dReal + loss_dFake)/2
        #print(loss_dSig,loss_gSig)
        return {'loss_d': loss_dSig , 'loss_g': loss_gSig}
        
    def forward(self, values, masks, direct):
        
        h = Variable(torch.zeros((values.size()[0], self.D_HID_SIZE)))
        c = Variable(torch.zeros((values.size()[0], self.D_HID_SIZE)))

        if torch.cuda.is_available():
            h, c = h.cuda(0), c.cuda(0)
            values, masks = values.cuda(0), masks.cuda(0)
            
        scoresSig=[]
        missing=[]
        if(direct=="forward"):

            for t in range(self.SEQ_LEN):
                #print("===============",t,"======================")
                x = values[:, t]
                x=x.unsqueeze(dim=1)
                m = masks[:, t]
                #print("Input",x.size())
                #print("Input",x[0])

                x_h = self.regression1(h)
                x_h = self.leaky(x_h)
                x_h = self.regression2(x_h)
                x_h2 = self.sig(x_h)
                #print("Discriminator output",x_h[0])
                #print("Discriminator output",x_h2[0])

                #print("Output regression",x_h.size())
                #print("Mask",m.size())

                m=m.unsqueeze(dim=1)
                
                #print("i am here")

                h, c = self.rnn_cell(x, (h, c))
                #print("i am here")

                #imputations.append(x_c[:,316].unsqueeze(dim = 1))
                scoresSig.append(x_h2[:,0].unsqueeze(dim = 1))
                #print("i am here")
                missing.append(m)
                #print("i am here")
                #print("to be appended",m.size())
                #print("Imputations",len(imputations))
                #print("Scores",scores[0].size())
        
        elif(direct=="backward"):

            for t in range(self.SEQ_LEN-1,-1,-1):
                #print("===============",t,"======================")
                x = values[:, t]
                x=x.unsqueeze(dim=1)
                m = masks[:, t]
                #print("Input",x.size())
                #print("Input",x[0])

                x_h = self.regression1(h)
                x_h = self.leaky(x_h)
                x_h = self.regression2(x_h)
                x_h2 = self.sig(x_h)
                #print("Discriminator output",x_h.shape)

                #print("Output regression",x_h.size())
                #print("Mask",m.size())

                m=m.unsqueeze(dim=1)
                
                #print("d",d[:,0].unsqueeze(dim=1).size())

                h, c = self.rnn_cell(x, (h, c))

                #imputations.append(x_c[:,316].unsqueeze(dim = 1))
                scoresSig.append(x_h2[:,0].unsqueeze(dim = 1))
                missing.append(m)
                #print("to be appended",m.size())
                #print("Imputations",len(imputations))
                #print("Scores",scores[0].size())
        
        scoresSig = torch.cat(scoresSig, dim = 1)
        missing = torch.cat(missing, dim = 1)
        #print("Scores",len(scores),scores[0].size())
        return self.merge_score({'scoresSig': scoresSig, 'missing':missing})
    
        
    

In [23]:
# Create Dataset
class CSVDataset(Dataset):
    def __init__(self, path, chunksize,length,seq_len,flag):
        self.path = path
        self.chunksize = chunksize
        self.len = int(length)#number of times total getitem is called
        self.seq_len=seq_len
        self.flag=flag
        self.reader=pd.read_csv(
                self.path,header=0,
                chunksize=self.chunksize)#,names=['data']))

    def __getitem__(self, index):
        data = self.reader.get_chunk(self.chunksize)
        #sex=pd.read_csv('C:\\Users/mehak/Desktop/demo.csv',header=0)
        #sex=sex[['person_id','Sex']]
        #data = pd.merge(data, sex, how='left', on=['person_id'])
        #print(data.shape)
        #data=data.sort_values(by=['RANDOM_PATIENT_ID','VISIT_YEAR','VISIT_MONTH'])
        #print(data['RANDOM_PATIENT_ID'].unique())
#         del data['person_id']
#         print(data.columns.get_loc('BMI'))
        #print(data.columns)

        data=data.replace(np.inf,0)
        data=data.replace(np.nan,0)
        data=data.fillna(0)
        #print(data.shape)
        if(self.flag==0):
            #data['Age']=data['Age'].apply(lambda x: ((x*12)/3)-81)
#             data['Stomach finding']=0
            pids=data['person_id']
            pids = T.as_tensor(pids.values.astype(float), dtype=T.long)
#             print("age",data['Age'])
#             print("pids",list(pids))
#             print("========================================================")

            data = T.as_tensor(data.values.astype(float), dtype=T.float32)
    #         print(list(data[:,0]))
    #         print("========================================================")
            #data=T.from_numpy(data)
            #data=data.double()
            data=data.view(int(data.shape[0]/self.seq_len), self.seq_len, data.shape[1])
            #print(data.shape)
            return data,pids
        elif(self.flag==2):
            #data['losartan']=0
            data = T.as_tensor(data.values.astype(float), dtype=T.float32)
            data=data.view(int(data.shape[0]/self.seq_len), self.seq_len, data.shape[1])
            return data
        else:
            data = T.as_tensor(data.values.astype(float), dtype=T.float32)
            data=data.view(int(data.shape[0]/self.seq_len), self.seq_len, data.shape[1])
            
            return data

    def __len__(self):
        return self.len

In [25]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, args, patience=7, verbose=False, delta=0):#-0.01
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf#11.1179
        self.delta = delta
        self.args=args

    def __call__(self, val_loss, model, optimizer, save_path):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer, save_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if score > self.best_score + 0:
                self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer, save_path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, optimizer, save_path):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        if  self.args.discriminator:
            T.save({
                "G_model": model['g'].state_dict(),
                "D_model": model['d'].state_dict(),
                'G_trainer': optimizer['g'].state_dict(),
                'D_trainer': optimizer['d'].state_dict()
            }, save_path)
        else:
            T.save({
                "G_model": model['g'].state_dict(),
                'G_trainer': optimizer['g'].state_dict()
            }, save_path)
        self.val_loss_min = -self.best_score

In [3]:
#W_emb = nn.Parameter(torch.randn(10, 5) * 0.01)
#b_emb = nn.Parameter(torch.zeros(5) * 0.01) 
       
    
#torch.tanh(x@self.W_emb + self.b_emb)

In [4]:
#W_emb

Parameter containing:
tensor([[-6.0893e-03, -1.2221e-02, -1.9950e-04, -6.2818e-03,  8.7840e-04],
        [ 1.6809e-02,  2.4981e-02, -2.9280e-03, -2.9692e-03, -7.5641e-03],
        [ 1.7082e-03,  2.8238e-03, -4.5027e-03, -1.8286e-03, -1.1840e-03],
        [ 1.5375e-02,  1.2315e-02,  5.0763e-03,  1.3537e-02, -9.2942e-04],
        [ 3.8756e-03, -1.6691e-03,  6.6457e-03, -6.2233e-03, -7.0930e-03],
        [ 1.1251e-02,  4.6324e-03, -5.9148e-03,  1.3852e-03,  3.3464e-03],
        [-1.5402e-02,  2.3139e-03,  4.3486e-03, -4.2589e-03, -3.6165e-03],
        [ 4.7201e-03, -2.8067e-03,  2.9867e-03,  4.7570e-03, -7.5531e-03],
        [-1.8638e-02, -1.8305e-03,  4.3261e-03,  7.7344e-03,  7.2551e-03],
        [ 1.6162e-04, -2.5067e-03, -1.3397e-02, -7.6573e-05,  1.6650e-02]],
       requires_grad=True)

In [5]:
#x=torch.rand((20, 10))

In [6]:
#x

tensor([[0.9394, 0.2193, 0.5036, 0.6045, 0.2243, 0.8799, 0.5826, 0.5922, 0.2516,
         0.1680],
        [0.6566, 0.8166, 0.8150, 0.8155, 0.4225, 0.6462, 0.8814, 0.9544, 0.0511,
         0.5803],
        [0.6730, 0.2475, 0.6231, 0.0054, 0.2907, 0.3123, 0.0914, 0.0269, 0.3092,
         0.0061],
        [0.8830, 0.0163, 0.2944, 0.0123, 0.6187, 0.7683, 0.4872, 0.0553, 0.7860,
         0.7100],
        [0.9427, 0.4828, 0.4373, 0.3135, 0.9343, 0.7308, 0.6924, 0.0260, 0.3376,
         0.7189],
        [0.6380, 0.6036, 0.2492, 0.2471, 0.7153, 0.2933, 0.0039, 0.9437, 0.6519,
         0.5304],
        [0.3565, 0.2318, 0.4736, 0.3471, 0.6950, 0.9692, 0.3261, 0.4120, 0.5497,
         0.4465],
        [0.7704, 0.6597, 0.7404, 0.9754, 0.1743, 0.2654, 0.2726, 0.9924, 0.6058,
         0.5487],
        [0.0354, 0.7807, 0.3702, 0.8434, 0.8027, 0.5034, 0.9833, 0.3385, 0.6894,
         0.6097],
        [0.0886, 0.4313, 0.6243, 0.8636, 0.5273, 0.6590, 0.0606, 0.9589, 0.7804,
         0.4070],
        [0

In [7]:
#torch.tanh(x@W_emb + b_emb)

tensor([[ 0.0080,  0.0054, -0.0006,  0.0028, -0.0026],
        [ 0.0226,  0.0248, -0.0039,  0.0024, -0.0085],
        [-0.0012,  0.0003, -0.0018, -0.0053, -0.0012],
        [-0.0151, -0.0091, -0.0057, -0.0045,  0.0138],
        [ 0.0031,  0.0066, -0.0052, -0.0090,  0.0039],
        [ 0.0089,  0.0060, -0.0002,  0.0025, -0.0022],
        [ 0.0082,  0.0078, -0.0032,  0.0027,  0.0031],
        [ 0.0156,  0.0175, -0.0015,  0.0125, -0.0014],
        [ 0.0090,  0.0300,  0.0028,  0.0066, -0.0021],
        [ 0.0196,  0.0193,  0.0009,  0.0167, -0.0013],
        [ 0.0052,  0.0140, -0.0009, -0.0018,  0.0007],
        [ 0.0043,  0.0119, -0.0078,  0.0136,  0.0162],
        [ 0.0135,  0.0184, -0.0014, -0.0007, -0.0067],
        [-0.0113, -0.0044, -0.0054,  0.0028,  0.0100],
        [ 0.0188,  0.0114, -0.0067,  0.0148,  0.0121],
        [ 0.0134,  0.0221,  0.0040,  0.0080, -0.0017],
        [ 0.0056,  0.0272, -0.0096,  0.0099,  0.0083],
        [ 0.0111,  0.0208, -0.0056, -0.0015, -0.0005],
        [-