In [2]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [3]:
%load_ext tensorboard

In [76]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/HVAEFCP1/tensorboard")

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
DEVICE

device(type='cuda')

# Model Classes

In [7]:
class diva_args:

    def __init__(self, z1_dim=1000, z2_dim=1000, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 600, h2_dim = 600, number_components = 500,
                 beta=1, rec_alpha = 100, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [8]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_col, x_bar = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len2.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar2.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col2.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x_len, x_col, x_bar, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,5,200), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]) ,2*j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 2*j+1] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len2.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col2.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar2.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [65]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=1200, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        
        self.fc1 = nn.Sequential(nn.Linear(2*h_dim, dim0, bias=False),  
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(dim0, dim2, bias=False),  
                                 nn.ReLU())
#         self.fc3 = nn.Sequential(nn.Linear(dim1, dim2, bias=False),
#                                  nn.ReLU())
        
        # Predicting length and color of each bar
        #self.color = nn.Sequential(nn.Conv1d(1,5, kernel_size=1, bias=False), 
                                  # nn.Softmax(dim=1))
        # Predicting color of each bar
        self.color_bar_black = nn.Linear(dim2,200)
        self.color_bar_reddd = nn.Linear(dim2,200)
        self.color_bar_bluee = nn.Linear(dim2,200)
        self.color_bar_green = nn.Linear(dim2,200)
        self.color_bar_yelow = nn.Linear(dim2,200)
        
        # Predicting the length of each bar
        self.length_bar_top = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        self.length_bar_bot = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        #self.length_bar_scale = nn.Sequential(nn.Conv1d(100, 1, kernel_size = 3, padding = 'same', bias=False), nn.Sigmoid())
        
        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(dim2,400), nn.ReLU(),nn.Linear(400,1), nn.Softplus())
        #self.length_RNA_scale = nn.Sequential(nn.Linear(400,1, bias=False), nn.Sigmoid())
        
    def forward(self, z1, mz2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        h = self.fc1(h)
        h = nn.Dropout(0.2)(h)
        h = self.fc2(h)
        h = nn.Dropout(0.2)(h)
        len_RNA = self.length_RNA(h)
        len_bar = torch.cat([self.length_bar_top(h)[:,None,:],self.length_bar_bot(h)[:,None,:]], dim=1) 
#         h = self.fc3(h)
#         h = nn.Dropout(0.3)(h)
        
        len_RNA_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_RNA_sc = torch.exp(self.length_RNA_scale(h))
        
        
        
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_bar_sc = torch.exp(self.length_bar_scale(h))
        
#        col_bar = self.color(h[:,None,:])
        black = self.color_bar_black(h)
        black = nn.Sigmoid()(black)[:,None,:]
        reddd = self.color_bar_reddd(h)[:,None,:]
        bluee = self.color_bar_bluee(h)[:,None,:]
        green = self.color_bar_green(h)[:,None,:]
        yelow = self.color_bar_yelow(h)[:,None,:]
        
        col = torch.cat([reddd,bluee,green,yelow], dim=1)
        
        col = nn.Softmax(dim=1)(col)
        
        
        col_bar = torch.cat([black,col*((1-black).repeat(1,4,1))], 1)
        
        return len_RNA, len_RNA_sc, len_bar, len_bar_sc, col_bar, pz1_m, pz1_s

    def reconstruct_image(self, len_RNA, var_RNA, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()
        var_RNA = var_RNA.cpu().numpy()
        #.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = int(np.round(np.random.normal(loc=len_RNA[i], scale=var_RNA[i])))
            else:
                limit = int(np.round(len_RNA[i]))
            limit = min(100, limit)
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1 = np.argmax(col_bar[i,:, 2*j])
                    _col_bar_2 = np.argmax(col_bar[i,:, 2*j+1])
                
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output


In [66]:
int(np.round(3.7, 0))
int(3.7)

3

In [67]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [68]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [69]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(2560, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(2560, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(2560, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

#         torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         torch.nn.init.xavier_uniform_(self.encoder[3].weight)
#         torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         self.fc11[0].bias.data.zero_()
#         torch.nn.init.xavier_uniform_(self.fc12[0].weight)
#         self.fc12[0].bias.data.zero_()
    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 2560)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 2560)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [70]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [71]:
enc = qz(128,10,10,10,500,400,400)
enc(torch.zeros((1,3,25,100)), torch.zeros((1,200)))
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 1, 10]           --
│    └─Conv2d: 2-1                       [1, 32, 23, 98]           864
│    └─ReLU: 2-2                         [1, 32, 23, 98]           --
│    └─Conv2d: 2-3                       [1, 64, 21, 96]           18,432
│    └─ReLU: 2-4                         [1, 64, 21, 96]           --
│    └─MaxPool2d: 2-5                    [1, 64, 10, 48]           --
│    └─Conv2d: 2-6                       [1, 128, 8, 46]           73,728
│    └─ReLU: 2-7                         [1, 128, 8, 46]           --
│    └─MaxPool2d: 2-8                    [1, 128, 4, 23]           --
│    └─Conv2d: 2-9                       [1, 256, 2, 21]           294,912
│    └─ReLU: 2-10                        [1, 256, 2, 21]           --
│    └─MaxPool2d: 2-11                   [1, 256, 1, 10]           --
├

In [72]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [79]:
class HVAE(nn.Module):
    def __init__(self, args):
        super(HVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, pz1_m, pz1_s = self.px(z1, mz2)
        
        return x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, out_len, out_bar, out_col):
        
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m)
        
        # Reconstruction Loss
        mask = 1 - F.one_hot(torch.round(out_len).to(torch.int64)*2-1, 200).cumsum(dim=1)[:,None,:]
        mask1 = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,:]).repeat(1,2,1)

        x_col = mask.repeat(1,5,1)*x_col
        
#         x_gap = x_col[:,0,:]
#         x_nuc = x_col[:,1:,:]
                
        
        
        dist_len = dist.Normal(x_len, x_len_scale+1e-7)
        log_len = dist_len.log_prob(out_len[:,None]).mean()
         
        mse_bar = ((((x_bar - out_bar)**2)*mask1).sum(dim=(1,2))/(mask1.sum(dim=(1,2)))).sum()#.detach().item()
        
        max_bar = torch.argmax(x_col, dim=1)
        #print(max_bar.shape, out_col.shape, mask.shape)
        acc_bar = (((max_bar==torch.argmax(out_col, dim=1))*mask).sum().float()/(mask.sum(1))).sum()
        acc_bar2 = (((max_bar==torch.argsort(out_col, dim=1)[:,1,:])*mask).sum().float()/(mask.sum(1))).sum() + acc_bar
        #print(acc_bar.shape)
        RE_len = -log_len
        RE_bar = mse_bar#-log_bar
        RE_col = F.cross_entropy(x_col, out_col, reduction='sum')
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        #print(KL_p_z1.shape,KL_p_z2.shape,KL_q_z1.shape,KL_q_z2.shape)
        
        
#         self.lpz1.append(KL_p_z1.detach().item())
#         self.lpz2.append(KL_p_z2.detach().item())
#         self.lqz1.append(KL_q_z1.detach().item())
#         self.lqz2.append(KL_q_z2.detach().item())
        
#         self.bar.append(RE_bar.detach())
#         self.col.append(RE_col.detach())
#         self.len.append(RE_len.detach())
        
        
        return self.rec_alpha * RE_len \
                  + self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_len, RE_col, mse_bar, acc_bar, acc_bar2
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [80]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [81]:
default_args = diva_args()
enc = HVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
HVAE                                     --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [77]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [78]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [24]:
len(RNA_dataset)

34721

In [82]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    acc_bar2 = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar, m) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss, mse, acc, acc2 = model.loss_function(d.float(), x.float(), y.float(), m.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        mse_bar += mse
        acc_bar += acc
        acc_bar2 += acc2
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    acc_bar2 /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2

In [83]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    mse_bar = 0
    acc_bar = 0   
    acc_bar2 = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar, m) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, len_loss, col_loss, mse, acc, acc2 = model.loss_function(d.float(), x.float(), y.float(),m.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
            mse_bar += mse
            acc_bar += acc
            acc_bar2 += acc2
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_len_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    acc_bar2 /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2
  

In [97]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col, mtr, atr, atr2 = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test, mte, ate, ate2 = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train-top1': atr, 'test-top1':ate, 'train-top2': atr2, 'test-top2':ate2}, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [98]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x_1, x_1var, x_2, x_2var, x_3 ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m)
        out = diva.px.reconstruct_image(x_1, x_1var, x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [99]:
DEVICE

device(type='cuda')

## Model Training

In [100]:
default_args = diva_args(prewarmup=0, number_components=50)

In [101]:
diva = HVAE(default_args).to(DEVICE)

In [102]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE10/checkpoints/905.pth'))

In [103]:
train_loader = DataLoader(RNA_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=64)

In [104]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [105]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [106]:
writer.flush()

In [107]:
diva.rec_gamma = 3

In [95]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/HVAEFCP1/tensorboard/"

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 0, save_folder="HVAEFCP1",save_interval=5)

Epoch 1: 543batch [00:36, 14.81batch/s, loss=762]   


epoch 1: avg train loss 818.13, bar train loss 11.790, len train loss 0.638, col train loss 171.595
epoch 1: avg test  loss 730.94, bar  test loss 9.307, len  test loss 0.301, col  test loss 169.814


Epoch 2: 543batch [00:36, 14.75batch/s, loss=774]


epoch 2: avg train loss 706.39, bar train loss 8.756, len train loss 0.193, col train loss 168.990


Epoch 3: 2batch [00:00, 13.89batch/s, loss=744]

epoch 2: avg test  loss 692.03, bar  test loss 8.329, len  test loss 0.153, col  test loss 168.410


Epoch 3: 543batch [00:37, 14.50batch/s, loss=743]


epoch 3: avg train loss 688.16, bar train loss 8.104, len train loss 0.167, col train loss 167.908


Epoch 4: 2batch [00:00, 14.39batch/s, loss=630]

epoch 3: avg test  loss 674.53, bar  test loss 7.725, len  test loss 0.132, col  test loss 166.789


Epoch 4: 543batch [00:37, 14.42batch/s, loss=693]


epoch 4: avg train loss 670.80, bar train loss 7.247, len train loss 0.207, col train loss 166.047


Epoch 5: 2batch [00:00, 14.18batch/s, loss=660]

epoch 4: avg test  loss 651.84, bar  test loss 6.765, len  test loss 0.129, col  test loss 165.391


Epoch 5: 543batch [00:37, 14.61batch/s, loss=611]


epoch 5: avg train loss 650.31, bar train loss 6.537, len train loss 0.170, col train loss 165.018
epoch 5: avg test  loss 642.37, bar  test loss 6.364, len  test loss 0.138, col  test loss 164.764


Epoch 6: 543batch [00:36, 14.72batch/s, loss=642]


epoch 6: avg train loss 641.45, bar train loss 6.157, len train loss 0.171, col train loss 164.480


Epoch 7: 2batch [00:00, 14.39batch/s, loss=644]

epoch 6: avg test  loss 633.91, bar  test loss 5.987, len  test loss 0.140, col  test loss 164.221


Epoch 7: 543batch [00:37, 14.53batch/s, loss=682]


epoch 7: avg train loss 633.40, bar train loss 5.870, len train loss 0.161, col train loss 163.988


Epoch 8: 2batch [00:00, 14.39batch/s, loss=645]

epoch 7: avg test  loss 628.93, bar  test loss 5.739, len  test loss 0.155, col  test loss 163.680


Epoch 8: 543batch [00:37, 14.60batch/s, loss=598]


epoch 8: avg train loss 626.43, bar train loss 5.623, len train loss 0.154, col train loss 163.606


Epoch 9: 2batch [00:00, 13.61batch/s, loss=616]

epoch 8: avg test  loss 620.14, bar  test loss 5.497, len  test loss 0.121, col  test loss 163.501


Epoch 9: 543batch [00:37, 14.50batch/s, loss=652]


epoch 9: avg train loss 620.35, bar train loss 5.438, len train loss 0.142, col train loss 163.251


Epoch 10: 2batch [00:00, 13.89batch/s, loss=619]

epoch 9: avg test  loss 627.33, bar  test loss 5.454, len  test loss 0.211, col  test loss 163.145


Epoch 10: 543batch [00:37, 14.49batch/s, loss=601]


epoch 10: avg train loss 616.21, bar train loss 5.291, len train loss 0.140, col train loss 162.945
epoch 10: avg test  loss 612.13, bar  test loss 5.216, len  test loss 0.118, col  test loss 162.933


Epoch 11: 543batch [00:37, 14.58batch/s, loss=623]


epoch 11: avg train loss 611.44, bar train loss 5.137, len train loss 0.132, col train loss 162.655


Epoch 12: 2batch [00:00, 14.39batch/s, loss=609]

epoch 11: avg test  loss 610.95, bar  test loss 5.095, len  test loss 0.130, col  test loss 162.675


Epoch 12: 543batch [00:36, 14.68batch/s, loss=590]


epoch 12: avg train loss 606.94, bar train loss 5.000, len train loss 0.122, col train loss 162.386


Epoch 13: 2batch [00:00, 14.29batch/s, loss=605]

epoch 12: avg test  loss 606.62, bar  test loss 4.986, len  test loss 0.123, col  test loss 162.496


Epoch 13: 543batch [00:36, 14.72batch/s, loss=632]


epoch 13: avg train loss 604.54, bar train loss 4.904, len train loss 0.125, col train loss 162.112


Epoch 14: 2batch [00:00, 14.49batch/s, loss=557]

epoch 13: avg test  loss 606.59, bar  test loss 4.889, len  test loss 0.134, col  test loss 162.200


Epoch 14: 543batch [00:36, 14.73batch/s, loss=565]


epoch 14: avg train loss 601.00, bar train loss 4.804, len train loss 0.116, col train loss 161.905


Epoch 15: 2batch [00:00, 13.99batch/s, loss=604]

epoch 14: avg test  loss 623.56, bar  test loss 4.847, len  test loss 0.330, col  test loss 162.079


Epoch 15: 543batch [00:37, 14.57batch/s, loss=678]


epoch 15: avg train loss 597.11, bar train loss 4.691, len train loss 0.106, col train loss 161.696
epoch 15: avg test  loss 595.60, bar  test loss 4.682, len  test loss 0.085, col  test loss 161.909


Epoch 16: 543batch [00:36, 14.72batch/s, loss=570]


epoch 16: avg train loss 593.15, bar train loss 4.594, len train loss 0.092, col train loss 161.485


Epoch 17: 2batch [00:00, 14.39batch/s, loss=634]

epoch 16: avg test  loss 592.92, bar  test loss 4.617, len  test loss 0.080, col  test loss 161.661


Epoch 17: 543batch [00:37, 14.62batch/s, loss=570]


epoch 17: avg train loss 591.01, bar train loss 4.528, len train loss 0.089, col train loss 161.307


Epoch 18: 2batch [00:00, 14.60batch/s, loss=591]

epoch 17: avg test  loss 595.07, bar  test loss 4.530, len  test loss 0.124, col  test loss 161.550


Epoch 18: 543batch [00:36, 14.72batch/s, loss=578]


epoch 18: avg train loss 588.62, bar train loss 4.469, len train loss 0.083, col train loss 161.118


Epoch 19: 2batch [00:00, 14.18batch/s, loss=568]

epoch 18: avg test  loss 588.97, bar  test loss 4.433, len  test loss 0.076, col  test loss 161.401


Epoch 19: 543batch [00:37, 14.67batch/s, loss=614]


epoch 19: avg train loss 586.60, bar train loss 4.404, len train loss 0.080, col train loss 160.980


Epoch 20: 2batch [00:00, 14.18batch/s, loss=606]

epoch 19: avg test  loss 587.59, bar  test loss 4.403, len  test loss 0.072, col  test loss 161.381


Epoch 20: 543batch [00:36, 14.70batch/s, loss=609]


epoch 20: avg train loss 584.29, bar train loss 4.341, len train loss 0.074, col train loss 160.816
epoch 20: avg test  loss 587.76, bar  test loss 4.487, len  test loss 0.074, col  test loss 161.212


Epoch 21: 543batch [00:37, 14.65batch/s, loss=606]


epoch 21: avg train loss 582.26, bar train loss 4.292, len train loss 0.069, col train loss 160.635


Epoch 22: 2batch [00:00, 14.29batch/s, loss=604]

epoch 21: avg test  loss 584.81, bar  test loss 4.329, len  test loss 0.066, col  test loss 161.129


Epoch 22: 543batch [00:36, 14.71batch/s, loss=535]


epoch 22: avg train loss 581.00, bar train loss 4.248, len train loss 0.070, col train loss 160.492


Epoch 23: 2batch [00:00, 14.60batch/s, loss=559]

epoch 22: avg test  loss 583.31, bar  test loss 4.346, len  test loss 0.064, col  test loss 161.033


Epoch 23: 543batch [00:36, 14.70batch/s, loss=594]


epoch 23: avg train loss 579.39, bar train loss 4.204, len train loss 0.068, col train loss 160.320


Epoch 24: 2batch [00:00, 14.29batch/s, loss=567]

epoch 23: avg test  loss 582.57, bar  test loss 4.290, len  test loss 0.069, col  test loss 160.897


Epoch 24: 543batch [00:36, 14.72batch/s, loss=618]


epoch 24: avg train loss 578.13, bar train loss 4.155, len train loss 0.069, col train loss 160.208


Epoch 25: 2batch [00:00, 14.29batch/s, loss=565]

epoch 24: avg test  loss 581.31, bar  test loss 4.248, len  test loss 0.064, col  test loss 160.783


Epoch 25: 543batch [00:36, 14.70batch/s, loss=579]


epoch 25: avg train loss 577.00, bar train loss 4.129, len train loss 0.067, col train loss 160.050
epoch 25: avg test  loss 580.37, bar  test loss 4.249, len  test loss 0.059, col  test loss 160.738


Epoch 26: 543batch [00:36, 14.71batch/s, loss=576]


epoch 26: avg train loss 575.70, bar train loss 4.088, len train loss 0.066, col train loss 159.891


Epoch 27: 2batch [00:00, 14.29batch/s, loss=585]

epoch 26: avg test  loss 579.77, bar  test loss 4.215, len  test loss 0.064, col  test loss 160.594


Epoch 27: 543batch [00:37, 14.61batch/s, loss=595]


epoch 27: avg train loss 574.32, bar train loss 4.055, len train loss 0.064, col train loss 159.757


Epoch 28: 2batch [00:00, 14.29batch/s, loss=586]

epoch 27: avg test  loss 579.60, bar  test loss 4.203, len  test loss 0.060, col  test loss 160.447


Epoch 28: 543batch [00:37, 14.66batch/s, loss=614]


epoch 28: avg train loss 573.26, bar train loss 4.025, len train loss 0.064, col train loss 159.611


Epoch 29: 2batch [00:00, 14.29batch/s, loss=550]

epoch 28: avg test  loss 578.57, bar  test loss 4.167, len  test loss 0.063, col  test loss 160.485


Epoch 29: 543batch [00:36, 14.70batch/s, loss=549]


epoch 29: avg train loss 572.64, bar train loss 3.996, len train loss 0.065, col train loss 159.516


Epoch 30: 2batch [00:00, 14.18batch/s, loss=543]

epoch 29: avg test  loss 578.37, bar  test loss 4.117, len  test loss 0.070, col  test loss 160.340


Epoch 30: 543batch [00:37, 14.67batch/s, loss=592]


epoch 30: avg train loss 571.07, bar train loss 3.967, len train loss 0.061, col train loss 159.377
epoch 30: avg test  loss 576.97, bar  test loss 4.116, len  test loss 0.059, col  test loss 160.185


Epoch 31: 543batch [00:36, 14.69batch/s, loss=547]


epoch 31: avg train loss 570.34, bar train loss 3.928, len train loss 0.063, col train loss 159.272


Epoch 32: 2batch [00:00, 14.39batch/s, loss=560]

epoch 31: avg test  loss 576.28, bar  test loss 4.074, len  test loss 0.058, col  test loss 160.167


Epoch 32: 543batch [00:36, 14.69batch/s, loss=630]


epoch 32: avg train loss 569.35, bar train loss 3.906, len train loss 0.061, col train loss 159.191


Epoch 33: 2batch [00:00, 14.49batch/s, loss=569]

epoch 32: avg test  loss 577.19, bar  test loss 4.124, len  test loss 0.070, col  test loss 160.214


Epoch 33: 543batch [00:37, 14.53batch/s, loss=569]


epoch 33: avg train loss 568.67, bar train loss 3.880, len train loss 0.062, col train loss 159.088


Epoch 34: 2batch [00:00, 14.18batch/s, loss=560]

epoch 33: avg test  loss 575.14, bar  test loss 4.086, len  test loss 0.057, col  test loss 160.227


Epoch 34: 543batch [00:36, 14.69batch/s, loss=650]


epoch 34: avg train loss 567.97, bar train loss 3.869, len train loss 0.060, col train loss 159.018


Epoch 35: 2batch [00:00, 14.49batch/s, loss=570]

epoch 34: avg test  loss 575.05, bar  test loss 4.028, len  test loss 0.061, col  test loss 160.105


Epoch 35: 543batch [00:36, 14.68batch/s, loss=602]


epoch 35: avg train loss 566.75, bar train loss 3.837, len train loss 0.058, col train loss 158.887
epoch 35: avg test  loss 574.82, bar  test loss 4.018, len  test loss 0.062, col  test loss 160.001


Epoch 36: 543batch [00:37, 14.65batch/s, loss=599]


epoch 36: avg train loss 566.44, bar train loss 3.821, len train loss 0.060, col train loss 158.820


Epoch 37: 2batch [00:00, 14.49batch/s, loss=552]

epoch 36: avg test  loss 573.78, bar  test loss 4.016, len  test loss 0.061, col  test loss 159.966


Epoch 37: 543batch [00:36, 14.69batch/s, loss=547]


epoch 37: avg train loss 565.45, bar train loss 3.795, len train loss 0.058, col train loss 158.741


Epoch 38: 2batch [00:00, 14.39batch/s, loss=543]

epoch 37: avg test  loss 572.82, bar  test loss 3.990, len  test loss 0.059, col  test loss 159.908


Epoch 38: 543batch [00:36, 14.68batch/s, loss=576]


epoch 38: avg train loss 565.09, bar train loss 3.786, len train loss 0.059, col train loss 158.680


Epoch 39: 2batch [00:00, 14.29batch/s, loss=536]

epoch 38: avg test  loss 573.45, bar  test loss 3.997, len  test loss 0.063, col  test loss 159.946


Epoch 39: 543batch [00:37, 14.62batch/s, loss=620]


epoch 39: avg train loss 564.29, bar train loss 3.766, len train loss 0.057, col train loss 158.606


Epoch 40: 2batch [00:00, 14.29batch/s, loss=569]

epoch 39: avg test  loss 574.87, bar  test loss 3.983, len  test loss 0.068, col  test loss 159.913


Epoch 40: 543batch [00:36, 14.71batch/s, loss=565]


epoch 40: avg train loss 563.81, bar train loss 3.751, len train loss 0.057, col train loss 158.545
epoch 40: avg test  loss 572.91, bar  test loss 3.983, len  test loss 0.064, col  test loss 159.799


Epoch 41: 543batch [00:36, 14.68batch/s, loss=544]


epoch 41: avg train loss 562.98, bar train loss 3.728, len train loss 0.057, col train loss 158.448


Epoch 42: 2batch [00:00, 14.49batch/s, loss=557]

epoch 41: avg test  loss 573.68, bar  test loss 3.964, len  test loss 0.057, col  test loss 159.746


Epoch 42: 543batch [00:37, 14.67batch/s, loss=537]


epoch 42: avg train loss 562.88, bar train loss 3.722, len train loss 0.058, col train loss 158.393


Epoch 43: 2batch [00:00, 14.18batch/s, loss=569]

epoch 42: avg test  loss 572.80, bar  test loss 3.968, len  test loss 0.064, col  test loss 159.842


Epoch 43: 543batch [00:37, 14.65batch/s, loss=554]


epoch 43: avg train loss 561.78, bar train loss 3.692, len train loss 0.056, col train loss 158.322


Epoch 44: 2batch [00:00, 14.49batch/s, loss=526]

epoch 43: avg test  loss 571.60, bar  test loss 3.961, len  test loss 0.054, col  test loss 159.773


Epoch 44: 543batch [00:37, 14.66batch/s, loss=563]


epoch 44: avg train loss 561.12, bar train loss 3.678, len train loss 0.055, col train loss 158.251


Epoch 45: 2batch [00:00, 14.49batch/s, loss=571]

epoch 44: avg test  loss 571.09, bar  test loss 3.948, len  test loss 0.057, col  test loss 159.767


Epoch 45: 543batch [00:36, 14.68batch/s, loss=573]


epoch 45: avg train loss 560.67, bar train loss 3.667, len train loss 0.055, col train loss 158.183
epoch 45: avg test  loss 571.53, bar  test loss 3.953, len  test loss 0.060, col  test loss 159.816


Epoch 46: 543batch [00:37, 14.63batch/s, loss=608]


epoch 46: avg train loss 560.43, bar train loss 3.661, len train loss 0.056, col train loss 158.130


Epoch 47: 2batch [00:00, 14.39batch/s, loss=572]

epoch 46: avg test  loss 572.18, bar  test loss 3.933, len  test loss 0.058, col  test loss 159.788


Epoch 47: 543batch [00:37, 14.65batch/s, loss=571]


epoch 47: avg train loss 560.01, bar train loss 3.650, len train loss 0.054, col train loss 158.087


Epoch 48: 2batch [00:00, 14.29batch/s, loss=586]

epoch 47: avg test  loss 570.94, bar  test loss 3.916, len  test loss 0.055, col  test loss 159.660


Epoch 48: 543batch [00:37, 14.65batch/s, loss=594]


epoch 48: avg train loss 559.42, bar train loss 3.628, len train loss 0.055, col train loss 158.039


Epoch 49: 2batch [00:00, 14.60batch/s, loss=587]

epoch 48: avg test  loss 570.30, bar  test loss 3.895, len  test loss 0.055, col  test loss 159.664


Epoch 49: 543batch [00:37, 14.66batch/s, loss=552]


epoch 49: avg train loss 558.88, bar train loss 3.618, len train loss 0.054, col train loss 157.965


Epoch 50: 2batch [00:00, 14.60batch/s, loss=547]

epoch 49: avg test  loss 571.28, bar  test loss 3.927, len  test loss 0.065, col  test loss 159.716


Epoch 50: 543batch [00:37, 14.65batch/s, loss=539]


epoch 50: avg train loss 558.39, bar train loss 3.604, len train loss 0.053, col train loss 157.915
epoch 50: avg test  loss 570.25, bar  test loss 3.949, len  test loss 0.057, col  test loss 159.665


Epoch 51: 543batch [00:37, 14.63batch/s, loss=562]


epoch 51: avg train loss 558.34, bar train loss 3.595, len train loss 0.055, col train loss 157.883


Epoch 52: 2batch [00:00, 14.39batch/s, loss=586]

epoch 51: avg test  loss 572.91, bar  test loss 3.937, len  test loss 0.078, col  test loss 159.659


Epoch 52: 543batch [00:37, 14.64batch/s, loss=556]


epoch 52: avg train loss 557.41, bar train loss 3.580, len train loss 0.053, col train loss 157.799


Epoch 53: 2batch [00:00, 14.18batch/s, loss=562]

epoch 52: avg test  loss 570.46, bar  test loss 3.916, len  test loss 0.058, col  test loss 159.723


Epoch 53: 543batch [00:37, 14.63batch/s, loss=558]


epoch 53: avg train loss 557.24, bar train loss 3.572, len train loss 0.053, col train loss 157.765


Epoch 54: 2batch [00:00, 14.29batch/s, loss=582]

epoch 53: avg test  loss 569.64, bar  test loss 3.895, len  test loss 0.055, col  test loss 159.675


Epoch 54: 543batch [00:37, 14.62batch/s, loss=568]


epoch 54: avg train loss 556.90, bar train loss 3.567, len train loss 0.053, col train loss 157.714


Epoch 55: 2batch [00:00, 14.39batch/s, loss=552]

epoch 54: avg test  loss 571.04, bar  test loss 3.955, len  test loss 0.055, col  test loss 159.779


Epoch 55: 543batch [00:37, 14.63batch/s, loss=570]


epoch 55: avg train loss 556.47, bar train loss 3.553, len train loss 0.054, col train loss 157.643
epoch 55: avg test  loss 570.65, bar  test loss 3.916, len  test loss 0.058, col  test loss 159.640


Epoch 56: 543batch [00:37, 14.62batch/s, loss=603]


epoch 56: avg train loss 555.77, bar train loss 3.535, len train loss 0.053, col train loss 157.559


Epoch 57: 2batch [00:00, 14.29batch/s, loss=546]

epoch 56: avg test  loss 569.86, bar  test loss 3.929, len  test loss 0.057, col  test loss 159.725


Epoch 57: 543batch [00:37, 14.64batch/s, loss=569]


epoch 57: avg train loss 555.28, bar train loss 3.527, len train loss 0.051, col train loss 157.522


Epoch 58: 2batch [00:00, 14.29batch/s, loss=559]

epoch 57: avg test  loss 569.50, bar  test loss 3.910, len  test loss 0.058, col  test loss 159.568


Epoch 58: 543batch [00:37, 14.63batch/s, loss=600]


epoch 58: avg train loss 554.97, bar train loss 3.519, len train loss 0.052, col train loss 157.453


Epoch 59: 2batch [00:00, 14.18batch/s, loss=539]

epoch 58: avg test  loss 570.64, bar  test loss 3.916, len  test loss 0.059, col  test loss 159.693


Epoch 59: 543batch [00:37, 14.62batch/s, loss=569]


epoch 59: avg train loss 554.80, bar train loss 3.512, len train loss 0.053, col train loss 157.419


Epoch 60: 2batch [00:00, 14.39batch/s, loss=555]

epoch 59: avg test  loss 570.12, bar  test loss 3.924, len  test loss 0.056, col  test loss 159.632


Epoch 60: 543batch [00:37, 14.59batch/s, loss=509]


epoch 60: avg train loss 554.15, bar train loss 3.496, len train loss 0.052, col train loss 157.351
epoch 60: avg test  loss 569.01, bar  test loss 3.882, len  test loss 0.052, col  test loss 159.543


Epoch 61: 543batch [00:37, 14.48batch/s, loss=569]


epoch 61: avg train loss 554.08, bar train loss 3.493, len train loss 0.052, col train loss 157.348


Epoch 62: 2batch [00:00, 12.42batch/s, loss=572]

epoch 61: avg test  loss 569.44, bar  test loss 3.917, len  test loss 0.055, col  test loss 159.573


Epoch 62: 543batch [00:37, 14.60batch/s, loss=570]


epoch 62: avg train loss 553.90, bar train loss 3.490, len train loss 0.053, col train loss 157.272


Epoch 63: 2batch [00:00, 14.39batch/s, loss=565]

epoch 62: avg test  loss 572.28, bar  test loss 3.924, len  test loss 0.063, col  test loss 159.645


Epoch 63: 543batch [00:37, 14.51batch/s, loss=553]


epoch 63: avg train loss 553.18, bar train loss 3.473, len train loss 0.052, col train loss 157.207


Epoch 64: 2batch [00:00, 13.79batch/s, loss=539]

epoch 63: avg test  loss 570.21, bar  test loss 3.921, len  test loss 0.061, col  test loss 159.707


Epoch 64: 543batch [00:37, 14.40batch/s, loss=609]


epoch 64: avg train loss 553.07, bar train loss 3.470, len train loss 0.052, col train loss 157.167


Epoch 65: 2batch [00:00, 14.49batch/s, loss=536]

epoch 64: avg test  loss 568.91, bar  test loss 3.877, len  test loss 0.056, col  test loss 159.568


Epoch 65: 543batch [00:36, 14.79batch/s, loss=579]


epoch 65: avg train loss 552.69, bar train loss 3.463, len train loss 0.051, col train loss 157.121
epoch 65: avg test  loss 568.45, bar  test loss 3.883, len  test loss 0.056, col  test loss 159.519


Epoch 66: 543batch [00:36, 14.93batch/s, loss=569]


epoch 66: avg train loss 552.38, bar train loss 3.452, len train loss 0.052, col train loss 157.089


Epoch 67: 2batch [00:00, 14.29batch/s, loss=536]

epoch 66: avg test  loss 569.03, bar  test loss 3.904, len  test loss 0.059, col  test loss 159.595


Epoch 67: 543batch [00:36, 14.94batch/s, loss=518]


epoch 67: avg train loss 551.86, bar train loss 3.440, len train loss 0.051, col train loss 157.038


Epoch 68: 2batch [00:00, 14.60batch/s, loss=549]

epoch 67: avg test  loss 569.75, bar  test loss 3.880, len  test loss 0.067, col  test loss 159.576


Epoch 68: 543batch [00:36, 14.92batch/s, loss=556]


epoch 68: avg train loss 551.68, bar train loss 3.434, len train loss 0.052, col train loss 156.974


Epoch 69: 2batch [00:00, 14.49batch/s, loss=553]

epoch 68: avg test  loss 569.06, bar  test loss 3.889, len  test loss 0.055, col  test loss 159.491


Epoch 69: 543batch [00:36, 14.71batch/s, loss=578]


epoch 69: avg train loss 551.49, bar train loss 3.440, len train loss 0.051, col train loss 156.929


Epoch 70: 2batch [00:00, 13.89batch/s, loss=568]

epoch 69: avg test  loss 569.19, bar  test loss 3.888, len  test loss 0.057, col  test loss 159.503


Epoch 70: 543batch [00:37, 14.67batch/s, loss=582]


epoch 70: avg train loss 551.15, bar train loss 3.430, len train loss 0.051, col train loss 156.892
epoch 70: avg test  loss 568.92, bar  test loss 3.877, len  test loss 0.055, col  test loss 159.477


Epoch 71: 543batch [00:37, 14.41batch/s, loss=548]


epoch 71: avg train loss 550.91, bar train loss 3.430, len train loss 0.051, col train loss 156.787


Epoch 72: 2batch [00:00, 14.39batch/s, loss=567]

epoch 71: avg test  loss 569.81, bar  test loss 3.914, len  test loss 0.064, col  test loss 159.528


Epoch 72: 543batch [00:37, 14.51batch/s, loss=571]


epoch 72: avg train loss 550.71, bar train loss 3.417, len train loss 0.052, col train loss 156.790


Epoch 73: 2batch [00:00, 13.89batch/s, loss=544]

epoch 72: avg test  loss 568.98, bar  test loss 3.883, len  test loss 0.054, col  test loss 159.591


Epoch 73: 4batch [00:00, 13.95batch/s, loss=545]

In [None]:
diva.lqz1[-5:], diva.lqz2[-5:], diva.lpz1[-5:], diva.lpz2[-5:]

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')