In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [None]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/new/NEWHVAE/tensorboard")

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z1_dim=1000, z2_dim=1000, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 600, h2_dim = 600, number_components = 500,
                 beta=1, rec_alpha = 100, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_col, x_bar = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len2.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar2.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col2.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x_len, x_col, x_bar, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,5,200), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]) ,2*j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 2*j+1] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len2.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col2.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar2.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [20]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=1200, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        
        self.fc1 = nn.Sequential(nn.Linear(2*h_dim, dim0, bias=False),  
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(dim0, dim1, bias=False),  
                                 nn.ReLU())
        self.fc3 = nn.Sequential(nn.Linear(dim1, dim2, bias=False),
                                 nn.ReLU())
        
        # Predicting length and color of each bar
        #self.color = nn.Sequential(nn.Conv1d(1,5, kernel_size=1, bias=False), 
                                  # nn.Softmax(dim=1))
        # Predicting color of each bar
        self.color_bar_black = nn.Linear(dim2,200)
        self.color_bar_reddd = nn.Linear(dim2,200)
        self.color_bar_bluee = nn.Linear(dim2,200)
        self.color_bar_green = nn.Linear(dim2,200)
        self.color_bar_yelow = nn.Linear(dim2,200)
        
        # Predicting the length of each bar
        self.length_bar_top = nn.Sequential(nn.Linear(dim1,100), nn.Softplus())
        self.length_bar_bot = nn.Sequential(nn.Linear(dim1,100), nn.Softplus())
        #self.length_bar_scale = nn.Sequential(nn.Conv1d(100, 1, kernel_size = 3, padding = 'same', bias=False), nn.Sigmoid())
        
        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(dim1,400), nn.ReLU(),nn.Linear(400,1), nn.Softplus())
        #self.length_RNA_scale = nn.Sequential(nn.Linear(400,1, bias=False), nn.Sigmoid())
        
    def forward(self, z1, mz2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        h = self.fc1(h)
        h = nn.Dropout(0.2)(h)
        h = self.fc2(h)
        h = nn.Dropout(0.2)(h)
        len_RNA = self.length_RNA(h)
        len_bar = torch.cat([self.length_bar_top(h)[:,None,:],self.length_bar_bot(h)[:,None,:]], dim=1) 
        h = self.fc3(h)
        h = nn.Dropout(0.3)(h)
        
        len_RNA_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_RNA_sc = torch.exp(self.length_RNA_scale(h))
        
        
        
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_bar_sc = torch.exp(self.length_bar_scale(h))
        
#        col_bar = self.color(h[:,None,:])
        black = self.color_bar_black(h)[:,None,:]
        reddd = self.color_bar_reddd(h)[:,None,:]
        bluee = self.color_bar_bluee(h)[:,None,:]
        green = self.color_bar_green(h)[:,None,:]
        yelow = self.color_bar_yelow(h)[:,None,:]
        
        col = torch.cat([black,reddd,bluee,green,yelow], dim=1)
        
        col_bar = nn.Softmax(dim=1)(col)
        
        return len_RNA, len_RNA_sc, len_bar, len_bar_sc, col_bar, pz1_m, pz1_s

    def reconstruct_image(self, len_RNA, var_RNA, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()
        var_RNA = var_RNA.cpu().numpy()
        #.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = int(np.round(np.random.normal(loc=len_RNA[i], scale=var_RNA[i])))
            else:
                limit = int(np.round(len_RNA[i]))
            limit = min(100, limit)
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1 = np.argmax(col_bar[i,:, 2*j])
                    _col_bar_2 = np.argmax(col_bar[i,:, 2*j+1])
                
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output


In [21]:
int(np.round(3.7, 0))
int(3.7)

3

In [22]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [23]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [24]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(5632, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(5632, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(5632, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

#         torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         torch.nn.init.xavier_uniform_(self.encoder[3].weight)
#         torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         self.fc11[0].bias.data.zero_()
#         torch.nn.init.xavier_uniform_(self.fc12[0].weight)
#         self.fc12[0].bias.data.zero_()
    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 5632)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 5632)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [25]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [26]:
enc = qz(128,10,10,10,500,400,400)
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 2, 11]           --
│    └─Conv2d: 2-1                       [1, 32, 25, 100]          2,400
│    └─ReLU: 2-2                         [1, 32, 25, 100]          --
│    └─Conv2d: 2-3                       [1, 64, 25, 100]          51,200
│    └─ReLU: 2-4                         [1, 64, 25, 100]          --
│    └─MaxPool2d: 2-5                    [1, 64, 12, 50]           --
│    └─Conv2d: 2-6                       [1, 128, 12, 50]          73,728
│    └─ReLU: 2-7                         [1, 128, 12, 50]          --
│    └─Conv2d: 2-8                       [1, 128, 12, 50]          147,456
│    └─ReLU: 2-9                         [1, 128, 12, 50]          --
│    └─MaxPool2d: 2-10                   [1, 128, 6, 25]           --
│    └─Conv2d: 2-11                      [1, 256, 4, 23]           29

In [27]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [28]:
class HVAE(nn.Module):
    def __init__(self, args):
        super(HVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, pz1_m, pz1_s = self.px(z1, mz2)
        
        return x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, out_len, out_bar, out_col):
        
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m)
        
        # Reconstruction Loss
        mask = 1 - F.one_hot(torch.round(out_len).to(torch.int64)*2-1, 200).cumsum(dim=1)[:,None,:]
        mask1 = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,:]).repeat(1,2,1)

        x_col = mask.repeat(1,5,1)*x_col
        
        dist_len = dist.Normal(x_len, x_len_scale+1e-7)
        log_len = dist_len.log_prob(out_len[:,None]).mean()
         
        mse_bar = ((((x_bar - out_bar)**2)*mask1).sum(dim=(1,2))/(mask1.sum(dim=(1,2)))).sum()#.detach().item()
        
        max_bar = torch.argmax(x_col, dim=1)
        #print(max_bar.shape, out_col.shape, mask.shape)
        acc_bar = (((max_bar==torch.argmax(out_col, dim=1))*mask).sum().float()/(mask.sum(1))).sum()
        #print(acc_bar.shape)
        RE_len = -log_len
        RE_bar = mse_bar#-log_bar
        RE_col = F.cross_entropy(x_col, out_col, reduction='sum')

        #KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        #print(KL_p_z1.shape,KL_p_z2.shape,KL_q_z1.shape,KL_q_z2.shape)
        
        
#         self.lpz1.append(KL_p_z1.detach().item())
#         self.lpz2.append(KL_p_z2.detach().item())
#         self.lqz1.append(KL_q_z1.detach().item())
#         self.lqz2.append(KL_q_z2.detach().item())
        
#         self.bar.append(RE_bar.detach())
#         self.col.append(RE_col.detach())
#         self.len.append(RE_len.detach())
        
        
        return self.rec_alpha * RE_len \
                  + self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_len, RE_col, mse_bar, acc_bar
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [29]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [30]:
default_args = diva_args()
enc = HVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
HVAE                                     --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [31]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [32]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [33]:
len(RNA_dataset)

34721

In [34]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar, m) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(), m.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        mse_bar += mse
        acc_bar += acc
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar

In [35]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    mse_bar = 0
    acc_bar = 0        
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar, m) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, len_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(),m.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
            mse_bar += mse
            acc_bar += acc
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_len_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar
  

In [36]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col, mtr, atr = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test, mte, ate = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train': atr, 'test':ate}, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [37]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x_1, x_1var, x_2, x_2var, x_3 ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m)
        out = diva.px.reconstruct_image(x_1, x_1var, x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [38]:
DEVICE

device(type='cuda')

## Model Training

In [39]:
default_args = diva_args(prewarmup=0, number_components=50)

In [40]:
diva = HVAE(default_args).to(DEVICE)

In [41]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE10/checkpoints/905.pth'))

In [42]:
train_loader = DataLoader(RNA_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=64)

In [43]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [44]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [45]:
writer.flush()

In [50]:
diva.rec_gamma = 3

In [46]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/HVAEFC2/tensorboard/"

Reusing TensorBoard on port 6006 (pid 18464), started 1 day, 22:13:51 ago. (Use '!kill 18464' to kill it.)

In [51]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 72, save_folder="HVAEFC3",save_interval=5)

Epoch 73: 543batch [01:08,  7.98batch/s, loss=581]


epoch 73: avg train loss 549.50, bar train loss 3.068, len train loss 0.050, col train loss 159.390


Epoch 74: 1batch [00:00,  8.13batch/s, loss=517]

epoch 73: avg test  loss 568.74, bar  test loss 3.778, len  test loss 0.057, col  test loss 160.863


Epoch 74: 543batch [01:07,  7.99batch/s, loss=557]


epoch 74: avg train loss 548.65, bar train loss 3.066, len train loss 0.051, col train loss 159.103


Epoch 75: 1batch [00:00,  8.20batch/s, loss=527]

epoch 74: avg test  loss 569.29, bar  test loss 3.807, len  test loss 0.059, col  test loss 160.802


Epoch 75: 543batch [01:07,  7.99batch/s, loss=553]


epoch 75: avg train loss 547.76, bar train loss 3.057, len train loss 0.050, col train loss 158.913
epoch 75: avg test  loss 567.99, bar  test loss 3.770, len  test loss 0.057, col  test loss 160.738


Epoch 76: 543batch [01:05,  8.27batch/s, loss=587]


epoch 76: avg train loss 546.91, bar train loss 3.042, len train loss 0.049, col train loss 158.784


Epoch 77: 1batch [00:00,  8.55batch/s, loss=539]

epoch 76: avg test  loss 567.77, bar  test loss 3.748, len  test loss 0.059, col  test loss 160.698


Epoch 77: 543batch [01:08,  7.97batch/s, loss=593]


epoch 77: avg train loss 546.66, bar train loss 3.036, len train loss 0.050, col train loss 158.687


Epoch 78: 1batch [00:00,  7.25batch/s, loss=549]

epoch 77: avg test  loss 567.93, bar  test loss 3.775, len  test loss 0.056, col  test loss 160.606


Epoch 78: 543batch [01:07,  7.99batch/s, loss=577]


epoch 78: avg train loss 546.17, bar train loss 3.027, len train loss 0.050, col train loss 158.602


Epoch 79: 1batch [00:00,  8.06batch/s, loss=563]

epoch 78: avg test  loss 568.32, bar  test loss 3.775, len  test loss 0.060, col  test loss 160.591


Epoch 79: 543batch [01:07,  8.00batch/s, loss=530]


epoch 79: avg train loss 545.95, bar train loss 3.023, len train loss 0.050, col train loss 158.529


Epoch 80: 1batch [00:00,  8.20batch/s, loss=574]

epoch 79: avg test  loss 569.46, bar  test loss 3.779, len  test loss 0.067, col  test loss 160.586


Epoch 80: 543batch [01:07,  8.00batch/s, loss=581]


epoch 80: avg train loss 545.53, bar train loss 3.014, len train loss 0.050, col train loss 158.466
epoch 80: avg test  loss 567.35, bar  test loss 3.756, len  test loss 0.056, col  test loss 160.578


Epoch 81: 543batch [01:08,  7.98batch/s, loss=563]


epoch 81: avg train loss 545.09, bar train loss 3.009, len train loss 0.050, col train loss 158.391


Epoch 82: 1batch [00:00,  7.94batch/s, loss=517]

epoch 81: avg test  loss 567.41, bar  test loss 3.749, len  test loss 0.058, col  test loss 160.530


Epoch 82: 543batch [01:07,  8.00batch/s, loss=537]


epoch 82: avg train loss 544.80, bar train loss 2.997, len train loss 0.050, col train loss 158.325


Epoch 83: 1batch [00:00,  8.00batch/s, loss=531]

epoch 82: avg test  loss 567.50, bar  test loss 3.780, len  test loss 0.054, col  test loss 160.495


Epoch 83: 543batch [01:08,  7.96batch/s, loss=551]


epoch 83: avg train loss 544.36, bar train loss 2.992, len train loss 0.050, col train loss 158.241


Epoch 84: 1batch [00:00,  8.20batch/s, loss=486]

epoch 83: avg test  loss 567.02, bar  test loss 3.761, len  test loss 0.056, col  test loss 160.470


Epoch 84: 543batch [01:08,  7.98batch/s, loss=528]


epoch 84: avg train loss 543.77, bar train loss 2.981, len train loss 0.048, col train loss 158.198


Epoch 85: 1batch [00:00,  8.06batch/s, loss=543]

epoch 84: avg test  loss 566.94, bar  test loss 3.767, len  test loss 0.056, col  test loss 160.508


Epoch 85: 543batch [01:07,  7.99batch/s, loss=553]


epoch 85: avg train loss 543.63, bar train loss 2.980, len train loss 0.049, col train loss 158.135
epoch 85: avg test  loss 567.14, bar  test loss 3.760, len  test loss 0.061, col  test loss 160.440


Epoch 86: 543batch [01:08,  7.93batch/s, loss=610]


epoch 86: avg train loss 543.71, bar train loss 2.983, len train loss 0.050, col train loss 158.105


Epoch 87: 1batch [00:00,  8.13batch/s, loss=539]

epoch 86: avg test  loss 567.95, bar  test loss 3.765, len  test loss 0.065, col  test loss 160.394


Epoch 87: 543batch [01:08,  7.97batch/s, loss=546]


epoch 87: avg train loss 543.01, bar train loss 2.965, len train loss 0.049, col train loss 158.045


Epoch 88: 1batch [00:00,  8.06batch/s, loss=558]

epoch 87: avg test  loss 567.70, bar  test loss 3.778, len  test loss 0.059, col  test loss 160.405


Epoch 88: 543batch [01:08,  7.97batch/s, loss=575]


epoch 88: avg train loss 542.95, bar train loss 2.963, len train loss 0.050, col train loss 157.973


Epoch 89: 1batch [00:00,  8.13batch/s, loss=494]

epoch 88: avg test  loss 567.82, bar  test loss 3.780, len  test loss 0.060, col  test loss 160.398


Epoch 89: 543batch [01:08,  7.96batch/s, loss=561]


epoch 89: avg train loss 542.56, bar train loss 2.959, len train loss 0.049, col train loss 157.921


Epoch 90: 1batch [00:00,  8.13batch/s, loss=553]

epoch 89: avg test  loss 568.47, bar  test loss 3.818, len  test loss 0.061, col  test loss 160.424


Epoch 90: 543batch [01:08,  7.97batch/s, loss=597]


epoch 90: avg train loss 542.24, bar train loss 2.951, len train loss 0.050, col train loss 157.874
epoch 90: avg test  loss 570.09, bar  test loss 3.843, len  test loss 0.066, col  test loss 160.526


Epoch 91: 543batch [01:08,  7.97batch/s, loss=559]


epoch 91: avg train loss 542.25, bar train loss 2.955, len train loss 0.049, col train loss 157.846


Epoch 92: 1batch [00:00,  6.85batch/s, loss=515]

epoch 91: avg test  loss 568.11, bar  test loss 3.784, len  test loss 0.058, col  test loss 160.410


Epoch 92: 543batch [01:07,  8.09batch/s, loss=521]


epoch 92: avg train loss 541.63, bar train loss 2.939, len train loss 0.049, col train loss 157.778


Epoch 93: 1batch [00:00,  7.87batch/s, loss=569]

epoch 92: avg test  loss 567.77, bar  test loss 3.796, len  test loss 0.059, col  test loss 160.388


Epoch 93: 543batch [01:07,  8.07batch/s, loss=534]


epoch 93: avg train loss 541.51, bar train loss 2.940, len train loss 0.048, col train loss 157.747


Epoch 94: 1batch [00:00,  8.26batch/s, loss=536]

epoch 93: avg test  loss 568.14, bar  test loss 3.838, len  test loss 0.062, col  test loss 160.332


Epoch 94: 543batch [01:07,  8.06batch/s, loss=562]


epoch 94: avg train loss 541.25, bar train loss 2.933, len train loss 0.049, col train loss 157.683


Epoch 95: 1batch [00:00,  7.87batch/s, loss=543]

epoch 94: avg test  loss 566.58, bar  test loss 3.780, len  test loss 0.057, col  test loss 160.240


Epoch 95: 543batch [01:08,  7.91batch/s, loss=554]


epoch 95: avg train loss 540.95, bar train loss 2.931, len train loss 0.048, col train loss 157.624
epoch 95: avg test  loss 567.52, bar  test loss 3.810, len  test loss 0.058, col  test loss 160.282


Epoch 96: 543batch [01:08,  7.91batch/s, loss=561]


epoch 96: avg train loss 540.79, bar train loss 2.923, len train loss 0.049, col train loss 157.606


Epoch 97: 1batch [00:00,  8.20batch/s, loss=535]

epoch 96: avg test  loss 567.46, bar  test loss 3.789, len  test loss 0.058, col  test loss 160.302


Epoch 97: 543batch [01:08,  7.90batch/s, loss=551]


epoch 97: avg train loss 540.53, bar train loss 2.918, len train loss 0.048, col train loss 157.578


Epoch 98: 1batch [00:00,  7.75batch/s, loss=536]

epoch 97: avg test  loss 567.36, bar  test loss 3.792, len  test loss 0.058, col  test loss 160.329


Epoch 98: 543batch [01:08,  7.91batch/s, loss=590]


epoch 98: avg train loss 540.12, bar train loss 2.911, len train loss 0.048, col train loss 157.508


Epoch 99: 1batch [00:00,  7.94batch/s, loss=539]

epoch 98: avg test  loss 567.17, bar  test loss 3.800, len  test loss 0.056, col  test loss 160.310


Epoch 99: 543batch [01:08,  7.92batch/s, loss=483]


epoch 99: avg train loss 540.21, bar train loss 2.919, len train loss 0.049, col train loss 157.446


Epoch 100: 1batch [00:00,  6.71batch/s, loss=547]

epoch 99: avg test  loss 567.00, bar  test loss 3.805, len  test loss 0.056, col  test loss 160.195


Epoch 100: 543batch [01:08,  7.94batch/s, loss=542]


epoch 100: avg train loss 539.61, bar train loss 2.904, len train loss 0.048, col train loss 157.399
epoch 100: avg test  loss 568.04, bar  test loss 3.838, len  test loss 0.064, col  test loss 160.269


Epoch 101: 543batch [01:08,  7.94batch/s, loss=553]


epoch 101: avg train loss 539.67, bar train loss 2.907, len train loss 0.049, col train loss 157.356


Epoch 102: 1batch [00:00,  8.13batch/s, loss=549]

epoch 101: avg test  loss 567.10, bar  test loss 3.807, len  test loss 0.060, col  test loss 160.284


Epoch 102: 543batch [01:08,  7.96batch/s, loss=516]


epoch 102: avg train loss 539.15, bar train loss 2.891, len train loss 0.048, col train loss 157.313


Epoch 103: 1batch [00:00,  8.06batch/s, loss=539]

epoch 102: avg test  loss 567.57, bar  test loss 3.794, len  test loss 0.058, col  test loss 160.147


Epoch 103: 543batch [01:08,  7.95batch/s, loss=562]


epoch 103: avg train loss 539.15, bar train loss 2.898, len train loss 0.048, col train loss 157.273


Epoch 104: 1batch [00:00,  8.13batch/s, loss=565]

epoch 103: avg test  loss 567.51, bar  test loss 3.816, len  test loss 0.058, col  test loss 160.207


Epoch 104: 543batch [01:08,  7.96batch/s, loss=551]


epoch 104: avg train loss 539.04, bar train loss 2.889, len train loss 0.050, col train loss 157.261


Epoch 105: 1batch [00:00,  8.06batch/s, loss=532]

epoch 104: avg test  loss 568.13, bar  test loss 3.830, len  test loss 0.062, col  test loss 160.191


Epoch 105: 543batch [01:08,  7.96batch/s, loss=557]


epoch 105: avg train loss 538.64, bar train loss 2.884, len train loss 0.049, col train loss 157.182
epoch 105: avg test  loss 568.23, bar  test loss 3.815, len  test loss 0.065, col  test loss 160.220


Epoch 106: 543batch [01:08,  7.94batch/s, loss=553]


epoch 106: avg train loss 538.35, bar train loss 2.883, len train loss 0.048, col train loss 157.136


Epoch 107: 1batch [00:00,  7.94batch/s, loss=557]

epoch 106: avg test  loss 567.79, bar  test loss 3.834, len  test loss 0.061, col  test loss 160.188


Epoch 107: 543batch [01:08,  7.96batch/s, loss=536]


epoch 107: avg train loss 538.38, bar train loss 2.882, len train loss 0.049, col train loss 157.107


Epoch 108: 1batch [00:00,  8.13batch/s, loss=556]

epoch 107: avg test  loss 567.85, bar  test loss 3.823, len  test loss 0.061, col  test loss 160.182


Epoch 108: 543batch [01:08,  7.94batch/s, loss=541]


epoch 108: avg train loss 537.69, bar train loss 2.870, len train loss 0.047, col train loss 157.034


Epoch 109: 1batch [00:00,  8.00batch/s, loss=516]

epoch 108: avg test  loss 568.15, bar  test loss 3.863, len  test loss 0.058, col  test loss 160.158


Epoch 109: 543batch [01:08,  7.96batch/s, loss=535]


epoch 109: avg train loss 537.72, bar train loss 2.866, len train loss 0.048, col train loss 157.059


Epoch 110: 1batch [00:00,  7.94batch/s, loss=524]

epoch 109: avg test  loss 567.60, bar  test loss 3.818, len  test loss 0.061, col  test loss 160.168


Epoch 110: 543batch [01:08,  7.95batch/s, loss=502]


epoch 110: avg train loss 537.55, bar train loss 2.870, len train loss 0.047, col train loss 157.001
epoch 110: avg test  loss 567.92, bar  test loss 3.835, len  test loss 0.065, col  test loss 160.150


Epoch 111: 543batch [01:08,  7.93batch/s, loss=557]


epoch 111: avg train loss 537.15, bar train loss 2.854, len train loss 0.048, col train loss 156.964


Epoch 112: 1batch [00:00,  8.06batch/s, loss=536]

epoch 111: avg test  loss 567.69, bar  test loss 3.848, len  test loss 0.059, col  test loss 160.160


Epoch 112: 543batch [01:08,  7.95batch/s, loss=546]


epoch 112: avg train loss 537.29, bar train loss 2.861, len train loss 0.048, col train loss 156.913


Epoch 113: 1batch [00:00,  8.13batch/s, loss=539]

epoch 112: avg test  loss 567.89, bar  test loss 3.854, len  test loss 0.062, col  test loss 160.151


Epoch 113: 543batch [01:08,  7.95batch/s, loss=515]


epoch 113: avg train loss 536.82, bar train loss 2.851, len train loss 0.048, col train loss 156.871


Epoch 114: 1batch [00:00,  8.06batch/s, loss=498]

epoch 113: avg test  loss 567.07, bar  test loss 3.808, len  test loss 0.060, col  test loss 160.105


Epoch 114: 543batch [01:08,  7.94batch/s, loss=489]


epoch 114: avg train loss 536.75, bar train loss 2.850, len train loss 0.047, col train loss 156.836


Epoch 115: 1batch [00:00,  8.00batch/s, loss=549]

epoch 114: avg test  loss 568.62, bar  test loss 3.837, len  test loss 0.073, col  test loss 160.104


Epoch 115: 543batch [01:08,  7.96batch/s, loss=530]


epoch 115: avg train loss 536.21, bar train loss 2.841, len train loss 0.046, col train loss 156.779
epoch 115: avg test  loss 567.72, bar  test loss 3.834, len  test loss 0.063, col  test loss 160.074


Epoch 116: 543batch [01:08,  7.91batch/s, loss=536]


epoch 116: avg train loss 536.21, bar train loss 2.838, len train loss 0.048, col train loss 156.762


Epoch 117: 1batch [00:00,  8.06batch/s, loss=550]

epoch 116: avg test  loss 567.07, bar  test loss 3.822, len  test loss 0.061, col  test loss 160.125


Epoch 117: 543batch [01:08,  7.94batch/s, loss=517]


epoch 117: avg train loss 536.11, bar train loss 2.843, len train loss 0.047, col train loss 156.699


Epoch 118: 1batch [00:00,  7.58batch/s, loss=508]

epoch 117: avg test  loss 567.34, bar  test loss 3.851, len  test loss 0.060, col  test loss 160.040


Epoch 118: 543batch [01:08,  7.94batch/s, loss=492]


epoch 118: avg train loss 536.01, bar train loss 2.837, len train loss 0.047, col train loss 156.725


Epoch 119: 1batch [00:00,  8.13batch/s, loss=542]

epoch 118: avg test  loss 568.04, bar  test loss 3.859, len  test loss 0.062, col  test loss 160.094


Epoch 119: 543batch [01:08,  7.95batch/s, loss=516]


epoch 119: avg train loss 535.58, bar train loss 2.830, len train loss 0.047, col train loss 156.636


Epoch 120: 1batch [00:00,  8.06batch/s, loss=532]

epoch 119: avg test  loss 569.69, bar  test loss 3.888, len  test loss 0.072, col  test loss 160.133


Epoch 120: 543batch [01:08,  7.93batch/s, loss=550]


epoch 120: avg train loss 535.63, bar train loss 2.830, len train loss 0.048, col train loss 156.620
epoch 120: avg test  loss 567.96, bar  test loss 3.842, len  test loss 0.065, col  test loss 160.103


Epoch 121: 543batch [01:08,  7.92batch/s, loss=530]


epoch 121: avg train loss 535.65, bar train loss 2.832, len train loss 0.048, col train loss 156.593


Epoch 122: 1batch [00:00,  8.13batch/s, loss=551]

epoch 121: avg test  loss 567.90, bar  test loss 3.861, len  test loss 0.064, col  test loss 160.063


Epoch 122: 543batch [01:08,  7.95batch/s, loss=556]


epoch 122: avg train loss 535.07, bar train loss 2.818, len train loss 0.047, col train loss 156.547


Epoch 123: 1batch [00:00,  8.20batch/s, loss=520]

epoch 122: avg test  loss 567.60, bar  test loss 3.842, len  test loss 0.059, col  test loss 160.094


Epoch 123: 543batch [01:08,  7.95batch/s, loss=535]


epoch 123: avg train loss 535.04, bar train loss 2.824, len train loss 0.047, col train loss 156.492


Epoch 124: 1batch [00:00,  8.06batch/s, loss=533]

epoch 123: avg test  loss 567.47, bar  test loss 3.849, len  test loss 0.061, col  test loss 160.091


Epoch 124: 543batch [01:08,  7.94batch/s, loss=554]


epoch 124: avg train loss 534.81, bar train loss 2.816, len train loss 0.047, col train loss 156.492


Epoch 125: 1batch [00:00,  8.13batch/s, loss=493]

epoch 124: avg test  loss 568.41, bar  test loss 3.854, len  test loss 0.066, col  test loss 159.972


Epoch 125: 543batch [01:08,  7.94batch/s, loss=543]


epoch 125: avg train loss 534.76, bar train loss 2.813, len train loss 0.047, col train loss 156.450
epoch 125: avg test  loss 571.12, bar  test loss 3.918, len  test loss 0.079, col  test loss 160.202


Epoch 126: 543batch [01:08,  7.93batch/s, loss=551]


epoch 126: avg train loss 535.63, bar train loss 2.832, len train loss 0.052, col train loss 156.467


Epoch 127: 1batch [00:00,  8.13batch/s, loss=504]

epoch 126: avg test  loss 568.22, bar  test loss 3.863, len  test loss 0.068, col  test loss 160.028


Epoch 127: 543batch [01:08,  7.95batch/s, loss=563]


epoch 127: avg train loss 534.33, bar train loss 2.808, len train loss 0.047, col train loss 156.383


Epoch 128: 1batch [00:00,  8.13batch/s, loss=533]

epoch 127: avg test  loss 567.59, bar  test loss 3.846, len  test loss 0.062, col  test loss 160.046


Epoch 128: 543batch [01:08,  7.95batch/s, loss=488]


epoch 128: avg train loss 534.56, bar train loss 2.811, len train loss 0.048, col train loss 156.389


Epoch 129: 1batch [00:00,  8.06batch/s, loss=543]

epoch 128: avg test  loss 568.97, bar  test loss 3.882, len  test loss 0.066, col  test loss 160.082


Epoch 129: 543batch [01:08,  7.91batch/s, loss=530]


epoch 129: avg train loss 534.09, bar train loss 2.802, len train loss 0.047, col train loss 156.338


Epoch 130: 1batch [00:00,  8.06batch/s, loss=557]

epoch 129: avg test  loss 567.61, bar  test loss 3.865, len  test loss 0.062, col  test loss 159.927


Epoch 130: 543batch [01:08,  7.93batch/s, loss=543]


epoch 130: avg train loss 533.69, bar train loss 2.791, len train loss 0.047, col train loss 156.299
epoch 130: avg test  loss 567.88, bar  test loss 3.859, len  test loss 0.064, col  test loss 159.998


Epoch 131: 543batch [01:05,  8.32batch/s, loss=551]


epoch 131: avg train loss 533.90, bar train loss 2.798, len train loss 0.048, col train loss 156.288


Epoch 132: 1batch [00:00,  8.26batch/s, loss=500]

epoch 131: avg test  loss 568.45, bar  test loss 3.856, len  test loss 0.065, col  test loss 160.082


Epoch 132: 543batch [01:03,  8.52batch/s, loss=564]


epoch 132: avg train loss 533.42, bar train loss 2.790, len train loss 0.046, col train loss 156.242


Epoch 133: 1batch [00:00,  8.33batch/s, loss=522]

epoch 132: avg test  loss 568.90, bar  test loss 3.867, len  test loss 0.071, col  test loss 160.083


Epoch 133: 543batch [01:03,  8.50batch/s, loss=567]


epoch 133: avg train loss 533.27, bar train loss 2.786, len train loss 0.047, col train loss 156.203


Epoch 134: 1batch [00:00,  8.33batch/s, loss=515]

epoch 133: avg test  loss 568.76, bar  test loss 3.862, len  test loss 0.064, col  test loss 160.014


Epoch 134: 543batch [01:03,  8.49batch/s, loss=582]


epoch 134: avg train loss 533.21, bar train loss 2.785, len train loss 0.047, col train loss 156.191


Epoch 135: 1batch [00:00,  8.40batch/s, loss=563]

epoch 134: avg test  loss 568.34, bar  test loss 3.887, len  test loss 0.062, col  test loss 160.039


Epoch 135: 543batch [01:04,  8.48batch/s, loss=494]


epoch 135: avg train loss 533.12, bar train loss 2.787, len train loss 0.046, col train loss 156.185
epoch 135: avg test  loss 568.49, bar  test loss 3.869, len  test loss 0.069, col  test loss 160.070


Epoch 136: 543batch [01:04,  8.48batch/s, loss=552]


epoch 136: avg train loss 532.98, bar train loss 2.783, len train loss 0.047, col train loss 156.144


Epoch 137: 1batch [00:00,  8.47batch/s, loss=544]

epoch 136: avg test  loss 568.56, bar  test loss 3.877, len  test loss 0.065, col  test loss 160.023


Epoch 137: 543batch [01:04,  8.48batch/s, loss=533]


epoch 137: avg train loss 532.71, bar train loss 2.777, len train loss 0.046, col train loss 156.121


Epoch 138: 1batch [00:00,  8.26batch/s, loss=536]

epoch 137: avg test  loss 567.82, bar  test loss 3.873, len  test loss 0.062, col  test loss 160.024


Epoch 138: 543batch [01:03,  8.51batch/s, loss=554]


epoch 138: avg train loss 532.53, bar train loss 2.777, len train loss 0.046, col train loss 156.068


Epoch 139: 1batch [00:00,  8.33batch/s, loss=505]

epoch 138: avg test  loss 568.34, bar  test loss 3.884, len  test loss 0.063, col  test loss 160.001


Epoch 139: 543batch [01:04,  8.46batch/s, loss=573]


epoch 139: avg train loss 532.86, bar train loss 2.779, len train loss 0.048, col train loss 156.077


Epoch 140: 1batch [00:00,  8.40batch/s, loss=519]

epoch 139: avg test  loss 568.18, bar  test loss 3.868, len  test loss 0.061, col  test loss 159.997


Epoch 140: 543batch [01:03,  8.51batch/s, loss=516]


epoch 140: avg train loss 532.20, bar train loss 2.771, len train loss 0.046, col train loss 156.009
epoch 140: avg test  loss 568.50, bar  test loss 3.897, len  test loss 0.065, col  test loss 160.002


Epoch 141: 543batch [01:04,  8.47batch/s, loss=567]


epoch 141: avg train loss 531.89, bar train loss 2.766, len train loss 0.045, col train loss 155.947


Epoch 142: 1batch [00:00,  8.33batch/s, loss=565]

epoch 141: avg test  loss 568.44, bar  test loss 3.871, len  test loss 0.063, col  test loss 160.075


Epoch 142: 543batch [01:03,  8.49batch/s, loss=521]


epoch 142: avg train loss 531.86, bar train loss 2.761, len train loss 0.046, col train loss 155.965


Epoch 143: 1batch [00:00,  8.40batch/s, loss=552]

epoch 142: avg test  loss 568.28, bar  test loss 3.874, len  test loss 0.063, col  test loss 160.026


Epoch 143: 543batch [01:04,  8.46batch/s, loss=514]


epoch 143: avg train loss 531.96, bar train loss 2.765, len train loss 0.046, col train loss 155.961


Epoch 144: 1batch [00:00,  8.40batch/s, loss=502]

epoch 143: avg test  loss 568.74, bar  test loss 3.884, len  test loss 0.062, col  test loss 159.961


Epoch 144: 543batch [01:04,  8.36batch/s, loss=574]


epoch 144: avg train loss 531.56, bar train loss 2.756, len train loss 0.045, col train loss 155.936


Epoch 145: 1batch [00:00,  8.47batch/s, loss=530]

epoch 144: avg test  loss 567.97, bar  test loss 3.869, len  test loss 0.063, col  test loss 159.987


Epoch 145: 543batch [01:03,  8.50batch/s, loss=547]


epoch 145: avg train loss 531.42, bar train loss 2.751, len train loss 0.046, col train loss 155.887
epoch 145: avg test  loss 568.59, bar  test loss 3.894, len  test loss 0.067, col  test loss 159.997


Epoch 146: 543batch [01:05,  8.33batch/s, loss=501]


epoch 146: avg train loss 531.56, bar train loss 2.759, len train loss 0.046, col train loss 155.883


Epoch 147: 1batch [00:00,  8.55batch/s, loss=559]

epoch 146: avg test  loss 568.37, bar  test loss 3.873, len  test loss 0.067, col  test loss 160.027


Epoch 147: 543batch [01:04,  8.38batch/s, loss=509]


epoch 147: avg train loss 531.80, bar train loss 2.762, len train loss 0.047, col train loss 155.896


Epoch 148: 1batch [00:00,  8.26batch/s, loss=508]

epoch 147: avg test  loss 569.05, bar  test loss 3.907, len  test loss 0.065, col  test loss 160.031


Epoch 148: 543batch [01:08,  7.87batch/s, loss=507]


epoch 148: avg train loss 531.28, bar train loss 2.754, len train loss 0.046, col train loss 155.826


Epoch 149: 1batch [00:00,  8.40batch/s, loss=539]

epoch 148: avg test  loss 570.55, bar  test loss 3.903, len  test loss 0.075, col  test loss 159.998


Epoch 149: 543batch [01:08,  7.89batch/s, loss=600]


epoch 149: avg train loss 531.03, bar train loss 2.750, len train loss 0.046, col train loss 155.768


Epoch 150: 1batch [00:00,  8.00batch/s, loss=533]

epoch 149: avg test  loss 568.51, bar  test loss 3.897, len  test loss 0.064, col  test loss 160.012


Epoch 150: 543batch [01:08,  7.89batch/s, loss=515]


epoch 150: avg train loss 531.01, bar train loss 2.745, len train loss 0.047, col train loss 155.766
epoch 150: avg test  loss 568.08, bar  test loss 3.882, len  test loss 0.060, col  test loss 159.952


Epoch 151: 543batch [01:09,  7.87batch/s, loss=532]


epoch 151: avg train loss 530.61, bar train loss 2.737, len train loss 0.045, col train loss 155.731


Epoch 152: 1batch [00:00,  7.94batch/s, loss=499]

epoch 151: avg test  loss 568.58, bar  test loss 3.884, len  test loss 0.069, col  test loss 159.983


Epoch 152: 543batch [01:08,  7.89batch/s, loss=523]


epoch 152: avg train loss 530.57, bar train loss 2.739, len train loss 0.046, col train loss 155.692


Epoch 153: 1batch [00:00,  8.40batch/s, loss=528]

epoch 152: avg test  loss 568.17, bar  test loss 3.888, len  test loss 0.061, col  test loss 159.917


Epoch 153: 543batch [01:08,  7.89batch/s, loss=508]


epoch 153: avg train loss 530.66, bar train loss 2.740, len train loss 0.046, col train loss 155.721


Epoch 154: 1batch [00:00,  8.06batch/s, loss=545]

epoch 153: avg test  loss 568.63, bar  test loss 3.907, len  test loss 0.066, col  test loss 159.994


Epoch 154: 543batch [01:08,  7.91batch/s, loss=553]


epoch 154: avg train loss 530.43, bar train loss 2.739, len train loss 0.046, col train loss 155.671


Epoch 155: 1batch [00:00,  7.94batch/s, loss=524]

epoch 154: avg test  loss 568.94, bar  test loss 3.911, len  test loss 0.067, col  test loss 160.014


Epoch 155: 543batch [01:08,  7.90batch/s, loss=529]


epoch 155: avg train loss 530.41, bar train loss 2.739, len train loss 0.046, col train loss 155.659
epoch 155: avg test  loss 568.57, bar  test loss 3.906, len  test loss 0.062, col  test loss 160.016


Epoch 156: 543batch [01:09,  7.87batch/s, loss=504]


epoch 156: avg train loss 532.46, bar train loss 2.772, len train loss 0.048, col train loss 155.727


Epoch 157: 1batch [00:00,  7.87batch/s, loss=580]

epoch 156: avg test  loss 575.84, bar  test loss 4.105, len  test loss 0.081, col  test loss 160.289


Epoch 157: 543batch [01:08,  7.95batch/s, loss=517]


epoch 157: avg train loss 532.51, bar train loss 2.788, len train loss 0.049, col train loss 155.735


Epoch 158: 1batch [00:00,  8.47batch/s, loss=511]

epoch 157: avg test  loss 570.26, bar  test loss 3.920, len  test loss 0.070, col  test loss 160.019


Epoch 158: 543batch [01:04,  8.46batch/s, loss=553]


epoch 158: avg train loss 530.72, bar train loss 2.738, len train loss 0.046, col train loss 155.649


Epoch 159: 1batch [00:00,  8.40batch/s, loss=526]

epoch 158: avg test  loss 569.07, bar  test loss 3.924, len  test loss 0.064, col  test loss 159.972


Epoch 159: 543batch [01:04,  8.47batch/s, loss=560]


epoch 159: avg train loss 530.25, bar train loss 2.728, len train loss 0.046, col train loss 155.576


Epoch 160: 1batch [00:00,  8.26batch/s, loss=570]

epoch 159: avg test  loss 569.41, bar  test loss 3.922, len  test loss 0.067, col  test loss 159.974


Epoch 160: 543batch [01:04,  8.47batch/s, loss=495]


epoch 160: avg train loss 530.20, bar train loss 2.728, len train loss 0.046, col train loss 155.591
epoch 160: avg test  loss 569.06, bar  test loss 3.916, len  test loss 0.066, col  test loss 159.993


Epoch 161: 543batch [01:04,  8.41batch/s, loss=559]


epoch 161: avg train loss 530.01, bar train loss 2.731, len train loss 0.045, col train loss 155.522


Epoch 162: 1batch [00:00,  8.33batch/s, loss=547]

epoch 161: avg test  loss 569.60, bar  test loss 3.907, len  test loss 0.070, col  test loss 160.001


Epoch 162: 543batch [01:04,  8.45batch/s, loss=474]


epoch 162: avg train loss 530.12, bar train loss 2.728, len train loss 0.046, col train loss 155.550


Epoch 163: 1batch [00:00,  8.33batch/s, loss=531]

epoch 162: avg test  loss 569.04, bar  test loss 3.915, len  test loss 0.063, col  test loss 160.028


Epoch 163: 543batch [01:04,  8.43batch/s, loss=547]


epoch 163: avg train loss 529.76, bar train loss 2.719, len train loss 0.046, col train loss 155.521


Epoch 164: 1batch [00:00,  8.40batch/s, loss=535]

epoch 163: avg test  loss 569.65, bar  test loss 3.912, len  test loss 0.061, col  test loss 160.015


Epoch 164: 543batch [01:06,  8.21batch/s, loss=529]


epoch 164: avg train loss 529.71, bar train loss 2.722, len train loss 0.046, col train loss 155.477


Epoch 165: 1batch [00:00,  8.06batch/s, loss=478]

epoch 164: avg test  loss 570.76, bar  test loss 3.917, len  test loss 0.078, col  test loss 159.995


Epoch 165: 543batch [01:08,  7.88batch/s, loss=539]


epoch 165: avg train loss 529.52, bar train loss 2.717, len train loss 0.047, col train loss 155.406
epoch 165: avg test  loss 569.80, bar  test loss 3.934, len  test loss 0.071, col  test loss 160.012


Epoch 166: 543batch [01:09,  7.83batch/s, loss=509]


epoch 166: avg train loss 529.37, bar train loss 2.716, len train loss 0.045, col train loss 155.425


Epoch 167: 1batch [00:00,  8.06batch/s, loss=525]

epoch 166: avg test  loss 569.50, bar  test loss 3.939, len  test loss 0.066, col  test loss 159.951


Epoch 167: 543batch [01:09,  7.87batch/s, loss=558]


epoch 167: avg train loss 529.07, bar train loss 2.708, len train loss 0.044, col train loss 155.411


Epoch 168: 1batch [00:00,  8.00batch/s, loss=493]

epoch 167: avg test  loss 569.97, bar  test loss 3.932, len  test loss 0.068, col  test loss 159.962


Epoch 168: 543batch [01:07,  8.06batch/s, loss=539]


epoch 168: avg train loss 529.12, bar train loss 2.705, len train loss 0.045, col train loss 155.418


Epoch 169: 1batch [00:00,  8.33batch/s, loss=522]

epoch 168: avg test  loss 569.34, bar  test loss 3.917, len  test loss 0.067, col  test loss 160.000


Epoch 169: 543batch [01:05,  8.34batch/s, loss=531]


epoch 169: avg train loss 529.18, bar train loss 2.711, len train loss 0.046, col train loss 155.380


Epoch 170: 1batch [00:00,  8.40batch/s, loss=564]

epoch 169: avg test  loss 569.80, bar  test loss 3.925, len  test loss 0.069, col  test loss 160.086


Epoch 170: 543batch [01:07,  8.04batch/s, loss=599]


epoch 170: avg train loss 529.35, bar train loss 2.715, len train loss 0.047, col train loss 155.392
epoch 170: avg test  loss 569.08, bar  test loss 3.917, len  test loss 0.067, col  test loss 159.983


Epoch 171: 543batch [01:09,  7.83batch/s, loss=527]


epoch 171: avg train loss 529.35, bar train loss 2.716, len train loss 0.046, col train loss 155.409


Epoch 172: 1batch [00:00,  7.94batch/s, loss=532]

epoch 171: avg test  loss 569.47, bar  test loss 3.922, len  test loss 0.069, col  test loss 159.962


Epoch 172: 543batch [01:08,  7.87batch/s, loss=560]


epoch 172: avg train loss 528.87, bar train loss 2.702, len train loss 0.046, col train loss 155.339


Epoch 173: 1batch [00:00,  8.00batch/s, loss=525]

epoch 172: avg test  loss 569.77, bar  test loss 3.938, len  test loss 0.067, col  test loss 160.012


Epoch 173: 543batch [01:09,  7.85batch/s, loss=553]


epoch 173: avg train loss 528.51, bar train loss 2.702, len train loss 0.044, col train loss 155.291


Epoch 174: 1batch [00:00,  7.94batch/s, loss=490]

epoch 173: avg test  loss 569.52, bar  test loss 3.926, len  test loss 0.067, col  test loss 160.012


Epoch 174: 543batch [01:09,  7.85batch/s, loss=527]


epoch 174: avg train loss 528.55, bar train loss 2.700, len train loss 0.045, col train loss 155.308


Epoch 175: 1batch [00:00,  8.06batch/s, loss=531]

epoch 174: avg test  loss 569.41, bar  test loss 3.942, len  test loss 0.066, col  test loss 160.023


Epoch 175: 543batch [01:09,  7.86batch/s, loss=475]


epoch 175: avg train loss 528.28, bar train loss 2.696, len train loss 0.044, col train loss 155.273
epoch 175: avg test  loss 568.82, bar  test loss 3.924, len  test loss 0.062, col  test loss 160.005


Epoch 176: 543batch [01:08,  7.96batch/s, loss=527]


epoch 176: avg train loss 528.24, bar train loss 2.695, len train loss 0.045, col train loss 155.227


Epoch 177: 1batch [00:00,  8.26batch/s, loss=518]

epoch 176: avg test  loss 569.97, bar  test loss 3.942, len  test loss 0.071, col  test loss 159.947


Epoch 177: 543batch [01:04,  8.41batch/s, loss=535]


epoch 177: avg train loss 528.53, bar train loss 2.705, len train loss 0.046, col train loss 155.234


Epoch 178: 1batch [00:00,  8.33batch/s, loss=508]

epoch 177: avg test  loss 570.58, bar  test loss 3.950, len  test loss 0.073, col  test loss 159.999


Epoch 178: 543batch [01:04,  8.42batch/s, loss=518]


epoch 178: avg train loss 528.18, bar train loss 2.698, len train loss 0.044, col train loss 155.218


Epoch 179: 1batch [00:00,  8.40batch/s, loss=524]

epoch 178: avg test  loss 569.46, bar  test loss 3.946, len  test loss 0.066, col  test loss 159.930


Epoch 179: 543batch [01:04,  8.43batch/s, loss=532]


epoch 179: avg train loss 528.39, bar train loss 2.701, len train loss 0.046, col train loss 155.243


Epoch 180: 1batch [00:00,  8.13batch/s, loss=503]

epoch 179: avg test  loss 570.38, bar  test loss 3.967, len  test loss 0.072, col  test loss 160.019


Epoch 180: 543batch [01:04,  8.40batch/s, loss=523]


epoch 180: avg train loss 528.07, bar train loss 2.690, len train loss 0.045, col train loss 155.205
epoch 180: avg test  loss 569.96, bar  test loss 3.948, len  test loss 0.068, col  test loss 159.989


Epoch 181: 543batch [01:05,  8.34batch/s, loss=530]


epoch 181: avg train loss 528.38, bar train loss 2.700, len train loss 0.046, col train loss 155.232


Epoch 182: 1batch [00:00,  8.47batch/s, loss=531]

epoch 181: avg test  loss 569.79, bar  test loss 3.937, len  test loss 0.067, col  test loss 159.926


Epoch 182: 543batch [01:04,  8.42batch/s, loss=527]


epoch 182: avg train loss 527.72, bar train loss 2.688, len train loss 0.045, col train loss 155.125


Epoch 183: 1batch [00:00,  8.40batch/s, loss=525]

epoch 182: avg test  loss 571.94, bar  test loss 3.953, len  test loss 0.087, col  test loss 160.040


Epoch 183: 543batch [01:04,  8.42batch/s, loss=547]


epoch 183: avg train loss 527.50, bar train loss 2.682, len train loss 0.044, col train loss 155.118


Epoch 184: 1batch [00:00,  8.26batch/s, loss=520]

epoch 183: avg test  loss 569.39, bar  test loss 3.941, len  test loss 0.068, col  test loss 159.998


Epoch 184: 543batch [01:04,  8.43batch/s, loss=498]


epoch 184: avg train loss 527.61, bar train loss 2.681, len train loss 0.045, col train loss 155.130


Epoch 185: 1batch [00:00,  8.20batch/s, loss=547]

epoch 184: avg test  loss 570.89, bar  test loss 3.985, len  test loss 0.069, col  test loss 160.046


Epoch 185: 543batch [01:04,  8.42batch/s, loss=532]


epoch 185: avg train loss 527.65, bar train loss 2.686, len train loss 0.044, col train loss 155.141
epoch 185: avg test  loss 571.16, bar  test loss 3.964, len  test loss 0.079, col  test loss 160.011


Epoch 186: 543batch [01:04,  8.39batch/s, loss=524]


epoch 186: avg train loss 527.08, bar train loss 2.670, len train loss 0.045, col train loss 155.047


Epoch 187: 1batch [00:00,  8.20batch/s, loss=544]

epoch 186: avg test  loss 571.27, bar  test loss 3.964, len  test loss 0.070, col  test loss 160.023


Epoch 187: 543batch [01:07,  8.05batch/s, loss=516]


epoch 187: avg train loss 527.20, bar train loss 2.676, len train loss 0.045, col train loss 155.048


Epoch 188: 1batch [00:00,  7.94batch/s, loss=556]

epoch 187: avg test  loss 570.13, bar  test loss 3.949, len  test loss 0.068, col  test loss 159.996


Epoch 188: 543batch [01:09,  7.85batch/s, loss=557]


epoch 188: avg train loss 527.16, bar train loss 2.675, len train loss 0.045, col train loss 155.036


Epoch 189: 1batch [00:00,  8.13batch/s, loss=564]

epoch 188: avg test  loss 570.04, bar  test loss 3.948, len  test loss 0.067, col  test loss 159.999


Epoch 189: 543batch [01:10,  7.69batch/s, loss=557]


epoch 189: avg train loss 526.61, bar train loss 2.666, len train loss 0.044, col train loss 154.982


Epoch 190: 1batch [00:00,  8.00batch/s, loss=537]

epoch 189: avg test  loss 570.54, bar  test loss 3.960, len  test loss 0.071, col  test loss 160.013


Epoch 190: 543batch [01:10,  7.75batch/s, loss=536]


epoch 190: avg train loss 526.52, bar train loss 2.667, len train loss 0.044, col train loss 154.940
epoch 190: avg test  loss 570.08, bar  test loss 3.937, len  test loss 0.073, col  test loss 159.996


Epoch 191: 543batch [01:10,  7.72batch/s, loss=509]


epoch 191: avg train loss 526.67, bar train loss 2.671, len train loss 0.044, col train loss 154.969


Epoch 192: 1batch [00:00,  7.94batch/s, loss=518]

epoch 191: avg test  loss 569.44, bar  test loss 3.955, len  test loss 0.065, col  test loss 159.988


Epoch 192: 543batch [01:10,  7.75batch/s, loss=528]


epoch 192: avg train loss 526.81, bar train loss 2.675, len train loss 0.045, col train loss 154.953


Epoch 193: 1batch [00:00,  7.94batch/s, loss=520]

epoch 192: avg test  loss 571.05, bar  test loss 3.973, len  test loss 0.071, col  test loss 160.073


Epoch 193: 543batch [01:09,  7.84batch/s, loss=511]


epoch 193: avg train loss 526.57, bar train loss 2.670, len train loss 0.044, col train loss 154.958


Epoch 194: 1batch [00:00,  8.06batch/s, loss=525]

epoch 193: avg test  loss 569.71, bar  test loss 3.956, len  test loss 0.070, col  test loss 159.953


Epoch 194: 543batch [01:09,  7.81batch/s, loss=465]


epoch 194: avg train loss 526.61, bar train loss 2.667, len train loss 0.045, col train loss 154.951


Epoch 195: 1batch [00:00,  7.94batch/s, loss=529]

epoch 194: avg test  loss 569.94, bar  test loss 3.973, len  test loss 0.065, col  test loss 159.995


Epoch 195: 543batch [01:09,  7.81batch/s, loss=478]


epoch 195: avg train loss 526.20, bar train loss 2.664, len train loss 0.043, col train loss 154.892
epoch 195: avg test  loss 569.54, bar  test loss 3.956, len  test loss 0.070, col  test loss 159.908


Epoch 196: 543batch [01:09,  7.77batch/s, loss=485]


epoch 196: avg train loss 525.93, bar train loss 2.658, len train loss 0.043, col train loss 154.847


Epoch 197: 1batch [00:00,  7.94batch/s, loss=536]

epoch 196: avg test  loss 569.55, bar  test loss 3.957, len  test loss 0.069, col  test loss 159.966


Epoch 197: 543batch [01:09,  7.83batch/s, loss=549]


epoch 197: avg train loss 526.00, bar train loss 2.657, len train loss 0.044, col train loss 154.866


Epoch 198: 1batch [00:00,  7.87batch/s, loss=513]

epoch 197: avg test  loss 570.33, bar  test loss 3.975, len  test loss 0.068, col  test loss 160.054


Epoch 198: 543batch [01:09,  7.83batch/s, loss=482]


epoch 198: avg train loss 526.13, bar train loss 2.660, len train loss 0.045, col train loss 154.850


Epoch 199: 1batch [00:00,  7.87batch/s, loss=530]

epoch 198: avg test  loss 569.98, bar  test loss 3.956, len  test loss 0.069, col  test loss 159.982


Epoch 199: 543batch [01:09,  7.84batch/s, loss=527]


epoch 199: avg train loss 525.99, bar train loss 2.658, len train loss 0.044, col train loss 154.832


Epoch 200: 1batch [00:00,  7.94batch/s, loss=509]

epoch 199: avg test  loss 569.79, bar  test loss 3.961, len  test loss 0.068, col  test loss 159.946


Epoch 200: 543batch [01:09,  7.81batch/s, loss=523]


epoch 200: avg train loss 526.09, bar train loss 2.662, len train loss 0.044, col train loss 154.861
epoch 200: avg test  loss 569.88, bar  test loss 3.955, len  test loss 0.070, col  test loss 159.978


Epoch 201: 543batch [01:10,  7.72batch/s, loss=517]


epoch 201: avg train loss 525.77, bar train loss 2.655, len train loss 0.044, col train loss 154.792


Epoch 202: 1batch [00:00,  7.87batch/s, loss=529]

epoch 201: avg test  loss 569.93, bar  test loss 3.958, len  test loss 0.071, col  test loss 159.946


Epoch 202: 543batch [01:09,  7.82batch/s, loss=519]


epoch 202: avg train loss 525.89, bar train loss 2.653, len train loss 0.044, col train loss 154.828


Epoch 203: 1batch [00:00,  7.94batch/s, loss=531]

epoch 202: avg test  loss 569.82, bar  test loss 3.958, len  test loss 0.068, col  test loss 159.949


Epoch 203: 543batch [01:09,  7.82batch/s, loss=545]


epoch 203: avg train loss 525.54, bar train loss 2.648, len train loss 0.044, col train loss 154.762


Epoch 204: 1batch [00:00,  7.87batch/s, loss=495]

epoch 203: avg test  loss 570.90, bar  test loss 3.988, len  test loss 0.069, col  test loss 159.954


Epoch 204: 543batch [01:09,  7.81batch/s, loss=536]


epoch 204: avg train loss 525.60, bar train loss 2.653, len train loss 0.045, col train loss 154.749


Epoch 205: 1batch [00:00,  7.94batch/s, loss=512]

epoch 204: avg test  loss 571.08, bar  test loss 3.988, len  test loss 0.076, col  test loss 160.020


Epoch 205: 543batch [01:09,  7.82batch/s, loss=554]


epoch 205: avg train loss 525.23, bar train loss 2.642, len train loss 0.043, col train loss 154.724
epoch 205: avg test  loss 570.44, bar  test loss 3.985, len  test loss 0.070, col  test loss 159.967


Epoch 206: 543batch [01:10,  7.72batch/s, loss=534]


epoch 206: avg train loss 525.50, bar train loss 2.647, len train loss 0.044, col train loss 154.757


Epoch 207: 1batch [00:00,  7.87batch/s, loss=522]

epoch 206: avg test  loss 570.74, bar  test loss 4.000, len  test loss 0.070, col  test loss 159.973


Epoch 207: 543batch [01:09,  7.79batch/s, loss=566]


epoch 207: avg train loss 525.37, bar train loss 2.648, len train loss 0.044, col train loss 154.723


Epoch 208: 1batch [00:00,  7.94batch/s, loss=514]

epoch 207: avg test  loss 570.28, bar  test loss 3.971, len  test loss 0.069, col  test loss 160.000


Epoch 208: 543batch [01:09,  7.79batch/s, loss=533]


epoch 208: avg train loss 525.23, bar train loss 2.643, len train loss 0.045, col train loss 154.672


Epoch 209: 1batch [00:00,  7.81batch/s, loss=503]

epoch 208: avg test  loss 570.90, bar  test loss 3.980, len  test loss 0.071, col  test loss 160.021


Epoch 209: 543batch [01:09,  7.81batch/s, loss=522]


epoch 209: avg train loss 525.29, bar train loss 2.646, len train loss 0.044, col train loss 154.692


Epoch 210: 1batch [00:00,  7.94batch/s, loss=534]

epoch 209: avg test  loss 570.27, bar  test loss 3.972, len  test loss 0.067, col  test loss 160.046


Epoch 210: 100batch [00:13,  7.64batch/s, loss=539]


KeyboardInterrupt: 

In [None]:
diva.lqz1[-5:], diva.lqz2[-5:], diva.lpz1[-5:], diva.lpz2[-5:]

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')