In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/new/RHVAE/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z1_dim=256, z2_dim=256, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 256, h2_dim = 256, number_components = 500,
                 beta=1, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x = self.get_encoded_values(self.images, ds)
        else:
            x = np.load(f'{link}/data/modmirbase_{ds}_images_RNN.npz')
        
        self.x = x
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x = self.x[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x2 = self.x[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x2, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        idx 0 - 5: color top bar
        idx 6 - 11: color bot bar
        idx 12: length top bar
        idx 13: length bot bar
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out = np.zeros((n,100,14), dtype=np.uint8)
        
        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            for j in range(100):
                
                # check color of top bar
                col = self.get_color(x[i,:,12,j])
                out[i, j, col] = 1
                if col < 5:
                # check length of top bar
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out[i, j, 12] = len1

                # check color of bottom bar
                col = self.get_color(x[i,:,13,j])
                out[i, j, col+6] = 1
                if col < 5:
                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out[i, j, 13] = len2

        with open(f'{link}/data/modmirbase_{ds}_images_RNN.npz', 'wb') as f:
            np.save(f, out)
        
        

        return out

        
        
    
    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        elif (pixel == np.array([1,1,1])).all():
            return 5 # white
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=800, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        # seperate decoders for length of RNA, color and size of bars
        self.fc = nn.Sequential(nn.Linear(2*h_dim, 14),
                                    nn.ReLU())
        
        self.rnn1 = nn.LSTM(14, 64, num_layers=2, batch_first=True)
        self.rnn2 = nn.LSTM(64,128, num_layers=2, batch_first=True)
        self.fc2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Dropout(0.4), nn.Linear(128,14))
        #self.f1 = nn.Linear(64, 14)
        
        
        
        
    def forward(self, z1, mz2, x2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        h = self.fc(h).unsqueeze(1)
        
        bs = x2.shape[0]
        
        x2 = x2[:,:-1,:]
        # initial input
        zeros = torch.zeros((bs, 1, 14)).to(DEVICE)
        
        inputs = torch.cat([h, zeros, x2], dim=1)
        rnn_o, rnn_s = self.rnn1(inputs)
        rnn_o, rnn_s = self.rnn2(rnn_o)
        output = self.fc2(rnn_o)[:, 1:]
        
        #output[:,:,:6] = nn.Softmax(dim=2)(output[:,:,:6])
        
        #output[:,:,6:12] = nn.Softmax(dim=2)(output[:,:,6:12])
        
        return output, pz1_m, pz1_s
        
        
    def reconstruct_image(self, out, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0]), # yellow
                  5: np.array([1,1,1])  # white
                  }
        
        out = out.cpu().numpy()
        n = out.shape[0]
        output = np.ones((n,25,100,3))

        white = np.array([1,1,1])
        
        for i in range(n):
            for j in range(100):
                col1 = color_dict[np.argmax(out[i,j,:6])]
                col2 = color_dict[np.argmax(out[i,j,6:12])]
                len1 = out[i,j,12]
                len2 = out[i,j,13]
                    
                
                h1 = max(0,13-round(len1))
                # paint upper bar
                output[i, h1:13, j] = col1
                h2 = min(25,13+round(len2))
                # paint lower bar
                output[i, 13:h2, j] = col2
        
        
        return output


In [9]:
int(np.round(3.7, 0)),int(3.7)

(4, 3)

In [10]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
pzy_ = px(45, 7500, 2, 256,256,256,256)
summary(pzy_, [(1,256),(1,456),(1,100,14)])

Layer (type:depth-idx)                   Output Shape              Param #
px                                       --                        --
├─Sequential: 1-1                        [1, 256]                  --
│    └─Linear: 2-1                       [1, 256]                  116,992
│    └─ReLU: 2-2                         [1, 256]                  --
│    └─Linear: 2-3                       [1, 256]                  65,792
│    └─ReLU: 2-4                         [1, 256]                  --
├─Sequential: 1-2                        [1, 256]                  --
│    └─Linear: 2-5                       [1, 256]                  65,792
├─Sequential: 1-3                        [1, 256]                  --
│    └─Linear: 2-6                       [1, 256]                  65,792
│    └─Softplus: 2-7                     [1, 256]                  --
├─Sequential: 1-4                        [1, 256]                  --
│    └─Linear: 2-8                       [1, 256]                  6

## Endcoder Classes

In [11]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [12]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(60, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(60, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(72, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(2592, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(2592, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(60, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(60, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(72, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(2592, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 2592)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 2592)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [13]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [14]:
enc = qz(128,10,10,256,256,256,256)
enc(torch.zeros((1,3,25,100)), torch.zeros((1,200)))
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 72, 3, 12]            --
│    └─Conv2d: 2-1                       [1, 48, 25, 100]          1,344
│    └─ReLU: 2-2                         [1, 48, 25, 100]          --
│    └─Conv2d: 2-3                       [1, 48, 25, 100]          20,784
│    └─ReLU: 2-4                         [1, 48, 25, 100]          --
│    └─MaxPool2d: 2-5                    [1, 48, 12, 50]           --
│    └─Conv2d: 2-6                       [1, 60, 12, 50]           25,980
│    └─ReLU: 2-7                         [1, 60, 12, 50]           --
│    └─Conv2d: 2-8                       [1, 60, 12, 50]           32,460
│    └─ReLU: 2-9                         [1, 60, 12, 50]           --
│    └─MaxPool2d: 2-10                   [1, 60, 6, 25]            --
│    └─Conv2d: 2-11                      [1, 72, 6, 25]            38,

In [15]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [16]:
class RHVAE(nn.Module):
    def __init__(self, args):
        super(RHVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        #self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m, x2):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x2_hat, pz1_m, pz1_s = self.px(z1, mz2, x2)
        
        return x2_hat, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, x2):
        
        x2_hat, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m, x2)
        
        # Reconstruction Loss
        x_rec_hat = x2_hat.permute((0,2,1))
        x_rec = x2.permute((0,2,1))
        
        rec = nn.CrossEntropyLoss()(x_rec_hat[:,:6,:], x_rec[:,:6,:]) + nn.CrossEntropyLoss()(x_rec_hat[:,6:12,:], x_rec[:,6:12,:])
        #rec = F.cross_entropy(x_rec[:,:6,:], x_rec_hat[:,:6,:], reduction='sum') + F.cross_entropy(x_rec[:,6:12,:], x_rec_hat[:,6:12,:], reduction='sum')

        
        mse_bar = F.mse_loss(x2_hat[:,:,12:], x2[:,:,12:], reduction='mean')

        acc_bar = 0
        acc_bar2 = 1
        acc_bar3 = 1
        acc_bar4 = 1
        acc_bar5 = 1
        
        RE_bar = mse_bar
        RE_col = rec
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        
        return self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_col, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5
    
    
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [17]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [18]:
default_args = diva_args()
enc = RHVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200),(1,100,14)])

Layer (type:depth-idx)                   Output Shape              Param #
RHVAE                                    --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [19]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [20]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [21]:
#RNA_dataset.x_bar.shape, RNA_dataset.x_col.shape 

In [22]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    acc_bar2 = 0
    acc_bar3 = 0
    acc_bar4 = 0
    acc_bar5 = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x2, m) in pbar:
        # To device
        x, y, d , x2, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x2.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5 = model.loss_function(d.float(), x.float(), y.float(), m.float(), x2.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        mse_bar += mse
        acc_bar += acc
        acc_bar2 += acc2
        acc_bar3 += acc3
        acc_bar4 += acc4
        acc_bar5 += acc5
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    acc_bar2 /= len(train_loader.dataset)
    acc_bar3 /= len(train_loader.dataset)
    acc_bar4 /= len(train_loader.dataset)
    acc_bar5 /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    pbar.set_postfix(loss=train_loss)
    
    return train_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5 

In [23]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    mse_bar = 0
    acc_bar = 0   
    acc_bar2 = 0
    acc_bar3 = 0
    acc_bar4 = 0
    acc_bar5 = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,x2, m) in enumerate(test_loader):
            x, y, d, x2, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x2.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5  = model.loss_function(d.float(), x.float(), y.float(),m.float(),x2.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            mse_bar += mse
            acc_bar += acc
            acc_bar2 += acc2
            acc_bar3 += acc3
            acc_bar4 += acc4
            acc_bar5 += acc5
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    acc_bar2 /= len(test_loader.dataset)
    acc_bar3 /= len(test_loader.dataset)
    acc_bar4 /= len(test_loader.dataset)
    acc_bar5 /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5 
  

In [24]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_col, mtr, atr, atr2, atr3, atr4, atr5  = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_col_test, mte, ate, ate2, ate3, ate4, ate5  = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("coll loss", {'train':avg_loss_col, 'test':avg_loss_col_test}, epoch)
#             writer.add_scalars("bar_acc",{'train-top1': atr, 'test-top1':ate,
#                                           'train-top2': atr2, 'test-top2':ate2,
#                                           'train-top3': atr3, 'test-top3':ate3,
#                                           'train-top4': atr4, 'test-top4':ate4,
#                                           'train-top5': atr4, 'test-top5':ate4
#                                          }, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [25]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        x2 = a[1][3][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x2_hat ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m,x2)
        out = diva.px.reconstruct_image(x2_hat)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [26]:
DEVICE

device(type='cuda')

## Model Training

In [27]:
default_args = diva_args(prewarmup=0, number_components=50, z1_dim=256, z2_dim=256)

In [28]:
diva = RHVAE(default_args).to(DEVICE)

In [29]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/new/HVAE3/checkpoints/100.pth'))

In [30]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [31]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [32]:
#RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

In [33]:
writer.flush()

In [34]:
#diva.rec_gamma = 3

In [35]:
%tensorboard  --logdir="D:/users/Marko/downloads/mirna/saved_models/new/RHVAE/tensorboard/"

Reusing TensorBoard on port 6006 (pid 18760), started 1:03:36 ago. (Use '!kill 18760' to kill it.)

In [36]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 0, save_folder="new/RHVAE",save_interval=5)

Epoch 1: 272batch [00:37,  7.21batch/s, loss=1.79] 


epoch 1: avg train loss 1.01, bar train loss 0.042, col train loss 0.017
epoch 1: avg test  loss 0.46, bar  test loss 0.020, col  test loss 0.013


Epoch 2: 272batch [00:37,  7.24batch/s, loss=1.75] 


epoch 2: avg train loss 0.48, bar train loss 0.022, col train loss 0.013


Epoch 3: 1batch [00:00,  7.30batch/s, loss=0.506]

epoch 2: avg test  loss 0.44, bar  test loss 0.019, col  test loss 0.013


Epoch 3: 272batch [00:37,  7.21batch/s, loss=1.74] 


epoch 3: avg train loss 0.47, bar train loss 0.021, col train loss 0.013


Epoch 4: 1batch [00:00,  7.30batch/s, loss=0.476]

epoch 3: avg test  loss 0.41, bar  test loss 0.019, col  test loss 0.013


Epoch 4: 272batch [00:37,  7.26batch/s, loss=1.82] 


epoch 4: avg train loss 0.46, bar train loss 0.021, col train loss 0.013


Epoch 5: 1batch [00:00,  7.25batch/s, loss=0.451]

epoch 4: avg test  loss 0.41, bar  test loss 0.019, col  test loss 0.013


Epoch 5: 272batch [00:37,  7.17batch/s, loss=1.78] 


epoch 5: avg train loss 0.45, bar train loss 0.020, col train loss 0.013
epoch 5: avg test  loss 0.41, bar  test loss 0.018, col  test loss 0.012


Epoch 6: 272batch [00:38,  7.15batch/s, loss=1.66] 


epoch 6: avg train loss 0.44, bar train loss 0.020, col train loss 0.013


Epoch 7: 1batch [00:00,  7.30batch/s, loss=0.436]

epoch 6: avg test  loss 0.40, bar  test loss 0.018, col  test loss 0.012


Epoch 7: 272batch [00:37,  7.24batch/s, loss=1.93] 


epoch 7: avg train loss 0.45, bar train loss 0.020, col train loss 0.013


Epoch 8: 1batch [00:00,  7.30batch/s, loss=0.429]

epoch 7: avg test  loss 0.39, bar  test loss 0.018, col  test loss 0.012


Epoch 8: 272batch [00:37,  7.19batch/s, loss=1.73] 


epoch 8: avg train loss 0.43, bar train loss 0.019, col train loss 0.013


Epoch 9: 1batch [00:00,  6.85batch/s, loss=0.459]

epoch 8: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 9: 272batch [00:38,  7.14batch/s, loss=1.63] 


epoch 9: avg train loss 0.42, bar train loss 0.019, col train loss 0.012


Epoch 10: 1batch [00:00,  7.19batch/s, loss=0.452]

epoch 9: avg test  loss 0.39, bar  test loss 0.017, col  test loss 0.012


Epoch 10: 272batch [00:37,  7.16batch/s, loss=1.62] 


epoch 10: avg train loss 0.42, bar train loss 0.019, col train loss 0.012
epoch 10: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 11: 272batch [00:38,  7.15batch/s, loss=1.47] 


epoch 11: avg train loss 0.41, bar train loss 0.019, col train loss 0.012


Epoch 12: 1batch [00:00,  7.25batch/s, loss=0.398]

epoch 11: avg test  loss 0.38, bar  test loss 0.017, col  test loss 0.012


Epoch 12: 272batch [00:37,  7.18batch/s, loss=1.65] 


epoch 12: avg train loss 0.41, bar train loss 0.018, col train loss 0.012


Epoch 13: 1batch [00:00,  7.30batch/s, loss=0.419]

epoch 12: avg test  loss 0.37, bar  test loss 0.017, col  test loss 0.012


Epoch 13: 272batch [00:38,  7.14batch/s, loss=1.45] 


epoch 13: avg train loss 0.40, bar train loss 0.018, col train loss 0.012


Epoch 14: 1batch [00:00,  7.30batch/s, loss=0.412]

epoch 13: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 14: 272batch [00:37,  7.19batch/s, loss=1.48] 


epoch 14: avg train loss 0.40, bar train loss 0.018, col train loss 0.012


Epoch 15: 1batch [00:00,  7.25batch/s, loss=0.444]

epoch 14: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 15: 272batch [00:37,  7.22batch/s, loss=1.38] 


epoch 15: avg train loss 0.39, bar train loss 0.018, col train loss 0.012
epoch 15: avg test  loss 0.37, bar  test loss 0.016, col  test loss 0.012


Epoch 16: 272batch [00:37,  7.22batch/s, loss=1.45] 


epoch 16: avg train loss 0.39, bar train loss 0.017, col train loss 0.012


Epoch 17: 1batch [00:00,  7.25batch/s, loss=0.396]

epoch 16: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 17: 272batch [00:38,  7.15batch/s, loss=1.23] 


epoch 17: avg train loss 0.39, bar train loss 0.017, col train loss 0.012


Epoch 18: 1batch [00:00,  7.14batch/s, loss=0.369]

epoch 17: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 18: 272batch [00:38,  7.13batch/s, loss=1.51] 


epoch 18: avg train loss 0.38, bar train loss 0.017, col train loss 0.012


Epoch 19: 1batch [00:00,  7.25batch/s, loss=0.367]

epoch 18: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 19: 272batch [00:37,  7.16batch/s, loss=1.43] 


epoch 19: avg train loss 0.38, bar train loss 0.017, col train loss 0.012


Epoch 20: 1batch [00:00,  7.25batch/s, loss=0.378]

epoch 19: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 20: 272batch [00:38,  7.14batch/s, loss=1.4]  


epoch 20: avg train loss 0.38, bar train loss 0.017, col train loss 0.012
epoch 20: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 21: 272batch [00:38,  7.11batch/s, loss=1.38] 


epoch 21: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 22: 1batch [00:00,  7.04batch/s, loss=0.388]

epoch 21: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 22: 272batch [00:38,  7.11batch/s, loss=1.34] 


epoch 22: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 23: 1batch [00:00,  7.19batch/s, loss=0.38]

epoch 22: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 23: 272batch [00:37,  7.21batch/s, loss=1.3]  


epoch 23: avg train loss 0.37, bar train loss 0.017, col train loss 0.012


Epoch 24: 1batch [00:00,  7.19batch/s, loss=0.367]

epoch 23: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 24: 272batch [00:38,  7.12batch/s, loss=1.33] 


epoch 24: avg train loss 0.37, bar train loss 0.016, col train loss 0.012


Epoch 25: 1batch [00:00,  7.19batch/s, loss=0.41]

epoch 24: avg test  loss 0.36, bar  test loss 0.016, col  test loss 0.012


Epoch 25: 272batch [00:38,  7.11batch/s, loss=1.42] 


epoch 25: avg train loss 0.37, bar train loss 0.016, col train loss 0.012
epoch 25: avg test  loss 0.35, bar  test loss 0.016, col  test loss 0.012


Epoch 26: 272batch [00:38,  7.07batch/s, loss=1.41] 


epoch 26: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 27: 1batch [00:00,  7.19batch/s, loss=0.345]

epoch 26: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 27: 272batch [00:38,  7.09batch/s, loss=1.49] 


epoch 27: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 28: 1batch [00:00,  7.09batch/s, loss=0.391]

epoch 27: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 28: 272batch [00:37,  7.17batch/s, loss=1.35] 


epoch 28: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 29: 1batch [00:00,  6.62batch/s, loss=0.402]

epoch 28: avg test  loss 0.35, bar  test loss 0.015, col  test loss 0.012


Epoch 29: 272batch [00:38,  7.07batch/s, loss=1.23] 


epoch 29: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 30: 1batch [00:00,  6.99batch/s, loss=0.354]

epoch 29: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 30: 272batch [00:38,  7.04batch/s, loss=1.26] 


epoch 30: avg train loss 0.36, bar train loss 0.016, col train loss 0.012
epoch 30: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 31: 272batch [00:38,  7.09batch/s, loss=1.34] 


epoch 31: avg train loss 0.36, bar train loss 0.016, col train loss 0.012


Epoch 32: 1batch [00:00,  7.14batch/s, loss=0.378]

epoch 31: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 32: 272batch [00:38,  7.05batch/s, loss=1.29] 


epoch 32: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 33: 1batch [00:00,  7.14batch/s, loss=0.376]

epoch 32: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 33: 272batch [00:38,  7.10batch/s, loss=1.34] 


epoch 33: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 34: 1batch [00:00,  7.14batch/s, loss=0.332]

epoch 33: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 34: 272batch [00:38,  7.06batch/s, loss=1.3]  


epoch 34: avg train loss 0.35, bar train loss 0.016, col train loss 0.012


Epoch 35: 1batch [00:00,  7.19batch/s, loss=0.317]

epoch 34: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 35: 272batch [00:38,  7.06batch/s, loss=1.38] 


epoch 35: avg train loss 0.35, bar train loss 0.015, col train loss 0.012
epoch 35: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 36: 272batch [00:38,  7.06batch/s, loss=1.24] 


epoch 36: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 37: 1batch [00:00,  7.19batch/s, loss=0.373]

epoch 36: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 37: 272batch [00:38,  7.08batch/s, loss=1.34] 


epoch 37: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 38: 1batch [00:00,  6.99batch/s, loss=0.328]

epoch 37: avg test  loss 0.34, bar  test loss 0.015, col  test loss 0.012


Epoch 38: 272batch [00:38,  7.00batch/s, loss=1.28] 


epoch 38: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 39: 1batch [00:00,  7.04batch/s, loss=0.332]

epoch 38: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 39: 272batch [00:38,  7.01batch/s, loss=1.3]  


epoch 39: avg train loss 0.35, bar train loss 0.015, col train loss 0.012


Epoch 40: 1batch [00:00,  7.04batch/s, loss=0.345]

epoch 39: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 40: 272batch [00:38,  7.01batch/s, loss=1.13] 


epoch 40: avg train loss 0.34, bar train loss 0.015, col train loss 0.012
epoch 40: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 41: 272batch [00:38,  7.02batch/s, loss=1.29] 


epoch 41: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 42: 1batch [00:00,  6.99batch/s, loss=0.298]

epoch 41: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 42: 272batch [00:38,  7.01batch/s, loss=1.32] 


epoch 42: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 43: 1batch [00:00,  7.09batch/s, loss=0.369]

epoch 42: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 43: 272batch [00:38,  7.02batch/s, loss=1.1]  


epoch 43: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 44: 1batch [00:00,  7.04batch/s, loss=0.349]

epoch 43: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 44: 272batch [00:39,  6.88batch/s, loss=1.21] 


epoch 44: avg train loss 0.34, bar train loss 0.015, col train loss 0.012


Epoch 45: 1batch [00:00,  6.94batch/s, loss=0.292]

epoch 44: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 45: 272batch [00:39,  6.97batch/s, loss=1.1]  


epoch 45: avg train loss 0.34, bar train loss 0.015, col train loss 0.012
epoch 45: avg test  loss 0.33, bar  test loss 0.014, col  test loss 0.012


Epoch 46: 272batch [00:39,  6.88batch/s, loss=1.11] 


epoch 46: avg train loss 0.33, bar train loss 0.015, col train loss 0.012


Epoch 47: 1batch [00:00,  6.54batch/s, loss=0.291]

epoch 46: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 47: 272batch [00:40,  6.74batch/s, loss=1.28] 


epoch 47: avg train loss 0.33, bar train loss 0.015, col train loss 0.012


Epoch 48: 1batch [00:00,  6.99batch/s, loss=0.366]

epoch 47: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 48: 272batch [00:39,  6.81batch/s, loss=1.21] 


epoch 48: avg train loss 0.33, bar train loss 0.015, col train loss 0.012


Epoch 49: 1batch [00:00,  6.90batch/s, loss=0.35]

epoch 48: avg test  loss 0.33, bar  test loss 0.015, col  test loss 0.012


Epoch 49: 272batch [00:41,  6.59batch/s, loss=1.06] 


epoch 49: avg train loss 0.33, bar train loss 0.015, col train loss 0.012


Epoch 50: 1batch [00:00,  6.85batch/s, loss=0.314]

epoch 49: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 50: 272batch [00:41,  6.61batch/s, loss=1.29] 


epoch 50: avg train loss 0.33, bar train loss 0.015, col train loss 0.012
epoch 50: avg test  loss 0.32, bar  test loss 0.014, col  test loss 0.012


Epoch 51: 234batch [00:34,  6.73batch/s, loss=0.36] 


KeyboardInterrupt: 

In [37]:
diva.rec_gamma = 20

In [38]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1000,50,save_folder="new/RHVAE",save_interval=5)

Epoch 51: 272batch [00:40,  6.80batch/s, loss=2.05] 


epoch 51: avg train loss 0.55, bar train loss 0.014, col train loss 0.012


Epoch 52: 1batch [00:00,  6.67batch/s, loss=0.556]

epoch 51: avg test  loss 0.54, bar  test loss 0.014, col  test loss 0.012


Epoch 52: 272batch [00:40,  6.77batch/s, loss=2.08] 


epoch 52: avg train loss 0.55, bar train loss 0.014, col train loss 0.012


Epoch 53: 1batch [00:00,  6.90batch/s, loss=0.61]

epoch 52: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 53: 272batch [00:40,  6.79batch/s, loss=2.11] 


epoch 53: avg train loss 0.55, bar train loss 0.014, col train loss 0.012


Epoch 54: 1batch [00:00,  6.94batch/s, loss=0.521]

epoch 53: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 54: 272batch [00:40,  6.72batch/s, loss=2.11] 


epoch 54: avg train loss 0.55, bar train loss 0.014, col train loss 0.012


Epoch 55: 1batch [00:00,  6.90batch/s, loss=0.574]

epoch 54: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 55: 272batch [00:40,  6.72batch/s, loss=1.89] 


epoch 55: avg train loss 0.54, bar train loss 0.014, col train loss 0.012
epoch 55: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 56: 272batch [00:41,  6.63batch/s, loss=2.03] 


epoch 56: avg train loss 0.54, bar train loss 0.014, col train loss 0.012


Epoch 57: 1batch [00:00,  6.80batch/s, loss=0.522]

epoch 56: avg test  loss 0.54, bar  test loss 0.014, col  test loss 0.012


Epoch 57: 272batch [00:40,  6.65batch/s, loss=2.02] 


epoch 57: avg train loss 0.54, bar train loss 0.014, col train loss 0.012


Epoch 58: 1batch [00:00,  6.76batch/s, loss=0.491]

epoch 57: avg test  loss 0.54, bar  test loss 0.014, col  test loss 0.012


Epoch 58: 272batch [00:41,  6.59batch/s, loss=1.87] 


epoch 58: avg train loss 0.54, bar train loss 0.014, col train loss 0.012


Epoch 59: 1batch [00:00,  6.80batch/s, loss=0.502]

epoch 58: avg test  loss 0.54, bar  test loss 0.014, col  test loss 0.012


Epoch 59: 272batch [00:41,  6.56batch/s, loss=2.07] 


epoch 59: avg train loss 0.54, bar train loss 0.014, col train loss 0.012


Epoch 60: 1batch [00:00,  6.94batch/s, loss=0.508]

epoch 59: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 60: 272batch [00:40,  6.78batch/s, loss=1.99] 


epoch 60: avg train loss 0.54, bar train loss 0.014, col train loss 0.012
epoch 60: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 61: 272batch [00:40,  6.75batch/s, loss=2.11] 


epoch 61: avg train loss 0.53, bar train loss 0.014, col train loss 0.012


Epoch 62: 1batch [00:00,  6.71batch/s, loss=0.581]

epoch 61: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 62: 272batch [00:40,  6.70batch/s, loss=1.92] 


epoch 62: avg train loss 0.53, bar train loss 0.014, col train loss 0.012


Epoch 63: 1batch [00:00,  6.85batch/s, loss=0.53]

epoch 62: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 63: 272batch [00:40,  6.69batch/s, loss=2.04] 


epoch 63: avg train loss 0.53, bar train loss 0.014, col train loss 0.012


Epoch 64: 1batch [00:00,  6.71batch/s, loss=0.488]

epoch 63: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 64: 272batch [00:40,  6.63batch/s, loss=2.19] 


epoch 64: avg train loss 0.53, bar train loss 0.013, col train loss 0.012


Epoch 65: 1batch [00:00,  6.71batch/s, loss=0.563]

epoch 64: avg test  loss 0.54, bar  test loss 0.014, col  test loss 0.012


Epoch 65: 272batch [00:40,  6.64batch/s, loss=2.02] 


epoch 65: avg train loss 0.53, bar train loss 0.013, col train loss 0.012
epoch 65: avg test  loss 0.55, bar  test loss 0.014, col  test loss 0.012


Epoch 66: 272batch [00:41,  6.60batch/s, loss=1.95] 


epoch 66: avg train loss 0.53, bar train loss 0.013, col train loss 0.012


Epoch 67: 1batch [00:00,  6.58batch/s, loss=0.551]

epoch 66: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 67: 272batch [00:41,  6.61batch/s, loss=1.92] 


epoch 67: avg train loss 0.52, bar train loss 0.013, col train loss 0.012


Epoch 68: 1batch [00:00,  6.71batch/s, loss=0.529]

epoch 67: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 68: 272batch [00:41,  6.61batch/s, loss=1.79] 


epoch 68: avg train loss 0.52, bar train loss 0.013, col train loss 0.012


Epoch 69: 1batch [00:00,  6.76batch/s, loss=0.499]

epoch 68: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 69: 272batch [00:41,  6.57batch/s, loss=1.95] 


epoch 69: avg train loss 0.52, bar train loss 0.013, col train loss 0.012


Epoch 70: 1batch [00:00,  6.67batch/s, loss=0.546]

epoch 69: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 70: 272batch [00:41,  6.56batch/s, loss=1.87] 


epoch 70: avg train loss 0.52, bar train loss 0.013, col train loss 0.012
epoch 70: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 71: 272batch [00:42,  6.43batch/s, loss=2.06] 


epoch 71: avg train loss 0.52, bar train loss 0.013, col train loss 0.012


Epoch 72: 1batch [00:00,  6.37batch/s, loss=0.472]

epoch 71: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 72: 272batch [00:42,  6.42batch/s, loss=2.01] 


epoch 72: avg train loss 0.52, bar train loss 0.013, col train loss 0.012


Epoch 73: 1batch [00:00,  6.54batch/s, loss=0.483]

epoch 72: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 73: 272batch [00:42,  6.40batch/s, loss=2.01] 


epoch 73: avg train loss 0.51, bar train loss 0.013, col train loss 0.012


Epoch 74: 1batch [00:00,  6.54batch/s, loss=0.53]

epoch 73: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 74: 272batch [00:42,  6.39batch/s, loss=1.97] 


epoch 74: avg train loss 0.51, bar train loss 0.013, col train loss 0.012


Epoch 75: 1batch [00:00,  6.49batch/s, loss=0.491]

epoch 74: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 75: 272batch [00:42,  6.37batch/s, loss=1.79] 


epoch 75: avg train loss 0.51, bar train loss 0.013, col train loss 0.012
epoch 75: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 76: 272batch [00:42,  6.34batch/s, loss=2.01] 


epoch 76: avg train loss 0.51, bar train loss 0.013, col train loss 0.012


Epoch 77: 1batch [00:00,  6.49batch/s, loss=0.536]

epoch 76: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 77: 272batch [00:42,  6.35batch/s, loss=1.9]  


epoch 77: avg train loss 0.51, bar train loss 0.012, col train loss 0.012


Epoch 78: 1batch [00:00,  6.49batch/s, loss=0.523]

epoch 77: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 78: 272batch [00:42,  6.33batch/s, loss=1.87] 


epoch 78: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 79: 1batch [00:00,  6.54batch/s, loss=0.519]

epoch 78: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 79: 272batch [00:43,  6.18batch/s, loss=1.88] 


epoch 79: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 80: 1batch [00:00,  6.29batch/s, loss=0.518]

epoch 79: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 80: 272batch [00:43,  6.20batch/s, loss=1.68] 


epoch 80: avg train loss 0.50, bar train loss 0.012, col train loss 0.012
epoch 80: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 81: 272batch [00:43,  6.19batch/s, loss=1.92] 


epoch 81: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 82: 1batch [00:00,  6.37batch/s, loss=0.499]

epoch 81: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 82: 272batch [00:44,  6.18batch/s, loss=1.88] 


epoch 82: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 83: 1batch [00:00,  6.37batch/s, loss=0.464]

epoch 82: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 83: 272batch [00:44,  6.16batch/s, loss=1.89] 


epoch 83: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 84: 1batch [00:00,  6.41batch/s, loss=0.474]

epoch 83: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 84: 272batch [00:44,  6.15batch/s, loss=1.95] 


epoch 84: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 85: 1batch [00:00,  6.29batch/s, loss=0.465]

epoch 84: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 85: 272batch [00:44,  6.12batch/s, loss=1.79] 


epoch 85: avg train loss 0.50, bar train loss 0.012, col train loss 0.012
epoch 85: avg test  loss 0.55, bar  test loss 0.015, col  test loss 0.012


Epoch 86: 272batch [00:44,  6.11batch/s, loss=2.06] 


epoch 86: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 87: 1batch [00:00,  6.49batch/s, loss=0.466]

epoch 86: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 87: 272batch [00:44,  6.11batch/s, loss=1.81] 


epoch 87: avg train loss 0.50, bar train loss 0.012, col train loss 0.012


Epoch 88: 1batch [00:00,  6.21batch/s, loss=0.514]

epoch 87: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 88: 272batch [00:44,  6.12batch/s, loss=1.85] 


epoch 88: avg train loss 0.49, bar train loss 0.012, col train loss 0.012


Epoch 89: 0batch [00:00, ?batch/s, loss=0.454]

epoch 88: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 89: 272batch [00:44,  6.11batch/s, loss=1.9]  


epoch 89: avg train loss 0.49, bar train loss 0.012, col train loss 0.012


Epoch 90: 1batch [00:00,  6.25batch/s, loss=0.463]

epoch 89: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 90: 272batch [00:44,  6.09batch/s, loss=1.82] 


epoch 90: avg train loss 0.49, bar train loss 0.012, col train loss 0.012
epoch 90: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 91: 272batch [00:44,  6.08batch/s, loss=1.88] 


epoch 91: avg train loss 0.49, bar train loss 0.012, col train loss 0.012


Epoch 92: 0batch [00:00, ?batch/s, loss=0.475]

epoch 91: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 92: 272batch [00:44,  6.07batch/s, loss=1.83] 


epoch 92: avg train loss 0.49, bar train loss 0.012, col train loss 0.011


Epoch 93: 1batch [00:00,  6.21batch/s, loss=0.496]

epoch 92: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 93: 272batch [00:44,  6.07batch/s, loss=1.94] 


epoch 93: avg train loss 0.49, bar train loss 0.012, col train loss 0.011


Epoch 94: 0batch [00:00, ?batch/s, loss=0.499]

epoch 93: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 94: 272batch [00:44,  6.06batch/s, loss=1.72] 


epoch 94: avg train loss 0.49, bar train loss 0.012, col train loss 0.011


Epoch 95: 0batch [00:00, ?batch/s, loss=0.493]

epoch 94: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 95: 272batch [00:44,  6.06batch/s, loss=1.7]  


epoch 95: avg train loss 0.48, bar train loss 0.011, col train loss 0.011
epoch 95: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 96: 272batch [00:45,  6.03batch/s, loss=1.95] 


epoch 96: avg train loss 0.49, bar train loss 0.011, col train loss 0.011


Epoch 97: 0batch [00:00, ?batch/s, loss=0.533]

epoch 96: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 97: 272batch [00:45,  5.94batch/s, loss=1.61] 


epoch 97: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 98: 0batch [00:00, ?batch/s, loss=0.468]

epoch 97: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 98: 272batch [00:45,  6.03batch/s, loss=1.83] 


epoch 98: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 99: 0batch [00:00, ?batch/s, loss=0.521]

epoch 98: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 99: 272batch [00:45,  5.99batch/s, loss=1.92] 


epoch 99: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 100: 0batch [00:00, ?batch/s, loss=0.525]

epoch 99: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 100: 272batch [00:45,  6.01batch/s, loss=1.87] 


epoch 100: avg train loss 0.48, bar train loss 0.011, col train loss 0.011
epoch 100: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 101: 272batch [00:45,  5.98batch/s, loss=1.83] 


epoch 101: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 102: 0batch [00:00, ?batch/s, loss=0.511]

epoch 101: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 102: 272batch [00:45,  6.00batch/s, loss=1.75] 


epoch 102: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 103: 0batch [00:00, ?batch/s]

epoch 102: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 103: 272batch [00:45,  5.93batch/s, loss=1.81] 


epoch 103: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 104: 0batch [00:00, ?batch/s, loss=0.458]

epoch 103: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 104: 272batch [00:46,  5.87batch/s, loss=1.67] 


epoch 104: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 105: 0batch [00:00, ?batch/s, loss=0.527]

epoch 104: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 105: 272batch [00:46,  5.91batch/s, loss=1.7]  


epoch 105: avg train loss 0.48, bar train loss 0.011, col train loss 0.011
epoch 105: avg test  loss 0.56, bar  test loss 0.015, col  test loss 0.012


Epoch 106: 272batch [00:46,  5.85batch/s, loss=1.7]  


epoch 106: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 107: 0batch [00:00, ?batch/s, loss=0.448]

epoch 106: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 107: 272batch [00:47,  5.75batch/s, loss=1.75] 


epoch 107: avg train loss 0.48, bar train loss 0.011, col train loss 0.011


Epoch 108: 0batch [00:00, ?batch/s, loss=0.468]

epoch 107: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 108: 272batch [00:46,  5.85batch/s, loss=1.68] 


epoch 108: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 109: 0batch [00:00, ?batch/s, loss=0.511]

epoch 108: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 109: 272batch [00:46,  5.87batch/s, loss=1.76] 


epoch 109: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 110: 1batch [00:00,  6.67batch/s, loss=0.466]

epoch 109: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 110: 272batch [00:47,  5.78batch/s, loss=1.9]  


epoch 110: avg train loss 0.47, bar train loss 0.011, col train loss 0.011
epoch 110: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 111: 272batch [00:46,  5.84batch/s, loss=1.87] 


epoch 111: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 112: 0batch [00:00, ?batch/s, loss=0.443]

epoch 111: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 112: 272batch [00:46,  5.82batch/s, loss=1.8]  


epoch 112: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 113: 0batch [00:00, ?batch/s, loss=0.484]

epoch 112: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 113: 272batch [00:46,  5.88batch/s, loss=1.47] 


epoch 113: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 114: 0batch [00:00, ?batch/s]

epoch 113: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 114: 272batch [00:47,  5.76batch/s, loss=1.55] 


epoch 114: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 115: 0batch [00:00, ?batch/s]

epoch 114: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 115: 272batch [00:47,  5.72batch/s, loss=1.78] 


epoch 115: avg train loss 0.47, bar train loss 0.011, col train loss 0.011
epoch 115: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 116: 272batch [00:47,  5.74batch/s, loss=1.76] 


epoch 116: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 117: 0batch [00:00, ?batch/s, loss=0.509]

epoch 116: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 117: 272batch [00:47,  5.74batch/s, loss=1.74] 


epoch 117: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 118: 0batch [00:00, ?batch/s, loss=0.474]

epoch 117: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 118: 272batch [00:46,  5.90batch/s, loss=1.9]  


epoch 118: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 119: 0batch [00:00, ?batch/s, loss=0.46]

epoch 118: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 119: 272batch [00:46,  5.89batch/s, loss=1.74] 


epoch 119: avg train loss 0.47, bar train loss 0.011, col train loss 0.011


Epoch 120: 0batch [00:00, ?batch/s, loss=0.474]

epoch 119: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 120: 272batch [00:46,  5.84batch/s, loss=1.66] 


epoch 120: avg train loss 0.46, bar train loss 0.010, col train loss 0.011
epoch 120: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 121: 272batch [00:46,  5.83batch/s, loss=1.74] 


epoch 121: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 122: 0batch [00:00, ?batch/s]

epoch 121: avg test  loss 0.57, bar  test loss 0.015, col  test loss 0.012


Epoch 122: 272batch [00:46,  5.85batch/s, loss=1.73] 


epoch 122: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 123: 0batch [00:00, ?batch/s, loss=0.483]

epoch 122: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 123: 272batch [00:46,  5.85batch/s, loss=1.57] 


epoch 123: avg train loss 0.47, bar train loss 0.010, col train loss 0.011


Epoch 124: 0batch [00:00, ?batch/s, loss=0.448]

epoch 123: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 124: 272batch [00:46,  5.81batch/s, loss=1.7]  


epoch 124: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 125: 0batch [00:00, ?batch/s, loss=0.472]

epoch 124: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 125: 272batch [00:46,  5.82batch/s, loss=1.77] 


epoch 125: avg train loss 0.46, bar train loss 0.010, col train loss 0.011
epoch 125: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 126: 272batch [00:46,  5.82batch/s, loss=1.65] 


epoch 126: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 127: 0batch [00:00, ?batch/s, loss=0.452]

epoch 126: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 127: 272batch [00:46,  5.82batch/s, loss=1.48] 


epoch 127: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 128: 0batch [00:00, ?batch/s, loss=0.436]

epoch 127: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 128: 272batch [00:46,  5.80batch/s, loss=1.7]  


epoch 128: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 129: 0batch [00:00, ?batch/s]

epoch 128: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 129: 272batch [00:46,  5.79batch/s, loss=1.67] 


epoch 129: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 130: 0batch [00:00, ?batch/s, loss=0.485]

epoch 129: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 130: 272batch [00:47,  5.78batch/s, loss=1.82] 


epoch 130: avg train loss 0.46, bar train loss 0.010, col train loss 0.011
epoch 130: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 131: 272batch [00:47,  5.78batch/s, loss=1.69] 


epoch 131: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 132: 0batch [00:00, ?batch/s, loss=0.464]

epoch 131: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 132: 272batch [00:46,  5.80batch/s, loss=1.59] 


epoch 132: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 133: 0batch [00:00, ?batch/s, loss=0.452]

epoch 132: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 133: 272batch [00:47,  5.79batch/s, loss=1.67] 


epoch 133: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 134: 1batch [00:00,  6.33batch/s, loss=0.482]

epoch 133: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 134: 272batch [00:47,  5.75batch/s, loss=1.5]  


epoch 134: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 135: 1batch [00:00,  6.17batch/s, loss=0.466]

epoch 134: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 135: 272batch [00:47,  5.76batch/s, loss=1.75] 


epoch 135: avg train loss 0.46, bar train loss 0.010, col train loss 0.011
epoch 135: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 136: 272batch [00:47,  5.75batch/s, loss=1.77] 


epoch 136: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 137: 0batch [00:00, ?batch/s, loss=0.436]

epoch 136: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 137: 272batch [00:47,  5.76batch/s, loss=1.6]  


epoch 137: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 138: 0batch [00:00, ?batch/s, loss=0.453]

epoch 137: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 138: 272batch [00:47,  5.76batch/s, loss=1.75] 


epoch 138: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 139: 0batch [00:00, ?batch/s, loss=0.45]

epoch 138: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 139: 272batch [00:47,  5.76batch/s, loss=1.6]  


epoch 139: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 140: 0batch [00:00, ?batch/s]

epoch 139: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 140: 272batch [00:47,  5.74batch/s, loss=1.64] 


epoch 140: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 140: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 141: 272batch [00:47,  5.73batch/s, loss=1.92] 


epoch 141: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 142: 0batch [00:00, ?batch/s, loss=0.453]

epoch 141: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 142: 272batch [00:47,  5.74batch/s, loss=1.53] 


epoch 142: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 143: 0batch [00:00, ?batch/s, loss=0.439]

epoch 142: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 143: 272batch [00:47,  5.72batch/s, loss=1.74] 


epoch 143: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 144: 0batch [00:00, ?batch/s, loss=0.411]

epoch 143: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 144: 272batch [00:47,  5.71batch/s, loss=1.64] 


epoch 144: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 145: 0batch [00:00, ?batch/s]

epoch 144: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 145: 272batch [00:47,  5.71batch/s, loss=1.65] 


epoch 145: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 145: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 146: 272batch [00:48,  5.65batch/s, loss=1.57] 


epoch 146: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 147: 0batch [00:00, ?batch/s]

epoch 146: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 147: 272batch [00:47,  5.69batch/s, loss=1.58] 


epoch 147: avg train loss 0.46, bar train loss 0.010, col train loss 0.011


Epoch 148: 0batch [00:00, ?batch/s]

epoch 147: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 148: 272batch [00:47,  5.70batch/s, loss=1.73] 


epoch 148: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 149: 0batch [00:00, ?batch/s, loss=0.406]

epoch 148: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 149: 272batch [00:47,  5.67batch/s, loss=1.7]  


epoch 149: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 150: 0batch [00:00, ?batch/s, loss=0.415]

epoch 149: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 150: 272batch [00:48,  5.66batch/s, loss=1.51] 


epoch 150: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 150: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 151: 272batch [00:48,  5.63batch/s, loss=1.6]  


epoch 151: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 152: 0batch [00:00, ?batch/s]

epoch 151: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 152: 272batch [00:48,  5.63batch/s, loss=1.75] 


epoch 152: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 153: 0batch [00:00, ?batch/s, loss=0.466]

epoch 152: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 153: 272batch [00:48,  5.62batch/s, loss=1.46] 


epoch 153: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 154: 0batch [00:00, ?batch/s]

epoch 153: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 154: 272batch [00:48,  5.60batch/s, loss=1.64] 


epoch 154: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 155: 0batch [00:00, ?batch/s]

epoch 154: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 155: 272batch [00:48,  5.61batch/s, loss=1.89] 


epoch 155: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 155: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 156: 272batch [00:48,  5.56batch/s, loss=1.6]  


epoch 156: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 157: 0batch [00:00, ?batch/s, loss=0.413]

epoch 156: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 157: 272batch [00:48,  5.59batch/s, loss=1.54] 


epoch 157: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 158: 0batch [00:00, ?batch/s, loss=0.448]

epoch 157: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 158: 272batch [00:48,  5.57batch/s, loss=1.85] 


epoch 158: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 159: 0batch [00:00, ?batch/s, loss=0.45]

epoch 158: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 159: 272batch [00:48,  5.57batch/s, loss=1.7]  


epoch 159: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 160: 0batch [00:00, ?batch/s, loss=0.412]

epoch 159: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 160: 272batch [00:49,  5.54batch/s, loss=1.76] 


epoch 160: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 160: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 161: 272batch [00:49,  5.53batch/s, loss=1.76] 


epoch 161: avg train loss 0.45, bar train loss 0.009, col train loss 0.011


Epoch 162: 0batch [00:00, ?batch/s, loss=0.48]

epoch 161: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 162: 272batch [00:49,  5.54batch/s, loss=1.6]  


epoch 162: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 163: 0batch [00:00, ?batch/s]

epoch 162: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 163: 272batch [00:49,  5.52batch/s, loss=1.48] 


epoch 163: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 164: 0batch [00:00, ?batch/s]

epoch 163: avg test  loss 0.57, bar  test loss 0.016, col  test loss 0.012


Epoch 164: 272batch [00:49,  5.51batch/s, loss=1.7]  


epoch 164: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 165: 0batch [00:00, ?batch/s, loss=0.482]

epoch 164: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 165: 272batch [00:49,  5.50batch/s, loss=1.47] 


epoch 165: avg train loss 0.45, bar train loss 0.010, col train loss 0.011
epoch 165: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 166: 272batch [00:49,  5.50batch/s, loss=1.74] 


epoch 166: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 167: 0batch [00:00, ?batch/s, loss=0.433]

epoch 166: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 167: 272batch [00:49,  5.50batch/s, loss=1.63] 


epoch 167: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 168: 0batch [00:00, ?batch/s]

epoch 167: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 168: 272batch [00:49,  5.48batch/s, loss=1.64] 


epoch 168: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 169: 0batch [00:00, ?batch/s]

epoch 168: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 169: 272batch [00:49,  5.46batch/s, loss=1.66] 


epoch 169: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 170: 0batch [00:00, ?batch/s]

epoch 169: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 170: 272batch [00:49,  5.47batch/s, loss=1.67] 


epoch 170: avg train loss 0.44, bar train loss 0.010, col train loss 0.011
epoch 170: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 171: 272batch [00:49,  5.45batch/s, loss=1.67] 


epoch 171: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 172: 0batch [00:00, ?batch/s, loss=0.428]

epoch 171: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 172: 272batch [00:49,  5.47batch/s, loss=1.7]  


epoch 172: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 173: 0batch [00:00, ?batch/s]

epoch 172: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 173: 272batch [00:49,  5.44batch/s, loss=1.8]  


epoch 173: avg train loss 0.45, bar train loss 0.010, col train loss 0.011


Epoch 174: 0batch [00:00, ?batch/s, loss=0.452]

epoch 173: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 174: 272batch [00:50,  5.43batch/s, loss=1.54] 


epoch 174: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 175: 0batch [00:00, ?batch/s]

epoch 174: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 175: 272batch [00:50,  5.42batch/s, loss=1.68] 


epoch 175: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 175: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 176: 272batch [00:50,  5.41batch/s, loss=1.66] 


epoch 176: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 177: 0batch [00:00, ?batch/s]

epoch 176: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 177: 272batch [00:50,  5.42batch/s, loss=1.82] 


epoch 177: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 178: 0batch [00:00, ?batch/s]

epoch 177: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 178: 272batch [00:50,  5.40batch/s, loss=1.67] 


epoch 178: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 179: 0batch [00:00, ?batch/s]

epoch 178: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 179: 272batch [00:50,  5.38batch/s, loss=1.58] 


epoch 179: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 180: 0batch [00:00, ?batch/s]

epoch 179: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 180: 272batch [00:50,  5.38batch/s, loss=1.62] 


epoch 180: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 180: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 181: 272batch [00:50,  5.35batch/s, loss=1.66] 


epoch 181: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 182: 0batch [00:00, ?batch/s, loss=0.449]

epoch 181: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 182: 272batch [00:50,  5.36batch/s, loss=1.7]  


epoch 182: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 183: 0batch [00:00, ?batch/s]

epoch 182: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 183: 272batch [00:50,  5.34batch/s, loss=1.58] 


epoch 183: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 184: 0batch [00:00, ?batch/s]

epoch 183: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 184: 272batch [00:50,  5.34batch/s, loss=1.65] 


epoch 184: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 185: 0batch [00:00, ?batch/s]

epoch 184: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 185: 272batch [00:51,  5.23batch/s, loss=1.67] 


epoch 185: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 185: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 186: 272batch [00:58,  4.63batch/s, loss=1.59] 


epoch 186: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 187: 0batch [00:00, ?batch/s]

epoch 186: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 187: 272batch [00:53,  5.07batch/s, loss=1.73] 


epoch 187: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 188: 0batch [00:00, ?batch/s]

epoch 187: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 188: 272batch [00:52,  5.19batch/s, loss=1.79] 


epoch 188: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 189: 0batch [00:00, ?batch/s]

epoch 188: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 189: 272batch [00:58,  4.61batch/s, loss=1.72] 


epoch 189: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 190: 0batch [00:00, ?batch/s]

epoch 189: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 190: 272batch [01:02,  4.33batch/s, loss=1.62] 


epoch 190: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 190: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 191: 272batch [00:51,  5.26batch/s, loss=1.53] 


epoch 191: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 192: 0batch [00:00, ?batch/s, loss=0.446]

epoch 191: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 192: 272batch [00:50,  5.41batch/s, loss=1.55] 


epoch 192: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 193: 0batch [00:00, ?batch/s]

epoch 192: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 193: 272batch [00:54,  4.96batch/s, loss=1.39] 


epoch 193: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 194: 0batch [00:00, ?batch/s]

epoch 193: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 194: 272batch [00:52,  5.16batch/s, loss=1.67] 


epoch 194: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 195: 0batch [00:00, ?batch/s]

epoch 194: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 195: 272batch [00:53,  5.05batch/s, loss=1.63] 


epoch 195: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 195: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 196: 272batch [00:55,  4.89batch/s, loss=1.62] 


epoch 196: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 197: 0batch [00:00, ?batch/s]

epoch 196: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 197: 272batch [00:52,  5.18batch/s, loss=1.78] 


epoch 197: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 198: 0batch [00:00, ?batch/s]

epoch 197: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 198: 272batch [00:49,  5.54batch/s, loss=1.75] 


epoch 198: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 199: 0batch [00:00, ?batch/s, loss=0.381]

epoch 198: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 199: 272batch [00:50,  5.40batch/s, loss=1.61] 


epoch 199: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 200: 0batch [00:00, ?batch/s, loss=0.437]

epoch 199: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 200: 272batch [00:49,  5.52batch/s, loss=1.65] 


epoch 200: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 200: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 201: 272batch [00:49,  5.52batch/s, loss=1.64] 


epoch 201: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 202: 0batch [00:00, ?batch/s, loss=0.417]

epoch 201: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 202: 272batch [00:49,  5.52batch/s, loss=1.64] 


epoch 202: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 203: 0batch [00:00, ?batch/s, loss=0.418]

epoch 202: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 203: 272batch [00:49,  5.54batch/s, loss=1.62] 


epoch 203: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 204: 0batch [00:00, ?batch/s]

epoch 203: avg test  loss 0.59, bar  test loss 0.017, col  test loss 0.012


Epoch 204: 272batch [00:49,  5.55batch/s, loss=1.71] 


epoch 204: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 205: 0batch [00:00, ?batch/s]

epoch 204: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 205: 272batch [00:50,  5.34batch/s, loss=1.58] 


epoch 205: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 205: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 206: 272batch [00:52,  5.16batch/s, loss=1.48] 


epoch 206: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 207: 0batch [00:00, ?batch/s, loss=0.411]

epoch 206: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 207: 272batch [00:51,  5.31batch/s, loss=1.55] 


epoch 207: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 208: 0batch [00:00, ?batch/s]

epoch 207: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 208: 272batch [00:51,  5.27batch/s, loss=1.56] 


epoch 208: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 209: 0batch [00:00, ?batch/s]

epoch 208: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 209: 272batch [00:50,  5.34batch/s, loss=1.49] 


epoch 209: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 210: 0batch [00:00, ?batch/s, loss=0.43]

epoch 209: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 210: 272batch [00:48,  5.65batch/s, loss=1.61] 


epoch 210: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 210: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 211: 272batch [00:47,  5.68batch/s, loss=1.47] 


epoch 211: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 212: 0batch [00:00, ?batch/s, loss=0.42]

epoch 211: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 212: 272batch [00:47,  5.68batch/s, loss=1.58] 


epoch 212: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 213: 0batch [00:00, ?batch/s, loss=0.425]

epoch 212: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 213: 272batch [00:47,  5.67batch/s, loss=1.54] 


epoch 213: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 214: 0batch [00:00, ?batch/s]

epoch 213: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 214: 272batch [00:48,  5.66batch/s, loss=1.54] 


epoch 214: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 215: 0batch [00:00, ?batch/s, loss=0.464]

epoch 214: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 215: 272batch [00:48,  5.64batch/s, loss=1.37] 


epoch 215: avg train loss 0.43, bar train loss 0.009, col train loss 0.011
epoch 215: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 216: 272batch [00:48,  5.62batch/s, loss=1.61] 


epoch 216: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 217: 0batch [00:00, ?batch/s, loss=0.424]

epoch 216: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 217: 272batch [00:48,  5.63batch/s, loss=1.46] 


epoch 217: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 218: 0batch [00:00, ?batch/s, loss=0.416]

epoch 217: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 218: 272batch [00:48,  5.60batch/s, loss=1.59] 


epoch 218: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 219: 0batch [00:00, ?batch/s]

epoch 218: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 219: 272batch [00:48,  5.61batch/s, loss=1.58] 


epoch 219: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 220: 0batch [00:00, ?batch/s, loss=0.386]

epoch 219: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 220: 272batch [00:48,  5.59batch/s, loss=1.67] 


epoch 220: avg train loss 0.43, bar train loss 0.009, col train loss 0.011
epoch 220: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 221: 272batch [00:48,  5.57batch/s, loss=1.5]  


epoch 221: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 222: 0batch [00:00, ?batch/s, loss=0.423]

epoch 221: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 222: 272batch [00:48,  5.58batch/s, loss=1.67] 


epoch 222: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 223: 0batch [00:00, ?batch/s, loss=0.387]

epoch 222: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 223: 272batch [00:48,  5.57batch/s, loss=1.45] 


epoch 223: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 224: 0batch [00:00, ?batch/s, loss=0.458]

epoch 223: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 224: 272batch [00:49,  5.55batch/s, loss=1.61] 


epoch 224: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 225: 0batch [00:00, ?batch/s, loss=0.436]

epoch 224: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 225: 272batch [00:48,  5.56batch/s, loss=1.69] 


epoch 225: avg train loss 0.43, bar train loss 0.009, col train loss 0.011
epoch 225: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 226: 272batch [00:49,  5.51batch/s, loss=1.48] 


epoch 226: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 227: 0batch [00:00, ?batch/s, loss=0.456]

epoch 226: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 227: 272batch [00:49,  5.54batch/s, loss=1.51] 


epoch 227: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 228: 0batch [00:00, ?batch/s]

epoch 227: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 228: 272batch [00:49,  5.53batch/s, loss=1.42] 


epoch 228: avg train loss 0.44, bar train loss 0.009, col train loss 0.011


Epoch 229: 0batch [00:00, ?batch/s, loss=0.416]

epoch 228: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 229: 272batch [00:49,  5.54batch/s, loss=1.53] 


epoch 229: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 230: 0batch [00:00, ?batch/s]

epoch 229: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 230: 272batch [00:49,  5.54batch/s, loss=1.53] 


epoch 230: avg train loss 0.43, bar train loss 0.009, col train loss 0.011
epoch 230: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 231: 272batch [00:49,  5.50batch/s, loss=1.76] 


epoch 231: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 232: 0batch [00:00, ?batch/s]

epoch 231: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 232: 272batch [00:49,  5.52batch/s, loss=1.76] 


epoch 232: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 233: 0batch [00:00, ?batch/s]

epoch 232: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 233: 272batch [00:49,  5.51batch/s, loss=1.47] 


epoch 233: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 234: 0batch [00:00, ?batch/s, loss=0.428]

epoch 233: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 234: 272batch [00:49,  5.49batch/s, loss=1.64] 


epoch 234: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 235: 0batch [00:00, ?batch/s]

epoch 234: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 235: 272batch [00:50,  5.39batch/s, loss=1.46] 


epoch 235: avg train loss 0.44, bar train loss 0.009, col train loss 0.011
epoch 235: avg test  loss 0.59, bar  test loss 0.016, col  test loss 0.012


Epoch 236: 272batch [00:52,  5.18batch/s, loss=1.67] 


epoch 236: avg train loss 0.43, bar train loss 0.009, col train loss 0.011


Epoch 237: 0batch [00:00, ?batch/s]

epoch 236: avg test  loss 0.58, bar  test loss 0.016, col  test loss 0.012


Epoch 237: 133batch [00:26,  5.10batch/s, loss=0.454]


KeyboardInterrupt: 

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')