In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/new/RHVAE/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z1_dim=256, z2_dim=256, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 256, h2_dim = 256, number_components = 500,
                 beta=1, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x = self.get_encoded_values(self.images, ds)
        else:
            x = np.load(f'{link}/data/modmirbase_{ds}_images_RNN.npz')
        
        self.x = x
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x = self.x[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x2 = self.x[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x2, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        idx 0 - 5: color top bar
        idx 6 - 11: color bot bar
        idx 12: length top bar
        idx 13: length bot bar
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out = np.zeros((n,100,14), dtype=np.uint8)
        
        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            for j in range(100):
                
                # check color of top bar
                col = self.get_color(x[i,:,12,j])
                out[i, j, col] = 1
                if col < 5:
                # check length of top bar
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out[i, j, 12] = len1

                # check color of bottom bar
                col = self.get_color(x[i,:,13,j])
                out[i, j, col+6] = 1
                if col < 5:
                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out[i, j, 13] = len2

        with open(f'{link}/data/modmirbase_{ds}_images_RNN.npz', 'wb') as f:
            np.save(f, out)
        
        

        return out

        
        
    
    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        elif (pixel == np.array([1,1,1])).all():
            return 5 # white
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=800, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        # seperate decoders for length of RNA, color and size of bars
        self.fc = nn.Sequential(nn.Linear(2*h_dim, 14),
                                    nn.ReLU())
        
        self.rnn1 = nn.LSTM(14, 64, num_layers=2, batch_first=True)
        self.rnn2 = nn.LSTM(64,128, num_layers=2, batch_first=True)
        self.fc2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Dropout(0.4), nn.Linear(128,14))
        #self.f1 = nn.Linear(64, 14)
        
        
        
        
    def forward(self, z1, mz2, x2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        h = self.fc(h).unsqueeze(1)
        
        bs = x2.shape[0]
        
        x2 = x2[:,:-1,:]
        # initial input
        zeros = torch.zeros((bs, 1, 14)).to(DEVICE)
        
        inputs = torch.cat([h, zeros, x2], dim=1)
        rnn_o, rnn_s = self.rnn1(inputs)
        rnn_o, rnn_s = self.rnn2(rnn_o)
        output = self.fc2(rnn_o)[:, 1:]
        
        #output[:,:,:6] = nn.Softmax(dim=2)(output[:,:,:6])
        
        #output[:,:,6:12] = nn.Softmax(dim=2)(output[:,:,6:12])
        
        return output, pz1_m, pz1_s
        
        
    def reconstruct_image(self, out, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0]), # yellow
                  5: np.array([1,1,1])  # white
                  }
        
        out = out.cpu().numpy()
        n = out.shape[0]
        output = np.ones((n,25,100,3))

        white = np.array([1,1,1])
        
        for i in range(n):
            for j in range(100):
                col1 = color_dict[np.argmax(out[i,j,:6])]
                col2 = color_dict[np.argmax(out[i,j,6:12])]
                len1 = out[i,j,12]
                len2 = out[i,j,13]
                    
                
                h1 = max(0,13-round(len1))
                # paint upper bar
                output[i, h1:13, j] = col1
                h2 = min(25,13+round(len2))
                # paint lower bar
                output[i, 13:h2, j] = col2
        
        
        return output


In [9]:
int(np.round(3.7, 0)),int(3.7)

(4, 3)

In [10]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
pzy_ = px(45, 7500, 2, 256,256,256,256)
summary(pzy_, [(1,256),(1,456),(1,100,14)])

Layer (type:depth-idx)                   Output Shape              Param #
px                                       --                        --
├─Sequential: 1-1                        [1, 256]                  --
│    └─Linear: 2-1                       [1, 256]                  116,992
│    └─ReLU: 2-2                         [1, 256]                  --
│    └─Linear: 2-3                       [1, 256]                  65,792
│    └─ReLU: 2-4                         [1, 256]                  --
├─Sequential: 1-2                        [1, 256]                  --
│    └─Linear: 2-5                       [1, 256]                  65,792
├─Sequential: 1-3                        [1, 256]                  --
│    └─Linear: 2-6                       [1, 256]                  65,792
│    └─Softplus: 2-7                     [1, 256]                  --
├─Sequential: 1-4                        [1, 256]                  --
│    └─Linear: 2-8                       [1, 256]                  6

## Endcoder Classes

In [11]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [12]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(60, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(60, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(72, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(2592, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(2592, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(60, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(60, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(72, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(2592, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 2592)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 2592)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [13]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [14]:
enc = qz(128,10,10,256,256,256,256)
enc(torch.zeros((1,3,25,100)), torch.zeros((1,200)))
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 72, 3, 12]            --
│    └─Conv2d: 2-1                       [1, 48, 25, 100]          1,344
│    └─ReLU: 2-2                         [1, 48, 25, 100]          --
│    └─Conv2d: 2-3                       [1, 48, 25, 100]          20,784
│    └─ReLU: 2-4                         [1, 48, 25, 100]          --
│    └─MaxPool2d: 2-5                    [1, 48, 12, 50]           --
│    └─Conv2d: 2-6                       [1, 60, 12, 50]           25,980
│    └─ReLU: 2-7                         [1, 60, 12, 50]           --
│    └─Conv2d: 2-8                       [1, 60, 12, 50]           32,460
│    └─ReLU: 2-9                         [1, 60, 12, 50]           --
│    └─MaxPool2d: 2-10                   [1, 60, 6, 25]            --
│    └─Conv2d: 2-11                      [1, 72, 6, 25]            38,

In [15]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [16]:
class RHVAE(nn.Module):
    def __init__(self, args):
        super(RHVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        #self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m, x2):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x2_hat, pz1_m, pz1_s = self.px(z1, mz2, x2)
        
        return x2_hat, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, x2):
        
        x2_hat, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m, x2)
        
        # Reconstruction Loss
        x_rec_hat = x2_hat.permute((0,2,1))
        x_rec = x2.permute((0,2,1))
        
        rec = nn.CrossEntropyLoss()(x_rec_hat[:,:6,:], x_rec[:,:6,:]) + nn.CrossEntropyLoss()(x_rec_hat[:,6:12,:], x_rec[:,6:12,:])
        #rec = F.cross_entropy(x_rec[:,:6,:], x_rec_hat[:,:6,:], reduction='sum') + F.cross_entropy(x_rec[:,6:12,:], x_rec_hat[:,6:12,:], reduction='sum')

        
        mse_bar = F.mse_loss(x2_hat[:,:,12:], x2[:,:,12:], reduction='mean')

        acc_bar = 0
        acc_bar2 = 1
        acc_bar3 = 1
        acc_bar4 = 1
        acc_bar5 = 1
        
        RE_bar = mse_bar
        RE_col = rec
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        
        return self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_col, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5
    
    
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [17]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [18]:
default_args = diva_args()
enc = RHVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200),(1,100,14)])

Layer (type:depth-idx)                   Output Shape              Param #
RHVAE                                    --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [19]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [20]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [21]:
#RNA_dataset.x_bar.shape, RNA_dataset.x_col.shape 

In [22]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    acc_bar2 = 0
    acc_bar3 = 0
    acc_bar4 = 0
    acc_bar5 = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x2, m) in pbar:
        # To device
        x, y, d , x2, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x2.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5 = model.loss_function(d.float(), x.float(), y.float(), m.float(), x2.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        mse_bar += mse
        acc_bar += acc
        acc_bar2 += acc2
        acc_bar3 += acc3
        acc_bar4 += acc4
        acc_bar5 += acc5
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    acc_bar2 /= len(train_loader.dataset)
    acc_bar3 /= len(train_loader.dataset)
    acc_bar4 /= len(train_loader.dataset)
    acc_bar5 /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    pbar.set_postfix(loss=train_loss)
    
    return train_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5 

In [23]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    mse_bar = 0
    acc_bar = 0   
    acc_bar2 = 0
    acc_bar3 = 0
    acc_bar4 = 0
    acc_bar5 = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,x2, m) in enumerate(test_loader):
            x, y, d, x2, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x2.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5  = model.loss_function(d.float(), x.float(), y.float(),m.float(),x2.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            mse_bar += mse
            acc_bar += acc
            acc_bar2 += acc2
            acc_bar3 += acc3
            acc_bar4 += acc4
            acc_bar5 += acc5
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    acc_bar2 /= len(test_loader.dataset)
    acc_bar3 /= len(test_loader.dataset)
    acc_bar4 /= len(test_loader.dataset)
    acc_bar5 /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5 
  

In [24]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_col, mtr, atr, atr2, atr3, atr4, atr5  = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_col_test, mte, ate, ate2, ate3, ate4, ate5  = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("coll loss", {'train':avg_loss_col, 'test':avg_loss_col_test}, epoch)
#             writer.add_scalars("bar_acc",{'train-top1': atr, 'test-top1':ate,
#                                           'train-top2': atr2, 'test-top2':ate2,
#                                           'train-top3': atr3, 'test-top3':ate3,
#                                           'train-top4': atr4, 'test-top4':ate4,
#                                           'train-top5': atr4, 'test-top5':ate4
#                                          }, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [25]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        x2 = a[1][3][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x2_hat ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m,x2)
        out = diva.px.reconstruct_image(x2_hat)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [26]:
DEVICE

device(type='cuda')

## Model Training

In [27]:
default_args = diva_args(prewarmup=0, number_components=50, z1_dim=256, z2_dim=256)

In [28]:
diva = RHVAE(default_args).to(DEVICE)

In [29]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/new/HVAE3/checkpoints/100.pth'))

In [30]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [31]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [32]:
#RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

In [33]:
writer.flush()

In [34]:
#diva.rec_gamma = 3

In [35]:
%tensorboard  --logdir="D:/users/Marko/downloads/mirna/saved_models/new/RHVAE/tensorboard/"

Reusing TensorBoard on port 6006 (pid 18760), started 1:03:36 ago. (Use '!kill 18760' to kill it.)

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 0, save_folder="new/RHVAE",save_interval=5)

Epoch 1: 272batch [00:37,  7.21batch/s, loss=1.79] 


epoch 1: avg train loss 1.01, bar train loss 0.042, col train loss 0.017
epoch 1: avg test  loss 0.46, bar  test loss 0.020, col  test loss 0.013


Epoch 2: 272batch [00:37,  7.24batch/s, loss=1.75] 


epoch 2: avg train loss 0.48, bar train loss 0.022, col train loss 0.013


Epoch 3: 1batch [00:00,  7.30batch/s, loss=0.506]

epoch 2: avg test  loss 0.44, bar  test loss 0.019, col  test loss 0.013


Epoch 3: 272batch [00:37,  7.21batch/s, loss=1.74] 


epoch 3: avg train loss 0.47, bar train loss 0.021, col train loss 0.013


Epoch 4: 1batch [00:00,  7.30batch/s, loss=0.476]

epoch 3: avg test  loss 0.41, bar  test loss 0.019, col  test loss 0.013


Epoch 4: 272batch [00:37,  7.26batch/s, loss=1.82] 


epoch 4: avg train loss 0.46, bar train loss 0.021, col train loss 0.013


Epoch 5: 1batch [00:00,  7.25batch/s, loss=0.451]

epoch 4: avg test  loss 0.41, bar  test loss 0.019, col  test loss 0.013


Epoch 5: 272batch [00:37,  7.17batch/s, loss=1.78] 


epoch 5: avg train loss 0.45, bar train loss 0.020, col train loss 0.013
epoch 5: avg test  loss 0.41, bar  test loss 0.018, col  test loss 0.012


Epoch 6: 272batch [00:38,  7.15batch/s, loss=1.66] 


epoch 6: avg train loss 0.44, bar train loss 0.020, col train loss 0.013


Epoch 7: 1batch [00:00,  7.30batch/s, loss=0.436]

epoch 6: avg test  loss 0.40, bar  test loss 0.018, col  test loss 0.012


Epoch 7: 272batch [00:37,  7.24batch/s, loss=1.93] 


epoch 7: avg train loss 0.45, bar train loss 0.020, col train loss 0.013


Epoch 8: 1batch [00:00,  7.30batch/s, loss=0.429]

epoch 7: avg test  loss 0.39, bar  test loss 0.018, col  test loss 0.012


Epoch 8: 272batch [00:37,  7.19batch/s, loss=1.73] 


epoch 8: avg train loss 0.43, bar train loss 0.019, col train loss 0.013


Epoch 9: 1batch [00:00,  6.85batch/s, loss=0.459]

epoch 8: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 9: 272batch [00:38,  7.14batch/s, loss=1.63] 


epoch 9: avg train loss 0.42, bar train loss 0.019, col train loss 0.012


Epoch 10: 1batch [00:00,  7.19batch/s, loss=0.452]

epoch 9: avg test  loss 0.39, bar  test loss 0.017, col  test loss 0.012


Epoch 10: 272batch [00:37,  7.16batch/s, loss=1.62] 


epoch 10: avg train loss 0.42, bar train loss 0.019, col train loss 0.012
epoch 10: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 11: 272batch [00:38,  7.15batch/s, loss=1.47] 


epoch 11: avg train loss 0.41, bar train loss 0.019, col train loss 0.012


Epoch 12: 1batch [00:00,  7.25batch/s, loss=0.398]

epoch 11: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 12: 272batch [00:37,  7.18batch/s, loss=1.65] 


epoch 12: avg train loss 0.41, bar train loss 0.018, col train loss 0.012


Epoch 13: 1batch [00:00,  7.30batch/s, loss=0.419]

epoch 12: avg test  loss 0.37, bar  test loss 0.017, col  test loss 0.012


Epoch 13: 272batch [00:38,  7.14batch/s, loss=1.45] 


epoch 13: avg train loss 0.40, bar train loss 0.018, col train loss 0.012


Epoch 14: 1batch [00:00,  7.30batch/s, loss=0.412]

epoch 13: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 14: 272batch [00:37,  7.19batch/s, loss=1.48] 


epoch 14: avg train loss 0.40, bar train loss 0.018, col train loss 0.012


Epoch 15: 1batch [00:00,  7.25batch/s, loss=0.444]

epoch 14: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 15: 272batch [00:37,  7.22batch/s, loss=1.38] 


epoch 15: avg train loss 0.39, bar train loss 0.018, col train loss 0.012
epoch 15: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 16: 272batch [00:37,  7.22batch/s, loss=1.45] 


epoch 16: avg train loss 0.39, bar train loss 0.017, col train loss 0.012


Epoch 17: 1batch [00:00,  7.25batch/s, loss=0.396]

epoch 16: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 17: 272batch [00:38,  7.15batch/s, loss=1.23] 


epoch 17: avg train loss 0.39, bar train loss 0.017, col train loss 0.012


Epoch 18: 1batch [00:00,  7.14batch/s, loss=0.369]

epoch 17: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 18: 272batch [00:38,  7.13batch/s, loss=1.51] 


epoch 18: avg train loss 0.38, bar train loss 0.017, col train loss 0.012


Epoch 19: 1batch [00:00,  7.25batch/s, loss=0.367]

epoch 18: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 19: 272batch [00:37,  7.16batch/s, loss=1.43] 


epoch 19: avg train loss 0.38, bar train loss 0.017, col train loss 0.012


Epoch 20: 1batch [00:00,  7.25batch/s, loss=0.378]

epoch 19: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 20: 272batch [00:38,  7.14batch/s, loss=1.4]  


epoch 20: avg train loss 0.38, bar train loss 0.017, col train loss 0.012
epoch 20: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 21: 272batch [00:38,  7.11batch/s, loss=1.38] 


epoch 21: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 22: 1batch [00:00,  7.04batch/s, loss=0.388]

epoch 21: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 22: 272batch [00:38,  7.11batch/s, loss=1.34] 


epoch 22: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 23: 1batch [00:00,  7.19batch/s, loss=0.38]

epoch 22: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 23: 272batch [00:37,  7.21batch/s, loss=1.3]  


epoch 23: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 24: 1batch [00:00,  7.19batch/s, loss=0.367]

epoch 23: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 24: 272batch [00:38,  7.12batch/s, loss=1.33] 


epoch 24: avg train loss 0.37, bar train loss 0.016, col train loss 0.012


Epoch 25: 1batch [00:00,  7.19batch/s, loss=0.41]

epoch 24: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 25: 272batch [00:38,  7.11batch/s, loss=1.42] 


epoch 25: avg train loss 0.37, bar train loss 0.016, col train loss 0.012
epoch 25: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 26: 272batch [00:38,  7.07batch/s, loss=1.41] 


epoch 26: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 27: 1batch [00:00,  7.19batch/s, loss=0.345]

epoch 26: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 27: 272batch [00:38,  7.09batch/s, loss=1.49] 


epoch 27: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 28: 1batch [00:00,  7.09batch/s, loss=0.391]

epoch 27: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 28: 272batch [00:37,  7.17batch/s, loss=1.35] 


epoch 28: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 29: 1batch [00:00,  6.62batch/s, loss=0.402]

epoch 28: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 29: 272batch [00:38,  7.07batch/s, loss=1.23] 


epoch 29: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 30: 1batch [00:00,  6.99batch/s, loss=0.354]

epoch 29: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 30: 272batch [00:38,  7.04batch/s, loss=1.26] 


epoch 30: avg train loss 0.36, bar train loss 0.016, col train loss 0.012
epoch 30: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 31: 272batch [00:38,  7.09batch/s, loss=1.34] 


epoch 31: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 32: 1batch [00:00,  7.14batch/s, loss=0.378]

epoch 31: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 32: 272batch [00:38,  7.05batch/s, loss=1.29] 


epoch 32: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 33: 1batch [00:00,  7.14batch/s, loss=0.376]

epoch 32: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 33: 272batch [00:38,  7.10batch/s, loss=1.34] 


epoch 33: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 34: 1batch [00:00,  7.14batch/s, loss=0.332]

epoch 33: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 34: 272batch [00:38,  7.06batch/s, loss=1.3]  


epoch 34: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 35: 1batch [00:00,  7.19batch/s, loss=0.317]

epoch 34: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 35: 272batch [00:38,  7.06batch/s, loss=1.38] 


epoch 35: avg train loss 0.35, bar train loss 0.015, col train loss 0.012
epoch 35: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 36: 272batch [00:38,  7.06batch/s, loss=1.24] 


epoch 36: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 37: 1batch [00:00,  7.19batch/s, loss=0.373]

epoch 36: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 37: 272batch [00:38,  7.08batch/s, loss=1.34] 


epoch 37: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 38: 1batch [00:00,  6.99batch/s, loss=0.328]

epoch 37: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 38: 272batch [00:38,  7.00batch/s, loss=1.28] 


epoch 38: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 39: 1batch [00:00,  7.04batch/s, loss=0.332]

epoch 38: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 39: 272batch [00:38,  7.01batch/s, loss=1.3]  


epoch 39: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 40: 1batch [00:00,  7.04batch/s, loss=0.345]

epoch 39: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 40: 272batch [00:38,  7.01batch/s, loss=1.13] 


epoch 40: avg train loss 0.34, bar train loss 0.015, col train loss 0.012
epoch 40: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 41: 272batch [00:38,  7.02batch/s, loss=1.29] 


epoch 41: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 42: 1batch [00:00,  6.99batch/s, loss=0.298]

epoch 41: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 42: 272batch [00:38,  7.01batch/s, loss=1.32] 


epoch 42: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 43: 1batch [00:00,  7.09batch/s, loss=0.369]

epoch 42: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 43: 272batch [00:38,  7.02batch/s, loss=1.1]  


epoch 43: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 44: 1batch [00:00,  7.04batch/s, loss=0.349]

epoch 43: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 44: 272batch [00:39,  6.88batch/s, loss=1.21] 


epoch 44: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 45: 1batch [00:00,  6.94batch/s, loss=0.292]

epoch 44: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 45: 272batch [00:39,  6.97batch/s, loss=1.1]  


epoch 45: avg train loss 0.34, bar train loss 0.015, col train loss 0.012
epoch 45: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 46: 272batch [00:39,  6.88batch/s, loss=1.11] 


epoch 46: avg train loss 0.33, bar train loss 0.015, col train loss 0.012


Epoch 47: 1batch [00:00,  6.54batch/s, loss=0.291]

epoch 46: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 47: 269batch [00:39,  6.75batch/s, loss=0.343]

In [None]:
#lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1000,500,save_folder="new/HVAE2",save_interval=5)

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')