In [2]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [3]:
%load_ext tensorboard

In [76]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/HVAEFCP1/tensorboard")

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
DEVICE

device(type='cuda')

# Model Classes

In [7]:
class diva_args:

    def __init__(self, z1_dim=1000, z2_dim=1000, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 600, h2_dim = 600, number_components = 500,
                 beta=1, rec_alpha = 100, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [8]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_col, x_bar = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len2.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar2.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col2.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x_len, x_col, x_bar, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,5,200), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]) ,2*j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 2*j+1] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len2.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col2.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar2.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [65]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=1200, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        
        self.fc1 = nn.Sequential(nn.Linear(2*h_dim, dim0, bias=False),  
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(dim0, dim2, bias=False),  
                                 nn.ReLU())
#         self.fc3 = nn.Sequential(nn.Linear(dim1, dim2, bias=False),
#                                  nn.ReLU())
        
        # Predicting length and color of each bar
        #self.color = nn.Sequential(nn.Conv1d(1,5, kernel_size=1, bias=False), 
                                  # nn.Softmax(dim=1))
        # Predicting color of each bar
        self.color_bar_black = nn.Linear(dim2,200)
        self.color_bar_reddd = nn.Linear(dim2,200)
        self.color_bar_bluee = nn.Linear(dim2,200)
        self.color_bar_green = nn.Linear(dim2,200)
        self.color_bar_yelow = nn.Linear(dim2,200)
        
        # Predicting the length of each bar
        self.length_bar_top = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        self.length_bar_bot = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        #self.length_bar_scale = nn.Sequential(nn.Conv1d(100, 1, kernel_size = 3, padding = 'same', bias=False), nn.Sigmoid())
        
        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(dim2,400), nn.ReLU(),nn.Linear(400,1), nn.Softplus())
        #self.length_RNA_scale = nn.Sequential(nn.Linear(400,1, bias=False), nn.Sigmoid())
        
    def forward(self, z1, mz2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        h = self.fc1(h)
        h = nn.Dropout(0.2)(h)
        h = self.fc2(h)
        h = nn.Dropout(0.2)(h)
        len_RNA = self.length_RNA(h)
        len_bar = torch.cat([self.length_bar_top(h)[:,None,:],self.length_bar_bot(h)[:,None,:]], dim=1) 
#         h = self.fc3(h)
#         h = nn.Dropout(0.3)(h)
        
        len_RNA_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_RNA_sc = torch.exp(self.length_RNA_scale(h))
        
        
        
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_bar_sc = torch.exp(self.length_bar_scale(h))
        
#        col_bar = self.color(h[:,None,:])
        black = self.color_bar_black(h)
        black = nn.Sigmoid()(black)[:,None,:]
        reddd = self.color_bar_reddd(h)[:,None,:]
        bluee = self.color_bar_bluee(h)[:,None,:]
        green = self.color_bar_green(h)[:,None,:]
        yelow = self.color_bar_yelow(h)[:,None,:]
        
        col = torch.cat([reddd,bluee,green,yelow], dim=1)
        
        col = nn.Softmax(dim=1)(col)
        
        
        col_bar = torch.cat([black,col*((1-black).repeat(1,4,1))], 1)
        
        return len_RNA, len_RNA_sc, len_bar, len_bar_sc, col_bar, pz1_m, pz1_s

    def reconstruct_image(self, len_RNA, var_RNA, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()
        var_RNA = var_RNA.cpu().numpy()
        #.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = int(np.round(np.random.normal(loc=len_RNA[i], scale=var_RNA[i])))
            else:
                limit = int(np.round(len_RNA[i]))
            limit = min(100, limit)
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1 = np.argmax(col_bar[i,:, 2*j])
                    _col_bar_2 = np.argmax(col_bar[i,:, 2*j+1])
                
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output


In [66]:
int(np.round(3.7, 0))
int(3.7)

3

In [67]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [68]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [69]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(2560, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(2560, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(2560, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

#         torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         torch.nn.init.xavier_uniform_(self.encoder[3].weight)
#         torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         self.fc11[0].bias.data.zero_()
#         torch.nn.init.xavier_uniform_(self.fc12[0].weight)
#         self.fc12[0].bias.data.zero_()
    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 2560)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 2560)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [70]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [71]:
enc = qz(128,10,10,10,500,400,400)
enc(torch.zeros((1,3,25,100)), torch.zeros((1,200)))
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 1, 10]           --
│    └─Conv2d: 2-1                       [1, 32, 23, 98]           864
│    └─ReLU: 2-2                         [1, 32, 23, 98]           --
│    └─Conv2d: 2-3                       [1, 64, 21, 96]           18,432
│    └─ReLU: 2-4                         [1, 64, 21, 96]           --
│    └─MaxPool2d: 2-5                    [1, 64, 10, 48]           --
│    └─Conv2d: 2-6                       [1, 128, 8, 46]           73,728
│    └─ReLU: 2-7                         [1, 128, 8, 46]           --
│    └─MaxPool2d: 2-8                    [1, 128, 4, 23]           --
│    └─Conv2d: 2-9                       [1, 256, 2, 21]           294,912
│    └─ReLU: 2-10                        [1, 256, 2, 21]           --
│    └─MaxPool2d: 2-11                   [1, 256, 1, 10]           --
├

In [72]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [79]:
class HVAE(nn.Module):
    def __init__(self, args):
        super(HVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, pz1_m, pz1_s = self.px(z1, mz2)
        
        return x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, out_len, out_bar, out_col):
        
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m)
        
        # Reconstruction Loss
        mask = 1 - F.one_hot(torch.round(out_len).to(torch.int64)*2-1, 200).cumsum(dim=1)[:,None,:]
        mask1 = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,:]).repeat(1,2,1)

        x_col = mask.repeat(1,5,1)*x_col
        
#         x_gap = x_col[:,0,:]
#         x_nuc = x_col[:,1:,:]
                
        
        
        dist_len = dist.Normal(x_len, x_len_scale+1e-7)
        log_len = dist_len.log_prob(out_len[:,None]).mean()
         
        mse_bar = ((((x_bar - out_bar)**2)*mask1).sum(dim=(1,2))/(mask1.sum(dim=(1,2)))).sum()#.detach().item()
        
        max_bar = torch.argmax(x_col, dim=1)
        #print(max_bar.shape, out_col.shape, mask.shape)
        acc_bar = (((max_bar==torch.argmax(out_col, dim=1))*mask).sum().float()/(mask.sum(1))).sum()
        acc_bar2 = (((max_bar==torch.argsort(out_col, dim=1)[:,1,:])*mask).sum().float()/(mask.sum(1))).sum() + acc_bar
        #print(acc_bar.shape)
        RE_len = -log_len
        RE_bar = mse_bar#-log_bar
        RE_col = F.cross_entropy(x_col, out_col, reduction='sum')
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        #print(KL_p_z1.shape,KL_p_z2.shape,KL_q_z1.shape,KL_q_z2.shape)
        
        
#         self.lpz1.append(KL_p_z1.detach().item())
#         self.lpz2.append(KL_p_z2.detach().item())
#         self.lqz1.append(KL_q_z1.detach().item())
#         self.lqz2.append(KL_q_z2.detach().item())
        
#         self.bar.append(RE_bar.detach())
#         self.col.append(RE_col.detach())
#         self.len.append(RE_len.detach())
        
        
        return self.rec_alpha * RE_len \
                  + self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_len, RE_col, mse_bar, acc_bar, acc_bar2
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [80]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [81]:
default_args = diva_args()
enc = HVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
HVAE                                     --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [77]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [78]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [24]:
len(RNA_dataset)

34721

In [82]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    acc_bar2 = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar, m) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss, mse, acc, acc2 = model.loss_function(d.float(), x.float(), y.float(), m.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        mse_bar += mse
        acc_bar += acc
        acc_bar2 += acc2
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    acc_bar2 /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2

In [83]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    mse_bar = 0
    acc_bar = 0   
    acc_bar2 = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar, m) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, len_loss, col_loss, mse, acc, acc2 = model.loss_function(d.float(), x.float(), y.float(),m.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
            mse_bar += mse
            acc_bar += acc
            acc_bar2 += acc2
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_len_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    acc_bar2 /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2
  

In [97]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col, mtr, atr, atr2 = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test, mte, ate, ate2 = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train-top1': atr, 'test-top1':ate, 'train-top2': atr2, 'test-top2':ate2}, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [98]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x_1, x_1var, x_2, x_2var, x_3 ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m)
        out = diva.px.reconstruct_image(x_1, x_1var, x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [99]:
DEVICE

device(type='cuda')

## Model Training

In [100]:
default_args = diva_args(prewarmup=0, number_components=50)

In [101]:
diva = HVAE(default_args).to(DEVICE)

In [102]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE10/checkpoints/905.pth'))

In [103]:
train_loader = DataLoader(RNA_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=64)

In [104]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [105]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [106]:
writer.flush()

In [107]:
diva.rec_gamma = 3

In [95]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/HVAEFCP1/tensorboard/"

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 0, save_folder="HVAEFCP1",save_interval=5)

Epoch 1: 543batch [00:36, 14.81batch/s, loss=762]   


epoch 1: avg train loss 818.13, bar train loss 11.790, len train loss 0.638, col train loss 171.595
epoch 1: avg test  loss 730.94, bar  test loss 9.307, len  test loss 0.301, col  test loss 169.814


Epoch 2: 543batch [00:36, 14.75batch/s, loss=774]


epoch 2: avg train loss 706.39, bar train loss 8.756, len train loss 0.193, col train loss 168.990


Epoch 3: 2batch [00:00, 13.89batch/s, loss=744]

epoch 2: avg test  loss 692.03, bar  test loss 8.329, len  test loss 0.153, col  test loss 168.410


Epoch 3: 543batch [00:37, 14.50batch/s, loss=743]


epoch 3: avg train loss 688.16, bar train loss 8.104, len train loss 0.167, col train loss 167.908


Epoch 4: 2batch [00:00, 14.39batch/s, loss=630]

epoch 3: avg test  loss 674.53, bar  test loss 7.725, len  test loss 0.132, col  test loss 166.789


Epoch 4: 543batch [00:37, 14.42batch/s, loss=693]


epoch 4: avg train loss 670.80, bar train loss 7.247, len train loss 0.207, col train loss 166.047


Epoch 5: 2batch [00:00, 14.18batch/s, loss=660]

epoch 4: avg test  loss 651.84, bar  test loss 6.765, len  test loss 0.129, col  test loss 165.391


Epoch 5: 543batch [00:37, 14.61batch/s, loss=611]


epoch 5: avg train loss 650.31, bar train loss 6.537, len train loss 0.170, col train loss 165.018
epoch 5: avg test  loss 642.37, bar  test loss 6.364, len  test loss 0.138, col  test loss 164.764


Epoch 6: 543batch [00:36, 14.72batch/s, loss=642]


epoch 6: avg train loss 641.45, bar train loss 6.157, len train loss 0.171, col train loss 164.480


Epoch 7: 2batch [00:00, 14.39batch/s, loss=644]

epoch 6: avg test  loss 633.91, bar  test loss 5.987, len  test loss 0.140, col  test loss 164.221


Epoch 7: 543batch [00:37, 14.53batch/s, loss=682]


epoch 7: avg train loss 633.40, bar train loss 5.870, len train loss 0.161, col train loss 163.988


Epoch 8: 2batch [00:00, 14.39batch/s, loss=645]

epoch 7: avg test  loss 628.93, bar  test loss 5.739, len  test loss 0.155, col  test loss 163.680


Epoch 8: 543batch [00:37, 14.60batch/s, loss=598]


epoch 8: avg train loss 626.43, bar train loss 5.623, len train loss 0.154, col train loss 163.606


Epoch 9: 2batch [00:00, 13.61batch/s, loss=616]

epoch 8: avg test  loss 620.14, bar  test loss 5.497, len  test loss 0.121, col  test loss 163.501


Epoch 9: 543batch [00:37, 14.50batch/s, loss=652]


epoch 9: avg train loss 620.35, bar train loss 5.438, len train loss 0.142, col train loss 163.251


Epoch 10: 2batch [00:00, 13.89batch/s, loss=619]

epoch 9: avg test  loss 627.33, bar  test loss 5.454, len  test loss 0.211, col  test loss 163.145


Epoch 10: 543batch [00:37, 14.49batch/s, loss=601]


epoch 10: avg train loss 616.21, bar train loss 5.291, len train loss 0.140, col train loss 162.945
epoch 10: avg test  loss 612.13, bar  test loss 5.216, len  test loss 0.118, col  test loss 162.933


Epoch 11: 543batch [00:37, 14.58batch/s, loss=623]


epoch 11: avg train loss 611.44, bar train loss 5.137, len train loss 0.132, col train loss 162.655


Epoch 12: 2batch [00:00, 14.39batch/s, loss=609]

epoch 11: avg test  loss 610.95, bar  test loss 5.095, len  test loss 0.130, col  test loss 162.675


Epoch 12: 543batch [00:36, 14.68batch/s, loss=590]


epoch 12: avg train loss 606.94, bar train loss 5.000, len train loss 0.122, col train loss 162.386


Epoch 13: 2batch [00:00, 14.29batch/s, loss=605]

epoch 12: avg test  loss 606.62, bar  test loss 4.986, len  test loss 0.123, col  test loss 162.496


Epoch 13: 543batch [00:36, 14.72batch/s, loss=632]


epoch 13: avg train loss 604.54, bar train loss 4.904, len train loss 0.125, col train loss 162.112


Epoch 14: 2batch [00:00, 14.49batch/s, loss=557]

epoch 13: avg test  loss 606.59, bar  test loss 4.889, len  test loss 0.134, col  test loss 162.200


Epoch 14: 543batch [00:36, 14.73batch/s, loss=565]


epoch 14: avg train loss 601.00, bar train loss 4.804, len train loss 0.116, col train loss 161.905


Epoch 15: 2batch [00:00, 13.99batch/s, loss=604]

epoch 14: avg test  loss 623.56, bar  test loss 4.847, len  test loss 0.330, col  test loss 162.079


Epoch 15: 543batch [00:37, 14.57batch/s, loss=678]


epoch 15: avg train loss 597.11, bar train loss 4.691, len train loss 0.106, col train loss 161.696
epoch 15: avg test  loss 595.60, bar  test loss 4.682, len  test loss 0.085, col  test loss 161.909


Epoch 16: 543batch [00:36, 14.72batch/s, loss=570]


epoch 16: avg train loss 593.15, bar train loss 4.594, len train loss 0.092, col train loss 161.485


Epoch 17: 2batch [00:00, 14.39batch/s, loss=634]

epoch 16: avg test  loss 592.92, bar  test loss 4.617, len  test loss 0.080, col  test loss 161.661


Epoch 17: 543batch [00:37, 14.62batch/s, loss=570]


epoch 17: avg train loss 591.01, bar train loss 4.528, len train loss 0.089, col train loss 161.307


Epoch 18: 2batch [00:00, 14.60batch/s, loss=591]

epoch 17: avg test  loss 595.07, bar  test loss 4.530, len  test loss 0.124, col  test loss 161.550


Epoch 18: 543batch [00:36, 14.72batch/s, loss=578]


epoch 18: avg train loss 588.62, bar train loss 4.469, len train loss 0.083, col train loss 161.118


Epoch 19: 2batch [00:00, 14.18batch/s, loss=568]

epoch 18: avg test  loss 588.97, bar  test loss 4.433, len  test loss 0.076, col  test loss 161.401


Epoch 19: 543batch [00:37, 14.67batch/s, loss=614]


epoch 19: avg train loss 586.60, bar train loss 4.404, len train loss 0.080, col train loss 160.980


Epoch 20: 2batch [00:00, 14.18batch/s, loss=606]

epoch 19: avg test  loss 587.59, bar  test loss 4.403, len  test loss 0.072, col  test loss 161.381


Epoch 20: 543batch [00:36, 14.70batch/s, loss=609]


epoch 20: avg train loss 584.29, bar train loss 4.341, len train loss 0.074, col train loss 160.816
epoch 20: avg test  loss 587.76, bar  test loss 4.487, len  test loss 0.074, col  test loss 161.212


Epoch 21: 543batch [00:37, 14.65batch/s, loss=606]


epoch 21: avg train loss 582.26, bar train loss 4.292, len train loss 0.069, col train loss 160.635


Epoch 22: 2batch [00:00, 14.29batch/s, loss=604]

epoch 21: avg test  loss 584.81, bar  test loss 4.329, len  test loss 0.066, col  test loss 161.129


Epoch 22: 543batch [00:36, 14.71batch/s, loss=535]


epoch 22: avg train loss 581.00, bar train loss 4.248, len train loss 0.070, col train loss 160.492


Epoch 23: 2batch [00:00, 14.60batch/s, loss=559]

epoch 22: avg test  loss 583.31, bar  test loss 4.346, len  test loss 0.064, col  test loss 161.033


Epoch 23: 543batch [00:36, 14.70batch/s, loss=594]


epoch 23: avg train loss 579.39, bar train loss 4.204, len train loss 0.068, col train loss 160.320


Epoch 24: 2batch [00:00, 14.29batch/s, loss=567]

epoch 23: avg test  loss 582.57, bar  test loss 4.290, len  test loss 0.069, col  test loss 160.897


Epoch 24: 543batch [00:36, 14.72batch/s, loss=618]


epoch 24: avg train loss 578.13, bar train loss 4.155, len train loss 0.069, col train loss 160.208


Epoch 25: 2batch [00:00, 14.29batch/s, loss=565]

epoch 24: avg test  loss 581.31, bar  test loss 4.248, len  test loss 0.064, col  test loss 160.783


Epoch 25: 543batch [00:36, 14.70batch/s, loss=579]


epoch 25: avg train loss 577.00, bar train loss 4.129, len train loss 0.067, col train loss 160.050
epoch 25: avg test  loss 580.37, bar  test loss 4.249, len  test loss 0.059, col  test loss 160.738


Epoch 26: 543batch [00:36, 14.71batch/s, loss=576]


epoch 26: avg train loss 575.70, bar train loss 4.088, len train loss 0.066, col train loss 159.891


Epoch 27: 2batch [00:00, 14.29batch/s, loss=585]

epoch 26: avg test  loss 579.77, bar  test loss 4.215, len  test loss 0.064, col  test loss 160.594


Epoch 27: 543batch [00:37, 14.61batch/s, loss=595]


epoch 27: avg train loss 574.32, bar train loss 4.055, len train loss 0.064, col train loss 159.757


Epoch 28: 2batch [00:00, 14.29batch/s, loss=586]

epoch 27: avg test  loss 579.60, bar  test loss 4.203, len  test loss 0.060, col  test loss 160.447


Epoch 28: 543batch [00:37, 14.66batch/s, loss=614]


epoch 28: avg train loss 573.26, bar train loss 4.025, len train loss 0.064, col train loss 159.611


Epoch 29: 2batch [00:00, 14.29batch/s, loss=550]

epoch 28: avg test  loss 578.57, bar  test loss 4.167, len  test loss 0.063, col  test loss 160.485


Epoch 29: 543batch [00:36, 14.70batch/s, loss=549]


epoch 29: avg train loss 572.64, bar train loss 3.996, len train loss 0.065, col train loss 159.516


Epoch 30: 2batch [00:00, 14.18batch/s, loss=543]

epoch 29: avg test  loss 578.37, bar  test loss 4.117, len  test loss 0.070, col  test loss 160.340


Epoch 30: 543batch [00:37, 14.67batch/s, loss=592]


epoch 30: avg train loss 571.07, bar train loss 3.967, len train loss 0.061, col train loss 159.377
epoch 30: avg test  loss 576.97, bar  test loss 4.116, len  test loss 0.059, col  test loss 160.185


Epoch 31: 543batch [00:36, 14.69batch/s, loss=547]


epoch 31: avg train loss 570.34, bar train loss 3.928, len train loss 0.063, col train loss 159.272


Epoch 32: 2batch [00:00, 14.39batch/s, loss=560]

epoch 31: avg test  loss 576.28, bar  test loss 4.074, len  test loss 0.058, col  test loss 160.167


Epoch 32: 543batch [00:36, 14.69batch/s, loss=630]


epoch 32: avg train loss 569.35, bar train loss 3.906, len train loss 0.061, col train loss 159.191


Epoch 33: 2batch [00:00, 14.49batch/s, loss=569]

epoch 32: avg test  loss 577.19, bar  test loss 4.124, len  test loss 0.070, col  test loss 160.214


Epoch 33: 543batch [00:37, 14.53batch/s, loss=569]


epoch 33: avg train loss 568.67, bar train loss 3.880, len train loss 0.062, col train loss 159.088


Epoch 34: 2batch [00:00, 14.18batch/s, loss=560]

epoch 33: avg test  loss 575.14, bar  test loss 4.086, len  test loss 0.057, col  test loss 160.227


Epoch 34: 543batch [00:36, 14.69batch/s, loss=650]


epoch 34: avg train loss 567.97, bar train loss 3.869, len train loss 0.060, col train loss 159.018


Epoch 35: 2batch [00:00, 14.49batch/s, loss=570]

epoch 34: avg test  loss 575.05, bar  test loss 4.028, len  test loss 0.061, col  test loss 160.105


Epoch 35: 543batch [00:36, 14.68batch/s, loss=602]


epoch 35: avg train loss 566.75, bar train loss 3.837, len train loss 0.058, col train loss 158.887
epoch 35: avg test  loss 574.82, bar  test loss 4.018, len  test loss 0.062, col  test loss 160.001


Epoch 36: 543batch [00:37, 14.65batch/s, loss=599]


epoch 36: avg train loss 566.44, bar train loss 3.821, len train loss 0.060, col train loss 158.820


Epoch 37: 2batch [00:00, 14.49batch/s, loss=552]

epoch 36: avg test  loss 573.78, bar  test loss 4.016, len  test loss 0.061, col  test loss 159.966


Epoch 37: 543batch [00:36, 14.69batch/s, loss=547]


epoch 37: avg train loss 565.45, bar train loss 3.795, len train loss 0.058, col train loss 158.741


Epoch 38: 2batch [00:00, 14.39batch/s, loss=543]

epoch 37: avg test  loss 572.82, bar  test loss 3.990, len  test loss 0.059, col  test loss 159.908


Epoch 38: 543batch [00:36, 14.68batch/s, loss=576]


epoch 38: avg train loss 565.09, bar train loss 3.786, len train loss 0.059, col train loss 158.680


Epoch 39: 2batch [00:00, 14.29batch/s, loss=536]

epoch 38: avg test  loss 573.45, bar  test loss 3.997, len  test loss 0.063, col  test loss 159.946


Epoch 39: 543batch [00:37, 14.62batch/s, loss=620]


epoch 39: avg train loss 564.29, bar train loss 3.766, len train loss 0.057, col train loss 158.606


Epoch 40: 2batch [00:00, 14.29batch/s, loss=569]

epoch 39: avg test  loss 574.87, bar  test loss 3.983, len  test loss 0.068, col  test loss 159.913


Epoch 40: 543batch [00:36, 14.71batch/s, loss=565]


epoch 40: avg train loss 563.81, bar train loss 3.751, len train loss 0.057, col train loss 158.545
epoch 40: avg test  loss 572.91, bar  test loss 3.983, len  test loss 0.064, col  test loss 159.799


Epoch 41: 543batch [00:36, 14.68batch/s, loss=544]


epoch 41: avg train loss 562.98, bar train loss 3.728, len train loss 0.057, col train loss 158.448


Epoch 42: 2batch [00:00, 14.49batch/s, loss=557]

epoch 41: avg test  loss 573.68, bar  test loss 3.964, len  test loss 0.057, col  test loss 159.746


Epoch 42: 543batch [00:37, 14.67batch/s, loss=537]


epoch 42: avg train loss 562.88, bar train loss 3.722, len train loss 0.058, col train loss 158.393


Epoch 43: 2batch [00:00, 14.18batch/s, loss=569]

epoch 42: avg test  loss 572.80, bar  test loss 3.968, len  test loss 0.064, col  test loss 159.842


Epoch 43: 543batch [00:37, 14.65batch/s, loss=554]


epoch 43: avg train loss 561.78, bar train loss 3.692, len train loss 0.056, col train loss 158.322


Epoch 44: 2batch [00:00, 14.49batch/s, loss=526]

epoch 43: avg test  loss 571.60, bar  test loss 3.961, len  test loss 0.054, col  test loss 159.773


Epoch 44: 543batch [00:37, 14.66batch/s, loss=563]


epoch 44: avg train loss 561.12, bar train loss 3.678, len train loss 0.055, col train loss 158.251


Epoch 45: 2batch [00:00, 14.49batch/s, loss=571]

epoch 44: avg test  loss 571.09, bar  test loss 3.948, len  test loss 0.057, col  test loss 159.767


Epoch 45: 543batch [00:36, 14.68batch/s, loss=573]


epoch 45: avg train loss 560.67, bar train loss 3.667, len train loss 0.055, col train loss 158.183
epoch 45: avg test  loss 571.53, bar  test loss 3.953, len  test loss 0.060, col  test loss 159.816


Epoch 46: 543batch [00:37, 14.63batch/s, loss=608]


epoch 46: avg train loss 560.43, bar train loss 3.661, len train loss 0.056, col train loss 158.130


Epoch 47: 2batch [00:00, 14.39batch/s, loss=572]

epoch 46: avg test  loss 572.18, bar  test loss 3.933, len  test loss 0.058, col  test loss 159.788


Epoch 47: 543batch [00:37, 14.65batch/s, loss=571]


epoch 47: avg train loss 560.01, bar train loss 3.650, len train loss 0.054, col train loss 158.087


Epoch 48: 2batch [00:00, 14.29batch/s, loss=586]

epoch 47: avg test  loss 570.94, bar  test loss 3.916, len  test loss 0.055, col  test loss 159.660


Epoch 48: 543batch [00:37, 14.65batch/s, loss=594]


epoch 48: avg train loss 559.42, bar train loss 3.628, len train loss 0.055, col train loss 158.039


Epoch 49: 2batch [00:00, 14.60batch/s, loss=587]

epoch 48: avg test  loss 570.30, bar  test loss 3.895, len  test loss 0.055, col  test loss 159.664


Epoch 49: 543batch [00:37, 14.66batch/s, loss=552]


epoch 49: avg train loss 558.88, bar train loss 3.618, len train loss 0.054, col train loss 157.965


Epoch 50: 2batch [00:00, 14.60batch/s, loss=547]

epoch 49: avg test  loss 571.28, bar  test loss 3.927, len  test loss 0.065, col  test loss 159.716


Epoch 50: 543batch [00:37, 14.65batch/s, loss=539]


epoch 50: avg train loss 558.39, bar train loss 3.604, len train loss 0.053, col train loss 157.915
epoch 50: avg test  loss 570.25, bar  test loss 3.949, len  test loss 0.057, col  test loss 159.665


Epoch 51: 543batch [00:37, 14.63batch/s, loss=562]


epoch 51: avg train loss 558.34, bar train loss 3.595, len train loss 0.055, col train loss 157.883


Epoch 52: 2batch [00:00, 14.39batch/s, loss=586]

epoch 51: avg test  loss 572.91, bar  test loss 3.937, len  test loss 0.078, col  test loss 159.659


Epoch 52: 543batch [00:37, 14.64batch/s, loss=556]


epoch 52: avg train loss 557.41, bar train loss 3.580, len train loss 0.053, col train loss 157.799


Epoch 53: 2batch [00:00, 14.18batch/s, loss=562]

epoch 52: avg test  loss 570.46, bar  test loss 3.916, len  test loss 0.058, col  test loss 159.723


Epoch 53: 543batch [00:37, 14.63batch/s, loss=558]


epoch 53: avg train loss 557.24, bar train loss 3.572, len train loss 0.053, col train loss 157.765


Epoch 54: 2batch [00:00, 14.29batch/s, loss=582]

epoch 53: avg test  loss 569.64, bar  test loss 3.895, len  test loss 0.055, col  test loss 159.675


Epoch 54: 543batch [00:37, 14.62batch/s, loss=568]


epoch 54: avg train loss 556.90, bar train loss 3.567, len train loss 0.053, col train loss 157.714


Epoch 55: 2batch [00:00, 14.39batch/s, loss=552]

epoch 54: avg test  loss 571.04, bar  test loss 3.955, len  test loss 0.055, col  test loss 159.779


Epoch 55: 543batch [00:37, 14.63batch/s, loss=570]


epoch 55: avg train loss 556.47, bar train loss 3.553, len train loss 0.054, col train loss 157.643
epoch 55: avg test  loss 570.65, bar  test loss 3.916, len  test loss 0.058, col  test loss 159.640


Epoch 56: 543batch [00:37, 14.62batch/s, loss=603]


epoch 56: avg train loss 555.77, bar train loss 3.535, len train loss 0.053, col train loss 157.559


Epoch 57: 2batch [00:00, 14.29batch/s, loss=546]

epoch 56: avg test  loss 569.86, bar  test loss 3.929, len  test loss 0.057, col  test loss 159.725


Epoch 57: 543batch [00:37, 14.64batch/s, loss=569]


epoch 57: avg train loss 555.28, bar train loss 3.527, len train loss 0.051, col train loss 157.522


Epoch 58: 2batch [00:00, 14.29batch/s, loss=559]

epoch 57: avg test  loss 569.50, bar  test loss 3.910, len  test loss 0.058, col  test loss 159.568


Epoch 58: 543batch [00:37, 14.63batch/s, loss=600]


epoch 58: avg train loss 554.97, bar train loss 3.519, len train loss 0.052, col train loss 157.453


Epoch 59: 2batch [00:00, 14.18batch/s, loss=539]

epoch 58: avg test  loss 570.64, bar  test loss 3.916, len  test loss 0.059, col  test loss 159.693


Epoch 59: 543batch [00:37, 14.62batch/s, loss=569]


epoch 59: avg train loss 554.80, bar train loss 3.512, len train loss 0.053, col train loss 157.419


Epoch 60: 2batch [00:00, 14.39batch/s, loss=555]

epoch 59: avg test  loss 570.12, bar  test loss 3.924, len  test loss 0.056, col  test loss 159.632


Epoch 60: 543batch [00:37, 14.59batch/s, loss=509]


epoch 60: avg train loss 554.15, bar train loss 3.496, len train loss 0.052, col train loss 157.351
epoch 60: avg test  loss 569.01, bar  test loss 3.882, len  test loss 0.052, col  test loss 159.543


Epoch 61: 543batch [00:37, 14.48batch/s, loss=569]


epoch 61: avg train loss 554.08, bar train loss 3.493, len train loss 0.052, col train loss 157.348


Epoch 62: 2batch [00:00, 12.42batch/s, loss=572]

epoch 61: avg test  loss 569.44, bar  test loss 3.917, len  test loss 0.055, col  test loss 159.573


Epoch 62: 543batch [00:37, 14.60batch/s, loss=570]


epoch 62: avg train loss 553.90, bar train loss 3.490, len train loss 0.053, col train loss 157.272


Epoch 63: 2batch [00:00, 14.39batch/s, loss=565]

epoch 62: avg test  loss 572.28, bar  test loss 3.924, len  test loss 0.063, col  test loss 159.645


Epoch 63: 543batch [00:37, 14.51batch/s, loss=553]


epoch 63: avg train loss 553.18, bar train loss 3.473, len train loss 0.052, col train loss 157.207


Epoch 64: 2batch [00:00, 13.79batch/s, loss=539]

epoch 63: avg test  loss 570.21, bar  test loss 3.921, len  test loss 0.061, col  test loss 159.707


Epoch 64: 543batch [00:37, 14.40batch/s, loss=609]


epoch 64: avg train loss 553.07, bar train loss 3.470, len train loss 0.052, col train loss 157.167


Epoch 65: 2batch [00:00, 14.49batch/s, loss=536]

epoch 64: avg test  loss 568.91, bar  test loss 3.877, len  test loss 0.056, col  test loss 159.568


Epoch 65: 543batch [00:36, 14.79batch/s, loss=579]


epoch 65: avg train loss 552.69, bar train loss 3.463, len train loss 0.051, col train loss 157.121
epoch 65: avg test  loss 568.45, bar  test loss 3.883, len  test loss 0.056, col  test loss 159.519


Epoch 66: 543batch [00:36, 14.93batch/s, loss=569]


epoch 66: avg train loss 552.38, bar train loss 3.452, len train loss 0.052, col train loss 157.089


Epoch 67: 2batch [00:00, 14.29batch/s, loss=536]

epoch 66: avg test  loss 569.03, bar  test loss 3.904, len  test loss 0.059, col  test loss 159.595


Epoch 67: 543batch [00:36, 14.94batch/s, loss=518]


epoch 67: avg train loss 551.86, bar train loss 3.440, len train loss 0.051, col train loss 157.038


Epoch 68: 2batch [00:00, 14.60batch/s, loss=549]

epoch 67: avg test  loss 569.75, bar  test loss 3.880, len  test loss 0.067, col  test loss 159.576


Epoch 68: 543batch [00:36, 14.92batch/s, loss=556]


epoch 68: avg train loss 551.68, bar train loss 3.434, len train loss 0.052, col train loss 156.974


Epoch 69: 2batch [00:00, 14.49batch/s, loss=553]

epoch 68: avg test  loss 569.06, bar  test loss 3.889, len  test loss 0.055, col  test loss 159.491


Epoch 69: 543batch [00:36, 14.71batch/s, loss=578]


epoch 69: avg train loss 551.49, bar train loss 3.440, len train loss 0.051, col train loss 156.929


Epoch 70: 2batch [00:00, 13.89batch/s, loss=568]

epoch 69: avg test  loss 569.19, bar  test loss 3.888, len  test loss 0.057, col  test loss 159.503


Epoch 70: 543batch [00:37, 14.67batch/s, loss=582]


epoch 70: avg train loss 551.15, bar train loss 3.430, len train loss 0.051, col train loss 156.892
epoch 70: avg test  loss 568.92, bar  test loss 3.877, len  test loss 0.055, col  test loss 159.477


Epoch 71: 543batch [00:37, 14.41batch/s, loss=548]


epoch 71: avg train loss 550.91, bar train loss 3.430, len train loss 0.051, col train loss 156.787


Epoch 72: 2batch [00:00, 14.39batch/s, loss=567]

epoch 71: avg test  loss 569.81, bar  test loss 3.914, len  test loss 0.064, col  test loss 159.528


Epoch 72: 543batch [00:37, 14.51batch/s, loss=571]


epoch 72: avg train loss 550.71, bar train loss 3.417, len train loss 0.052, col train loss 156.790


Epoch 73: 2batch [00:00, 13.89batch/s, loss=544]

epoch 72: avg test  loss 568.98, bar  test loss 3.883, len  test loss 0.054, col  test loss 159.591


Epoch 73: 543batch [00:36, 14.68batch/s, loss=539]


epoch 73: avg train loss 550.10, bar train loss 3.410, len train loss 0.050, col train loss 156.714


Epoch 74: 2batch [00:00, 14.60batch/s, loss=537]

epoch 73: avg test  loss 569.63, bar  test loss 3.892, len  test loss 0.061, col  test loss 159.531


Epoch 74: 543batch [00:37, 14.56batch/s, loss=504]


epoch 74: avg train loss 549.83, bar train loss 3.396, len train loss 0.051, col train loss 156.693


Epoch 75: 2batch [00:00, 14.49batch/s, loss=567]

epoch 74: avg test  loss 568.51, bar  test loss 3.884, len  test loss 0.055, col  test loss 159.497


Epoch 75: 543batch [00:36, 14.91batch/s, loss=526]


epoch 75: avg train loss 549.56, bar train loss 3.398, len train loss 0.050, col train loss 156.604
epoch 75: avg test  loss 570.50, bar  test loss 3.869, len  test loss 0.059, col  test loss 159.577


Epoch 76: 543batch [00:36, 14.87batch/s, loss=592]


epoch 76: avg train loss 549.15, bar train loss 3.385, len train loss 0.051, col train loss 156.565


Epoch 77: 2batch [00:00, 14.60batch/s, loss=572]

epoch 76: avg test  loss 569.23, bar  test loss 3.903, len  test loss 0.059, col  test loss 159.579


Epoch 77: 543batch [00:36, 14.88batch/s, loss=592]


epoch 77: avg train loss 548.94, bar train loss 3.383, len train loss 0.051, col train loss 156.492


Epoch 78: 2batch [00:00, 14.60batch/s, loss=547]

epoch 77: avg test  loss 569.58, bar  test loss 3.923, len  test loss 0.061, col  test loss 159.732


Epoch 78: 543batch [00:36, 14.90batch/s, loss=486]


epoch 78: avg train loss 548.36, bar train loss 3.366, len train loss 0.050, col train loss 156.453


Epoch 79: 2batch [00:00, 14.71batch/s, loss=553]

epoch 78: avg test  loss 568.51, bar  test loss 3.904, len  test loss 0.059, col  test loss 159.499


Epoch 79: 543batch [00:36, 14.90batch/s, loss=550]


epoch 79: avg train loss 548.32, bar train loss 3.367, len train loss 0.051, col train loss 156.410


Epoch 80: 2batch [00:00, 14.60batch/s, loss=571]

epoch 79: avg test  loss 569.09, bar  test loss 3.911, len  test loss 0.058, col  test loss 159.503


Epoch 80: 543batch [00:36, 14.88batch/s, loss=557]


epoch 80: avg train loss 548.30, bar train loss 3.373, len train loss 0.050, col train loss 156.370
epoch 80: avg test  loss 569.13, bar  test loss 3.896, len  test loss 0.056, col  test loss 159.587


Epoch 81: 543batch [00:36, 14.83batch/s, loss=566]


epoch 81: avg train loss 547.83, bar train loss 3.359, len train loss 0.050, col train loss 156.342


Epoch 82: 2batch [00:00, 14.60batch/s, loss=563]

epoch 81: avg test  loss 569.02, bar  test loss 3.897, len  test loss 0.059, col  test loss 159.525


Epoch 82: 543batch [00:36, 14.74batch/s, loss=533]


epoch 82: avg train loss 547.61, bar train loss 3.354, len train loss 0.050, col train loss 156.289


Epoch 83: 2batch [00:00, 12.27batch/s, loss=554]

epoch 82: avg test  loss 569.00, bar  test loss 3.909, len  test loss 0.058, col  test loss 159.558


Epoch 83: 543batch [00:37, 14.34batch/s, loss=603]


epoch 83: avg train loss 547.54, bar train loss 3.354, len train loss 0.050, col train loss 156.289


Epoch 84: 2batch [00:00, 14.60batch/s, loss=570]

epoch 83: avg test  loss 568.59, bar  test loss 3.895, len  test loss 0.058, col  test loss 159.450


Epoch 84: 543batch [00:38, 14.15batch/s, loss=559]


epoch 84: avg train loss 547.19, bar train loss 3.351, len train loss 0.050, col train loss 156.188


Epoch 85: 2batch [00:00, 13.51batch/s, loss=556]

epoch 84: avg test  loss 568.73, bar  test loss 3.909, len  test loss 0.056, col  test loss 159.495


Epoch 85: 543batch [00:38, 14.13batch/s, loss=580]


epoch 85: avg train loss 547.04, bar train loss 3.345, len train loss 0.050, col train loss 156.176
epoch 85: avg test  loss 570.15, bar  test loss 3.929, len  test loss 0.059, col  test loss 159.597


Epoch 86: 543batch [00:36, 14.69batch/s, loss=578]


epoch 86: avg train loss 546.75, bar train loss 3.340, len train loss 0.050, col train loss 156.129


Epoch 87: 2batch [00:00, 14.49batch/s, loss=539]

epoch 86: avg test  loss 569.75, bar  test loss 3.917, len  test loss 0.061, col  test loss 159.645


Epoch 87: 543batch [00:38, 14.18batch/s, loss=529]


epoch 87: avg train loss 546.58, bar train loss 3.331, len train loss 0.051, col train loss 156.115


Epoch 88: 2batch [00:00, 13.89batch/s, loss=538]

epoch 87: avg test  loss 569.31, bar  test loss 3.897, len  test loss 0.059, col  test loss 159.462


Epoch 88: 543batch [00:39, 13.81batch/s, loss=613]


epoch 88: avg train loss 546.04, bar train loss 3.325, len train loss 0.050, col train loss 156.020


Epoch 89: 2batch [00:00, 11.83batch/s, loss=522]

epoch 88: avg test  loss 570.26, bar  test loss 3.915, len  test loss 0.058, col  test loss 159.481


Epoch 89: 543batch [00:40, 13.53batch/s, loss=555]


epoch 89: avg train loss 546.22, bar train loss 3.326, len train loss 0.051, col train loss 156.036


Epoch 90: 2batch [00:00, 13.42batch/s, loss=545]

epoch 89: avg test  loss 569.74, bar  test loss 3.890, len  test loss 0.059, col  test loss 159.462


Epoch 90: 543batch [00:40, 13.28batch/s, loss=556]


epoch 90: avg train loss 545.99, bar train loss 3.321, len train loss 0.050, col train loss 156.020
epoch 90: avg test  loss 568.89, bar  test loss 3.918, len  test loss 0.058, col  test loss 159.517


Epoch 91: 543batch [00:40, 13.56batch/s, loss=532]


epoch 91: avg train loss 545.27, bar train loss 3.304, len train loss 0.050, col train loss 155.913


Epoch 92: 2batch [00:00, 13.07batch/s, loss=525]

epoch 91: avg test  loss 569.71, bar  test loss 3.909, len  test loss 0.060, col  test loss 159.465


Epoch 92: 543batch [00:47, 11.46batch/s, loss=552]


epoch 92: avg train loss 545.34, bar train loss 3.313, len train loss 0.050, col train loss 155.898


Epoch 93: 0batch [00:00, ?batch/s]

epoch 92: avg test  loss 568.97, bar  test loss 3.919, len  test loss 0.059, col  test loss 159.603


Epoch 93: 543batch [00:51, 10.54batch/s, loss=590]


epoch 93: avg train loss 545.19, bar train loss 3.304, len train loss 0.051, col train loss 155.864


Epoch 94: 0batch [00:00, ?batch/s, loss=525]

epoch 93: avg test  loss 569.46, bar  test loss 3.897, len  test loss 0.056, col  test loss 159.529


Epoch 94: 543batch [00:46, 11.74batch/s, loss=555]


epoch 94: avg train loss 544.72, bar train loss 3.292, len train loss 0.050, col train loss 155.832


Epoch 95: 2batch [00:00, 12.58batch/s, loss=525]

epoch 94: avg test  loss 569.43, bar  test loss 3.916, len  test loss 0.058, col  test loss 159.511


Epoch 95: 543batch [00:42, 12.88batch/s, loss=547]


epoch 95: avg train loss 544.52, bar train loss 3.291, len train loss 0.050, col train loss 155.782
epoch 95: avg test  loss 570.51, bar  test loss 3.926, len  test loss 0.067, col  test loss 159.579


Epoch 96: 543batch [00:41, 12.95batch/s, loss=567]


epoch 96: avg train loss 544.18, bar train loss 3.287, len train loss 0.049, col train loss 155.714


Epoch 97: 0batch [00:00, ?batch/s, loss=517]

epoch 96: avg test  loss 569.66, bar  test loss 3.915, len  test loss 0.058, col  test loss 159.691


Epoch 97: 543batch [00:41, 12.94batch/s, loss=510]


epoch 97: avg train loss 544.11, bar train loss 3.286, len train loss 0.049, col train loss 155.696


Epoch 98: 2batch [00:00, 12.74batch/s, loss=521]

epoch 97: avg test  loss 569.25, bar  test loss 3.921, len  test loss 0.060, col  test loss 159.580


Epoch 98: 543batch [00:41, 12.94batch/s, loss=548]


epoch 98: avg train loss 543.91, bar train loss 3.281, len train loss 0.050, col train loss 155.641


Epoch 99: 2batch [00:00, 11.30batch/s, loss=540]

epoch 98: avg test  loss 570.51, bar  test loss 3.947, len  test loss 0.065, col  test loss 159.653


Epoch 99: 543batch [00:42, 12.88batch/s, loss=496]


epoch 99: avg train loss 543.67, bar train loss 3.275, len train loss 0.049, col train loss 155.630


Epoch 100: 0batch [00:00, ?batch/s]

epoch 99: avg test  loss 569.29, bar  test loss 3.911, len  test loss 0.059, col  test loss 159.576


Epoch 100: 543batch [00:42, 12.67batch/s, loss=505]


epoch 100: avg train loss 543.35, bar train loss 3.271, len train loss 0.049, col train loss 155.587
epoch 100: avg test  loss 569.41, bar  test loss 3.924, len  test loss 0.059, col  test loss 159.555


Epoch 101: 543batch [00:43, 12.53batch/s, loss=524]


epoch 101: avg train loss 543.29, bar train loss 3.273, len train loss 0.049, col train loss 155.561


Epoch 102: 2batch [00:00, 12.74batch/s, loss=527]

epoch 101: avg test  loss 571.20, bar  test loss 3.971, len  test loss 0.069, col  test loss 159.749


Epoch 102: 543batch [00:42, 12.82batch/s, loss=563]


epoch 102: avg train loss 543.23, bar train loss 3.271, len train loss 0.050, col train loss 155.507


Epoch 103: 0batch [00:00, ?batch/s]

epoch 102: avg test  loss 569.68, bar  test loss 3.937, len  test loss 0.059, col  test loss 159.554


Epoch 103: 543batch [00:55,  9.76batch/s, loss=526]


epoch 103: avg train loss 543.17, bar train loss 3.269, len train loss 0.050, col train loss 155.513


Epoch 104: 1batch [00:00,  9.71batch/s, loss=549]

epoch 103: avg test  loss 569.87, bar  test loss 3.946, len  test loss 0.061, col  test loss 159.686


Epoch 104: 543batch [00:51, 10.60batch/s, loss=555]


epoch 104: avg train loss 543.15, bar train loss 3.271, len train loss 0.049, col train loss 155.498


Epoch 105: 0batch [00:00, ?batch/s, loss=543]

epoch 104: avg test  loss 570.24, bar  test loss 3.953, len  test loss 0.060, col  test loss 159.577


Epoch 105: 543batch [00:44, 12.21batch/s, loss=567]


epoch 105: avg train loss 542.61, bar train loss 3.254, len train loss 0.050, col train loss 155.431
epoch 105: avg test  loss 571.07, bar  test loss 3.915, len  test loss 0.070, col  test loss 159.570


Epoch 106: 543batch [00:40, 13.40batch/s, loss=577]


epoch 106: avg train loss 542.43, bar train loss 3.253, len train loss 0.049, col train loss 155.401


Epoch 107: 0batch [00:00, ?batch/s, loss=528]

epoch 106: avg test  loss 569.65, bar  test loss 3.937, len  test loss 0.059, col  test loss 159.552


Epoch 107: 543batch [00:57,  9.44batch/s, loss=520]


epoch 107: avg train loss 542.24, bar train loss 3.255, len train loss 0.049, col train loss 155.366


Epoch 108: 1batch [00:00,  7.87batch/s, loss=573]

epoch 107: avg test  loss 570.18, bar  test loss 3.943, len  test loss 0.057, col  test loss 159.637


Epoch 108: 543batch [00:52, 10.33batch/s, loss=526]


epoch 108: avg train loss 541.94, bar train loss 3.248, len train loss 0.049, col train loss 155.320


Epoch 109: 0batch [00:00, ?batch/s, loss=544]

epoch 108: avg test  loss 570.07, bar  test loss 3.943, len  test loss 0.057, col  test loss 159.648


Epoch 109: 543batch [00:57,  9.47batch/s, loss=538]


epoch 109: avg train loss 542.08, bar train loss 3.255, len train loss 0.049, col train loss 155.301


Epoch 110: 1batch [00:00,  8.33batch/s, loss=559]

epoch 109: avg test  loss 569.87, bar  test loss 3.957, len  test loss 0.059, col  test loss 159.637


Epoch 110: 543batch [00:57,  9.52batch/s, loss=565]


epoch 110: avg train loss 541.63, bar train loss 3.243, len train loss 0.049, col train loss 155.243
epoch 110: avg test  loss 570.50, bar  test loss 3.943, len  test loss 0.067, col  test loss 159.577


Epoch 111: 543batch [00:56,  9.63batch/s, loss=548]


epoch 111: avg train loss 541.74, bar train loss 3.241, len train loss 0.049, col train loss 155.274


Epoch 112: 0batch [00:00, ?batch/s, loss=529]

epoch 111: avg test  loss 569.94, bar  test loss 3.972, len  test loss 0.060, col  test loss 159.638


Epoch 112: 543batch [00:49, 11.03batch/s, loss=505]


epoch 112: avg train loss 541.12, bar train loss 3.230, len train loss 0.048, col train loss 155.212


Epoch 113: 0batch [00:00, ?batch/s]

epoch 112: avg test  loss 570.37, bar  test loss 3.971, len  test loss 0.057, col  test loss 159.657


Epoch 113: 543batch [00:53, 10.10batch/s, loss=618]


epoch 113: avg train loss 541.24, bar train loss 3.232, len train loss 0.049, col train loss 155.199


Epoch 114: 0batch [00:00, ?batch/s, loss=564]

epoch 113: avg test  loss 570.35, bar  test loss 3.966, len  test loss 0.059, col  test loss 159.725


Epoch 114: 543batch [00:53, 10.15batch/s, loss=527]


epoch 114: avg train loss 541.23, bar train loss 3.234, len train loss 0.049, col train loss 155.162


Epoch 115: 1batch [00:00,  8.62batch/s, loss=548]

epoch 114: avg test  loss 571.65, bar  test loss 4.029, len  test loss 0.062, col  test loss 159.697


Epoch 115: 543batch [00:54, 10.03batch/s, loss=531]


epoch 115: avg train loss 540.95, bar train loss 3.231, len train loss 0.049, col train loss 155.112
epoch 115: avg test  loss 571.50, bar  test loss 3.947, len  test loss 0.063, col  test loss 159.640


Epoch 116: 543batch [00:51, 10.47batch/s, loss=546]


epoch 116: avg train loss 540.86, bar train loss 3.227, len train loss 0.050, col train loss 155.081


Epoch 117: 0batch [00:00, ?batch/s, loss=563]

epoch 116: avg test  loss 570.73, bar  test loss 3.940, len  test loss 0.062, col  test loss 159.718


Epoch 117: 543batch [00:55,  9.76batch/s, loss=545]


epoch 117: avg train loss 540.87, bar train loss 3.230, len train loss 0.050, col train loss 155.096


Epoch 118: 0batch [00:00, ?batch/s, loss=490]

epoch 117: avg test  loss 570.44, bar  test loss 3.981, len  test loss 0.060, col  test loss 159.683


Epoch 118: 543batch [00:51, 10.46batch/s, loss=543]


epoch 118: avg train loss 540.49, bar train loss 3.225, len train loss 0.049, col train loss 155.025


Epoch 119: 0batch [00:00, ?batch/s, loss=541]

epoch 118: avg test  loss 570.79, bar  test loss 3.990, len  test loss 0.061, col  test loss 159.702


Epoch 119: 543batch [00:52, 10.41batch/s, loss=565]


epoch 119: avg train loss 540.30, bar train loss 3.219, len train loss 0.049, col train loss 154.975


Epoch 120: 0batch [00:00, ?batch/s, loss=545]

epoch 119: avg test  loss 570.83, bar  test loss 3.959, len  test loss 0.057, col  test loss 159.765


Epoch 120: 543batch [00:52, 10.26batch/s, loss=526]


epoch 120: avg train loss 539.99, bar train loss 3.216, len train loss 0.048, col train loss 154.968
epoch 120: avg test  loss 570.39, bar  test loss 3.961, len  test loss 0.058, col  test loss 159.722


Epoch 121: 543batch [00:58,  9.33batch/s, loss=606]


epoch 121: avg train loss 540.05, bar train loss 3.215, len train loss 0.049, col train loss 154.950


Epoch 122: 0batch [00:00, ?batch/s]

epoch 121: avg test  loss 570.83, bar  test loss 3.973, len  test loss 0.060, col  test loss 159.646


Epoch 122: 543batch [00:57,  9.50batch/s, loss=505]


epoch 122: avg train loss 539.91, bar train loss 3.216, len train loss 0.048, col train loss 154.923


Epoch 123: 1batch [00:00,  7.35batch/s, loss=557]

epoch 122: avg test  loss 572.39, bar  test loss 3.971, len  test loss 0.067, col  test loss 159.725


Epoch 123: 543batch [00:52, 10.39batch/s, loss=574]


epoch 123: avg train loss 539.66, bar train loss 3.207, len train loss 0.048, col train loss 154.905


Epoch 124: 0batch [00:00, ?batch/s, loss=561]

epoch 123: avg test  loss 571.03, bar  test loss 3.984, len  test loss 0.060, col  test loss 159.702


Epoch 124: 543batch [00:59,  9.11batch/s, loss=560]


epoch 124: avg train loss 539.40, bar train loss 3.205, len train loss 0.048, col train loss 154.852


Epoch 125: 1batch [00:00,  8.70batch/s, loss=567]

epoch 124: avg test  loss 571.28, bar  test loss 3.988, len  test loss 0.062, col  test loss 159.801


Epoch 125: 543batch [00:54, 10.02batch/s, loss=615]


epoch 125: avg train loss 539.70, bar train loss 3.210, len train loss 0.049, col train loss 154.880
epoch 125: avg test  loss 571.51, bar  test loss 3.999, len  test loss 0.061, col  test loss 159.802


Epoch 126: 543batch [00:52, 10.38batch/s, loss=517]


epoch 126: avg train loss 539.47, bar train loss 3.204, len train loss 0.049, col train loss 154.849


Epoch 127: 0batch [00:00, ?batch/s, loss=512]

epoch 126: avg test  loss 571.41, bar  test loss 3.992, len  test loss 0.061, col  test loss 159.857


Epoch 127: 543batch [00:55,  9.86batch/s, loss=571]


epoch 127: avg train loss 539.41, bar train loss 3.205, len train loss 0.049, col train loss 154.815


Epoch 128: 1batch [00:00,  8.13batch/s, loss=558]

epoch 127: avg test  loss 572.35, bar  test loss 3.977, len  test loss 0.071, col  test loss 159.684


Epoch 128: 543batch [00:57,  9.39batch/s, loss=548]


epoch 128: avg train loss 539.00, bar train loss 3.195, len train loss 0.049, col train loss 154.779


Epoch 129: 2batch [00:00, 12.42batch/s, loss=520]

epoch 128: avg test  loss 571.32, bar  test loss 3.989, len  test loss 0.063, col  test loss 159.804


Epoch 129: 543batch [00:58,  9.30batch/s, loss=538]


epoch 129: avg train loss 538.86, bar train loss 3.196, len train loss 0.049, col train loss 154.727


Epoch 130: 1batch [00:00,  9.43batch/s, loss=563]

epoch 129: avg test  loss 572.98, bar  test loss 4.015, len  test loss 0.066, col  test loss 159.987


Epoch 130: 543batch [00:59,  9.20batch/s, loss=530]


epoch 130: avg train loss 538.88, bar train loss 3.193, len train loss 0.049, col train loss 154.735
epoch 130: avg test  loss 571.37, bar  test loss 3.999, len  test loss 0.060, col  test loss 159.836


Epoch 131: 543batch [00:53, 10.11batch/s, loss=509]


epoch 131: avg train loss 538.84, bar train loss 3.197, len train loss 0.048, col train loss 154.721


Epoch 132: 2batch [00:00, 11.83batch/s, loss=546]

epoch 131: avg test  loss 572.19, bar  test loss 4.016, len  test loss 0.066, col  test loss 159.914


Epoch 132: 543batch [00:58,  9.34batch/s, loss=549]


epoch 132: avg train loss 538.39, bar train loss 3.185, len train loss 0.049, col train loss 154.664


Epoch 133: 0batch [00:00, ?batch/s, loss=552]

epoch 132: avg test  loss 571.57, bar  test loss 4.021, len  test loss 0.061, col  test loss 159.839


Epoch 133: 543batch [00:51, 10.49batch/s, loss=565]


epoch 133: avg train loss 538.60, bar train loss 3.195, len train loss 0.049, col train loss 154.642


Epoch 134: 0batch [00:00, ?batch/s, loss=556]

epoch 133: avg test  loss 571.10, bar  test loss 4.000, len  test loss 0.060, col  test loss 159.715


Epoch 134: 543batch [00:54, 10.00batch/s, loss=517]


epoch 134: avg train loss 538.30, bar train loss 3.185, len train loss 0.049, col train loss 154.641


Epoch 135: 0batch [00:00, ?batch/s, loss=519]

epoch 134: avg test  loss 571.56, bar  test loss 4.014, len  test loss 0.059, col  test loss 159.797


Epoch 135: 543batch [00:55,  9.84batch/s, loss=530]


epoch 135: avg train loss 538.22, bar train loss 3.185, len train loss 0.048, col train loss 154.633
epoch 135: avg test  loss 571.31, bar  test loss 4.011, len  test loss 0.058, col  test loss 159.772


Epoch 136: 543batch [00:54, 10.01batch/s, loss=561]


epoch 136: avg train loss 538.16, bar train loss 3.185, len train loss 0.049, col train loss 154.593


Epoch 137: 0batch [00:00, ?batch/s, loss=579]

epoch 136: avg test  loss 572.26, bar  test loss 4.023, len  test loss 0.064, col  test loss 159.849


Epoch 137: 543batch [00:53, 10.23batch/s, loss=547]


epoch 137: avg train loss 537.87, bar train loss 3.179, len train loss 0.048, col train loss 154.565


Epoch 138: 0batch [00:00, ?batch/s, loss=563]

epoch 137: avg test  loss 571.75, bar  test loss 4.022, len  test loss 0.061, col  test loss 159.759


Epoch 138: 543batch [00:53, 10.24batch/s, loss=552]


epoch 138: avg train loss 537.61, bar train loss 3.176, len train loss 0.047, col train loss 154.525


Epoch 139: 1batch [00:00,  8.33batch/s, loss=528]

epoch 138: avg test  loss 572.25, bar  test loss 4.021, len  test loss 0.063, col  test loss 159.821


Epoch 139: 543batch [00:57,  9.40batch/s, loss=525]


epoch 139: avg train loss 538.00, bar train loss 3.189, len train loss 0.049, col train loss 154.517


Epoch 140: 0batch [00:00, ?batch/s]

epoch 139: avg test  loss 572.23, bar  test loss 4.042, len  test loss 0.061, col  test loss 159.844


Epoch 140: 543batch [00:55,  9.84batch/s, loss=535]


epoch 140: avg train loss 537.76, bar train loss 3.179, len train loss 0.049, col train loss 154.506
epoch 140: avg test  loss 571.56, bar  test loss 3.994, len  test loss 0.058, col  test loss 159.806


Epoch 141: 543batch [00:44, 12.26batch/s, loss=574]


epoch 141: avg train loss 537.35, bar train loss 3.169, len train loss 0.048, col train loss 154.460


Epoch 142: 2batch [00:00, 13.24batch/s, loss=555]

epoch 141: avg test  loss 571.45, bar  test loss 4.018, len  test loss 0.060, col  test loss 159.864


Epoch 142: 543batch [00:40, 13.46batch/s, loss=515]


epoch 142: avg train loss 537.53, bar train loss 3.178, len train loss 0.049, col train loss 154.436


Epoch 143: 2batch [00:00, 13.42batch/s, loss=549]

epoch 142: avg test  loss 571.81, bar  test loss 4.038, len  test loss 0.060, col  test loss 159.759


Epoch 143: 543batch [00:40, 13.50batch/s, loss=520]


epoch 143: avg train loss 537.03, bar train loss 3.168, len train loss 0.048, col train loss 154.395


Epoch 144: 0batch [00:00, ?batch/s, loss=593]

epoch 143: avg test  loss 571.97, bar  test loss 4.025, len  test loss 0.060, col  test loss 159.824


Epoch 144: 543batch [00:47, 11.50batch/s, loss=498]


epoch 144: avg train loss 536.92, bar train loss 3.166, len train loss 0.047, col train loss 154.397


Epoch 145: 0batch [00:00, ?batch/s, loss=564]

epoch 144: avg test  loss 572.17, bar  test loss 4.032, len  test loss 0.060, col  test loss 159.790


Epoch 145: 543batch [00:55,  9.77batch/s, loss=538]


epoch 145: avg train loss 536.89, bar train loss 3.169, len train loss 0.048, col train loss 154.366
epoch 145: avg test  loss 573.51, bar  test loss 4.057, len  test loss 0.066, col  test loss 159.980


Epoch 146: 543batch [00:40, 13.46batch/s, loss=574]


epoch 146: avg train loss 536.94, bar train loss 3.167, len train loss 0.048, col train loss 154.370


Epoch 147: 2batch [00:00, 13.42batch/s, loss=551]

epoch 146: avg test  loss 573.49, bar  test loss 4.065, len  test loss 0.069, col  test loss 159.916


Epoch 147: 543batch [00:42, 12.65batch/s, loss=554]


epoch 147: avg train loss 536.71, bar train loss 3.161, len train loss 0.048, col train loss 154.356


Epoch 148: 2batch [00:00, 12.05batch/s, loss=536]

epoch 147: avg test  loss 573.37, bar  test loss 4.062, len  test loss 0.069, col  test loss 159.923


Epoch 148: 543batch [00:55,  9.83batch/s, loss=538]


epoch 148: avg train loss 536.68, bar train loss 3.160, len train loss 0.048, col train loss 154.353


Epoch 149: 0batch [00:00, ?batch/s, loss=552]

epoch 148: avg test  loss 572.63, bar  test loss 4.048, len  test loss 0.063, col  test loss 159.969


Epoch 149: 543batch [00:54, 10.00batch/s, loss=551]


epoch 149: avg train loss 536.79, bar train loss 3.156, len train loss 0.049, col train loss 154.363


Epoch 150: 2batch [00:00, 11.83batch/s, loss=517]

epoch 149: avg test  loss 572.46, bar  test loss 4.049, len  test loss 0.062, col  test loss 159.812


Epoch 150: 543batch [00:56,  9.53batch/s, loss=549]


epoch 150: avg train loss 536.55, bar train loss 3.160, len train loss 0.048, col train loss 154.307
epoch 150: avg test  loss 572.74, bar  test loss 4.043, len  test loss 0.063, col  test loss 159.956


Epoch 151: 543batch [00:54,  9.88batch/s, loss=569]


epoch 151: avg train loss 536.26, bar train loss 3.155, len train loss 0.048, col train loss 154.265


Epoch 152: 2batch [00:00, 13.42batch/s, loss=562]

epoch 151: avg test  loss 572.68, bar  test loss 4.059, len  test loss 0.063, col  test loss 159.870


Epoch 152: 543batch [00:40, 13.46batch/s, loss=520]


epoch 152: avg train loss 536.15, bar train loss 3.152, len train loss 0.048, col train loss 154.233


Epoch 153: 2batch [00:00, 13.07batch/s, loss=571]

epoch 152: avg test  loss 572.68, bar  test loss 4.039, len  test loss 0.062, col  test loss 159.877


Epoch 153: 543batch [00:40, 13.40batch/s, loss=544]


epoch 153: avg train loss 536.12, bar train loss 3.154, len train loss 0.048, col train loss 154.236


Epoch 154: 0batch [00:00, ?batch/s]

epoch 153: avg test  loss 572.81, bar  test loss 4.062, len  test loss 0.063, col  test loss 159.894


Epoch 154: 543batch [00:40, 13.43batch/s, loss=487]


epoch 154: avg train loss 535.70, bar train loss 3.143, len train loss 0.048, col train loss 154.165


Epoch 155: 2batch [00:00, 13.33batch/s, loss=531]

epoch 154: avg test  loss 572.78, bar  test loss 4.069, len  test loss 0.061, col  test loss 159.878


Epoch 155: 543batch [00:40, 13.41batch/s, loss=547]


epoch 155: avg train loss 536.14, bar train loss 3.154, len train loss 0.049, col train loss 154.201
epoch 155: avg test  loss 573.53, bar  test loss 4.078, len  test loss 0.065, col  test loss 159.838


Epoch 156: 543batch [00:40, 13.41batch/s, loss=539]


epoch 156: avg train loss 535.42, bar train loss 3.143, len train loss 0.047, col train loss 154.103


Epoch 157: 2batch [00:00, 12.50batch/s, loss=528]

epoch 156: avg test  loss 573.66, bar  test loss 4.091, len  test loss 0.062, col  test loss 160.027


Epoch 157: 543batch [00:40, 13.32batch/s, loss=551]


epoch 157: avg train loss 535.66, bar train loss 3.143, len train loss 0.048, col train loss 154.128


Epoch 158: 2batch [00:00, 13.07batch/s, loss=547]

epoch 157: avg test  loss 572.66, bar  test loss 4.059, len  test loss 0.060, col  test loss 159.976


Epoch 158: 543batch [00:40, 13.34batch/s, loss=564]


epoch 158: avg train loss 535.80, bar train loss 3.152, len train loss 0.049, col train loss 154.116


Epoch 159: 0batch [00:00, ?batch/s, loss=555]

epoch 158: avg test  loss 573.36, bar  test loss 4.067, len  test loss 0.063, col  test loss 160.016


Epoch 159: 543batch [00:40, 13.41batch/s, loss=524]


epoch 159: avg train loss 535.45, bar train loss 3.141, len train loss 0.048, col train loss 154.124


Epoch 160: 2batch [00:00, 13.33batch/s, loss=508]

epoch 159: avg test  loss 572.66, bar  test loss 4.064, len  test loss 0.062, col  test loss 160.019


Epoch 160: 543batch [00:40, 13.40batch/s, loss=555]


epoch 160: avg train loss 535.34, bar train loss 3.144, len train loss 0.047, col train loss 154.103
epoch 160: avg test  loss 573.70, bar  test loss 4.072, len  test loss 0.064, col  test loss 160.081


Epoch 161: 543batch [00:40, 13.36batch/s, loss=497]


epoch 161: avg train loss 535.01, bar train loss 3.134, len train loss 0.048, col train loss 154.027


Epoch 162: 2batch [00:00, 13.07batch/s, loss=557]

epoch 161: avg test  loss 574.55, bar  test loss 4.075, len  test loss 0.072, col  test loss 160.005


Epoch 162: 543batch [00:40, 13.31batch/s, loss=481]


epoch 162: avg train loss 535.20, bar train loss 3.140, len train loss 0.048, col train loss 154.055


Epoch 163: 2batch [00:00, 13.51batch/s, loss=537]

epoch 162: avg test  loss 573.93, bar  test loss 4.076, len  test loss 0.067, col  test loss 159.975


Epoch 163: 543batch [00:40, 13.35batch/s, loss=528]


epoch 163: avg train loss 535.20, bar train loss 3.138, len train loss 0.048, col train loss 154.055


Epoch 164: 2batch [00:00, 13.07batch/s, loss=551]

epoch 163: avg test  loss 573.21, bar  test loss 4.077, len  test loss 0.063, col  test loss 159.994


Epoch 164: 543batch [00:40, 13.33batch/s, loss=553]


epoch 164: avg train loss 535.03, bar train loss 3.131, len train loss 0.048, col train loss 154.041


Epoch 165: 2batch [00:00, 13.25batch/s, loss=510]

epoch 164: avg test  loss 573.32, bar  test loss 4.088, len  test loss 0.061, col  test loss 159.992


Epoch 165: 543batch [00:40, 13.37batch/s, loss=601]


epoch 165: avg train loss 534.96, bar train loss 3.134, len train loss 0.048, col train loss 154.010
epoch 165: avg test  loss 573.58, bar  test loss 4.083, len  test loss 0.064, col  test loss 160.051


Epoch 166: 543batch [00:40, 13.35batch/s, loss=532]


epoch 166: avg train loss 535.11, bar train loss 3.144, len train loss 0.048, col train loss 154.012


Epoch 167: 2batch [00:00, 13.07batch/s, loss=516]

epoch 166: avg test  loss 574.56, bar  test loss 4.110, len  test loss 0.069, col  test loss 160.036


Epoch 167: 543batch [00:40, 13.37batch/s, loss=553]


epoch 167: avg train loss 534.77, bar train loss 3.138, len train loss 0.048, col train loss 153.955


Epoch 168: 0batch [00:00, ?batch/s, loss=559]

epoch 167: avg test  loss 574.47, bar  test loss 4.092, len  test loss 0.075, col  test loss 160.001


Epoch 168: 543batch [00:41, 13.13batch/s, loss=553]


epoch 168: avg train loss 534.63, bar train loss 3.133, len train loss 0.047, col train loss 153.949


Epoch 169: 0batch [00:00, ?batch/s, loss=524]

epoch 168: avg test  loss 573.93, bar  test loss 4.084, len  test loss 0.063, col  test loss 160.189


Epoch 169: 543batch [00:40, 13.26batch/s, loss=517]


epoch 169: avg train loss 534.47, bar train loss 3.124, len train loss 0.048, col train loss 153.936


Epoch 170: 2batch [00:00, 12.99batch/s, loss=515]

epoch 169: avg test  loss 575.17, bar  test loss 4.097, len  test loss 0.079, col  test loss 159.980


Epoch 170: 543batch [00:40, 13.27batch/s, loss=525]


epoch 170: avg train loss 534.69, bar train loss 3.133, len train loss 0.048, col train loss 153.946
epoch 170: avg test  loss 574.16, bar  test loss 4.095, len  test loss 0.068, col  test loss 160.105


Epoch 171: 543batch [00:41, 13.22batch/s, loss=543]


epoch 171: avg train loss 534.42, bar train loss 3.129, len train loss 0.048, col train loss 153.912


Epoch 172: 0batch [00:00, ?batch/s, loss=528]

epoch 171: avg test  loss 573.56, bar  test loss 4.086, len  test loss 0.061, col  test loss 159.989


Epoch 172: 543batch [00:40, 13.26batch/s, loss=557]


epoch 172: avg train loss 533.96, bar train loss 3.116, len train loss 0.047, col train loss 153.867


Epoch 173: 2batch [00:00, 12.66batch/s, loss=494]

epoch 172: avg test  loss 576.07, bar  test loss 4.130, len  test loss 0.080, col  test loss 160.131


Epoch 173: 543batch [00:41, 13.22batch/s, loss=553]


epoch 173: avg train loss 533.94, bar train loss 3.117, len train loss 0.047, col train loss 153.857


Epoch 174: 2batch [00:00, 12.82batch/s, loss=539]

epoch 173: avg test  loss 575.53, bar  test loss 4.122, len  test loss 0.075, col  test loss 160.109


Epoch 174: 543batch [00:41, 13.22batch/s, loss=533]


epoch 174: avg train loss 534.16, bar train loss 3.124, len train loss 0.049, col train loss 153.847


Epoch 175: 2batch [00:00, 13.33batch/s, loss=521]

epoch 174: avg test  loss 574.01, bar  test loss 4.108, len  test loss 0.062, col  test loss 160.021


Epoch 175: 543batch [00:41, 13.20batch/s, loss=553]


epoch 175: avg train loss 533.80, bar train loss 3.116, len train loss 0.047, col train loss 153.829
epoch 175: avg test  loss 574.30, bar  test loss 4.116, len  test loss 0.066, col  test loss 160.097


Epoch 176: 543batch [00:41, 13.18batch/s, loss=575]


epoch 176: avg train loss 533.72, bar train loss 3.116, len train loss 0.048, col train loss 153.801


Epoch 177: 2batch [00:00, 12.90batch/s, loss=542]

epoch 176: avg test  loss 574.40, bar  test loss 4.116, len  test loss 0.066, col  test loss 160.072


Epoch 177: 543batch [00:41, 13.17batch/s, loss=551]


epoch 177: avg train loss 534.16, bar train loss 3.122, len train loss 0.049, col train loss 153.858


Epoch 178: 0batch [00:00, ?batch/s]

epoch 177: avg test  loss 574.17, bar  test loss 4.101, len  test loss 0.069, col  test loss 160.099


Epoch 178: 543batch [00:42, 12.80batch/s, loss=554]


epoch 178: avg train loss 533.64, bar train loss 3.116, len train loss 0.048, col train loss 153.770


Epoch 179: 2batch [00:00, 13.16batch/s, loss=544]

epoch 178: avg test  loss 575.20, bar  test loss 4.114, len  test loss 0.070, col  test loss 160.170


Epoch 179: 543batch [00:41, 13.17batch/s, loss=561]


epoch 179: avg train loss 533.46, bar train loss 3.111, len train loss 0.048, col train loss 153.732


Epoch 180: 2batch [00:00, 13.07batch/s, loss=541]

epoch 179: avg test  loss 574.59, bar  test loss 4.109, len  test loss 0.067, col  test loss 160.183


Epoch 180: 543batch [00:41, 13.17batch/s, loss=508]


epoch 180: avg train loss 533.26, bar train loss 3.110, len train loss 0.047, col train loss 153.718
epoch 180: avg test  loss 574.20, bar  test loss 4.114, len  test loss 0.067, col  test loss 160.048


Epoch 181: 543batch [00:41, 13.14batch/s, loss=519]


epoch 181: avg train loss 533.33, bar train loss 3.112, len train loss 0.047, col train loss 153.727


Epoch 182: 2batch [00:00, 13.24batch/s, loss=499]

epoch 181: avg test  loss 574.03, bar  test loss 4.121, len  test loss 0.062, col  test loss 160.091


Epoch 182: 543batch [00:41, 13.20batch/s, loss=525]


epoch 182: avg train loss 533.23, bar train loss 3.111, len train loss 0.047, col train loss 153.725


Epoch 183: 2batch [00:00, 13.07batch/s, loss=537]

epoch 182: avg test  loss 574.79, bar  test loss 4.126, len  test loss 0.071, col  test loss 160.146


Epoch 183: 543batch [00:41, 13.06batch/s, loss=539]


epoch 183: avg train loss 533.20, bar train loss 3.106, len train loss 0.048, col train loss 153.711


Epoch 184: 2batch [00:00, 12.27batch/s, loss=546]

epoch 183: avg test  loss 574.81, bar  test loss 4.145, len  test loss 0.066, col  test loss 160.203


Epoch 184: 543batch [00:41, 13.12batch/s, loss=548]


epoch 184: avg train loss 533.07, bar train loss 3.110, len train loss 0.047, col train loss 153.643


Epoch 185: 2batch [00:00, 12.99batch/s, loss=537]

epoch 184: avg test  loss 577.06, bar  test loss 4.150, len  test loss 0.085, col  test loss 160.251


Epoch 185: 543batch [00:41, 13.16batch/s, loss=573]


epoch 185: avg train loss 532.79, bar train loss 3.103, len train loss 0.047, col train loss 153.624
epoch 185: avg test  loss 574.80, bar  test loss 4.133, len  test loss 0.068, col  test loss 160.195


Epoch 186: 543batch [00:42, 12.84batch/s, loss=551]


epoch 186: avg train loss 533.38, bar train loss 3.110, len train loss 0.049, col train loss 153.698


Epoch 187: 2batch [00:00, 13.07batch/s, loss=508]

epoch 186: avg test  loss 574.58, bar  test loss 4.117, len  test loss 0.067, col  test loss 160.073


Epoch 187: 543batch [00:41, 13.13batch/s, loss=499]


epoch 187: avg train loss 532.98, bar train loss 3.102, len train loss 0.048, col train loss 153.690


Epoch 188: 1batch [00:00,  7.87batch/s, loss=522]

epoch 187: avg test  loss 574.56, bar  test loss 4.121, len  test loss 0.066, col  test loss 160.061


Epoch 188: 543batch [00:41, 13.10batch/s, loss=555]


epoch 188: avg train loss 532.64, bar train loss 3.097, len train loss 0.047, col train loss 153.639


Epoch 189: 2batch [00:00, 12.74batch/s, loss=561]

epoch 188: avg test  loss 574.77, bar  test loss 4.127, len  test loss 0.066, col  test loss 160.076


Epoch 189: 543batch [00:41, 13.06batch/s, loss=490]


epoch 189: avg train loss 532.67, bar train loss 3.102, len train loss 0.048, col train loss 153.578


Epoch 190: 1batch [00:00,  7.69batch/s, loss=526]

epoch 189: avg test  loss 575.03, bar  test loss 4.146, len  test loss 0.067, col  test loss 160.228


Epoch 190: 543batch [00:41, 13.10batch/s, loss=479]


epoch 190: avg train loss 532.78, bar train loss 3.105, len train loss 0.048, col train loss 153.601
epoch 190: avg test  loss 576.09, bar  test loss 4.148, len  test loss 0.068, col  test loss 160.245


Epoch 191: 543batch [00:41, 12.94batch/s, loss=553]


epoch 191: avg train loss 532.66, bar train loss 3.097, len train loss 0.048, col train loss 153.593


Epoch 192: 2batch [00:00, 12.74batch/s, loss=547]

epoch 191: avg test  loss 575.38, bar  test loss 4.144, len  test loss 0.065, col  test loss 160.230


Epoch 192: 543batch [00:41, 13.07batch/s, loss=541]


epoch 192: avg train loss 532.87, bar train loss 3.107, len train loss 0.048, col train loss 153.586


Epoch 193: 2batch [00:00, 12.58batch/s, loss=570]

epoch 192: avg test  loss 575.31, bar  test loss 4.135, len  test loss 0.070, col  test loss 160.229


Epoch 193: 543batch [00:41, 13.06batch/s, loss=546]


epoch 193: avg train loss 532.36, bar train loss 3.091, len train loss 0.048, col train loss 153.560


Epoch 194: 2batch [00:00, 13.07batch/s, loss=557]

epoch 193: avg test  loss 574.97, bar  test loss 4.137, len  test loss 0.065, col  test loss 160.212


Epoch 194: 543batch [00:41, 13.00batch/s, loss=524]


epoch 194: avg train loss 532.27, bar train loss 3.093, len train loss 0.048, col train loss 153.528


Epoch 195: 2batch [00:00, 12.90batch/s, loss=527]

epoch 194: avg test  loss 575.56, bar  test loss 4.151, len  test loss 0.069, col  test loss 160.192


Epoch 195: 543batch [00:41, 13.06batch/s, loss=539]


epoch 195: avg train loss 532.28, bar train loss 3.089, len train loss 0.048, col train loss 153.542
epoch 195: avg test  loss 574.99, bar  test loss 4.136, len  test loss 0.064, col  test loss 160.174


Epoch 196: 543batch [00:41, 13.04batch/s, loss=538]


epoch 196: avg train loss 532.05, bar train loss 3.092, len train loss 0.048, col train loss 153.458


Epoch 197: 2batch [00:00, 13.16batch/s, loss=527]

epoch 196: avg test  loss 576.35, bar  test loss 4.172, len  test loss 0.070, col  test loss 160.263


Epoch 197: 543batch [00:41, 13.02batch/s, loss=548]


epoch 197: avg train loss 532.55, bar train loss 3.102, len train loss 0.048, col train loss 153.533


Epoch 198: 2batch [00:00, 13.07batch/s, loss=513]

epoch 197: avg test  loss 576.70, bar  test loss 4.170, len  test loss 0.074, col  test loss 160.275


Epoch 198: 543batch [00:41, 13.01batch/s, loss=524]


epoch 198: avg train loss 531.78, bar train loss 3.081, len train loss 0.047, col train loss 153.471


Epoch 199: 2batch [00:00, 12.99batch/s, loss=501]

epoch 198: avg test  loss 575.11, bar  test loss 4.149, len  test loss 0.066, col  test loss 160.140


Epoch 199: 543batch [00:41, 13.03batch/s, loss=572]


epoch 199: avg train loss 532.03, bar train loss 3.086, len train loss 0.048, col train loss 153.481


Epoch 200: 0batch [00:00, ?batch/s, loss=517]

epoch 199: avg test  loss 574.94, bar  test loss 4.128, len  test loss 0.066, col  test loss 160.165


Epoch 200: 543batch [00:41, 13.02batch/s, loss=522]


epoch 200: avg train loss 531.56, bar train loss 3.077, len train loss 0.048, col train loss 153.402
epoch 200: avg test  loss 574.95, bar  test loss 4.145, len  test loss 0.065, col  test loss 160.107


Epoch 201: 543batch [00:41, 12.95batch/s, loss=548]


epoch 201: avg train loss 531.70, bar train loss 3.082, len train loss 0.048, col train loss 153.412


Epoch 202: 2batch [00:00, 12.99batch/s, loss=525]

epoch 201: avg test  loss 575.68, bar  test loss 4.161, len  test loss 0.067, col  test loss 160.208


Epoch 202: 543batch [00:41, 13.01batch/s, loss=519]


epoch 202: avg train loss 531.66, bar train loss 3.083, len train loss 0.047, col train loss 153.439


Epoch 203: 0batch [00:00, ?batch/s, loss=507]

epoch 202: avg test  loss 575.98, bar  test loss 4.173, len  test loss 0.066, col  test loss 160.333


Epoch 203: 543batch [00:41, 13.00batch/s, loss=548]


epoch 203: avg train loss 531.33, bar train loss 3.078, len train loss 0.047, col train loss 153.358


Epoch 204: 2batch [00:00, 12.50batch/s, loss=505]

epoch 203: avg test  loss 575.89, bar  test loss 4.167, len  test loss 0.071, col  test loss 160.240


Epoch 204: 543batch [00:41, 13.00batch/s, loss=530]


epoch 204: avg train loss 531.55, bar train loss 3.082, len train loss 0.048, col train loss 153.392


Epoch 205: 2batch [00:00, 13.16batch/s, loss=522]

epoch 204: avg test  loss 575.93, bar  test loss 4.187, len  test loss 0.067, col  test loss 160.151


Epoch 205: 543batch [00:41, 12.95batch/s, loss=519]


epoch 205: avg train loss 531.15, bar train loss 3.076, len train loss 0.046, col train loss 153.323
epoch 205: avg test  loss 575.72, bar  test loss 4.155, len  test loss 0.069, col  test loss 160.192


Epoch 206: 543batch [00:41, 12.96batch/s, loss=572]


epoch 206: avg train loss 531.42, bar train loss 3.076, len train loss 0.049, col train loss 153.337


Epoch 207: 0batch [00:00, ?batch/s, loss=540]

epoch 206: avg test  loss 576.37, bar  test loss 4.189, len  test loss 0.066, col  test loss 160.326


Epoch 207: 543batch [00:42, 12.70batch/s, loss=509]


epoch 207: avg train loss 531.34, bar train loss 3.077, len train loss 0.049, col train loss 153.303


Epoch 208: 2batch [00:00, 12.74batch/s, loss=569]

epoch 207: avg test  loss 575.74, bar  test loss 4.176, len  test loss 0.068, col  test loss 160.223


Epoch 208: 543batch [00:41, 12.97batch/s, loss=521]


epoch 208: avg train loss 531.18, bar train loss 3.079, len train loss 0.047, col train loss 153.299


Epoch 209: 2batch [00:00, 13.07batch/s, loss=546]

epoch 208: avg test  loss 575.54, bar  test loss 4.153, len  test loss 0.067, col  test loss 160.173


Epoch 209: 543batch [00:42, 12.92batch/s, loss=599]


epoch 209: avg train loss 530.97, bar train loss 3.071, len train loss 0.048, col train loss 153.267


Epoch 210: 0batch [00:00, ?batch/s, loss=555]

epoch 209: avg test  loss 575.34, bar  test loss 4.149, len  test loss 0.068, col  test loss 160.264


Epoch 210: 543batch [00:41, 12.93batch/s, loss=550]


epoch 210: avg train loss 530.83, bar train loss 3.066, len train loss 0.047, col train loss 153.282
epoch 210: avg test  loss 575.97, bar  test loss 4.173, len  test loss 0.071, col  test loss 160.196


Epoch 211: 543batch [00:42, 12.89batch/s, loss=533]


epoch 211: avg train loss 530.72, bar train loss 3.059, len train loss 0.048, col train loss 153.246


Epoch 212: 2batch [00:00, 13.07batch/s, loss=529]

epoch 211: avg test  loss 576.24, bar  test loss 4.182, len  test loss 0.067, col  test loss 160.211


Epoch 212: 543batch [00:41, 12.94batch/s, loss=579]


epoch 212: avg train loss 530.96, bar train loss 3.074, len train loss 0.048, col train loss 153.253


Epoch 213: 2batch [00:00, 12.58batch/s, loss=556]

epoch 212: avg test  loss 575.52, bar  test loss 4.163, len  test loss 0.066, col  test loss 160.191


Epoch 213: 543batch [00:42, 12.84batch/s, loss=526]


epoch 213: avg train loss 530.93, bar train loss 3.072, len train loss 0.048, col train loss 153.251


Epoch 214: 2batch [00:00, 12.90batch/s, loss=563]

epoch 213: avg test  loss 576.20, bar  test loss 4.186, len  test loss 0.066, col  test loss 160.301


Epoch 214: 543batch [00:41, 12.93batch/s, loss=525]


epoch 214: avg train loss 530.44, bar train loss 3.063, len train loss 0.047, col train loss 153.189


Epoch 215: 2batch [00:00, 12.82batch/s, loss=531]

epoch 214: avg test  loss 576.40, bar  test loss 4.182, len  test loss 0.072, col  test loss 160.273


Epoch 215: 543batch [00:42, 12.92batch/s, loss=537]


epoch 215: avg train loss 530.77, bar train loss 3.067, len train loss 0.048, col train loss 153.228
epoch 215: avg test  loss 576.58, bar  test loss 4.188, len  test loss 0.069, col  test loss 160.354


Epoch 216: 543batch [00:41, 12.94batch/s, loss=525]


epoch 216: avg train loss 530.60, bar train loss 3.069, len train loss 0.047, col train loss 153.195


Epoch 217: 2batch [00:00, 12.99batch/s, loss=533]

epoch 216: avg test  loss 576.11, bar  test loss 4.183, len  test loss 0.068, col  test loss 160.214


Epoch 217: 543batch [00:42, 12.90batch/s, loss=592]


epoch 217: avg train loss 530.55, bar train loss 3.062, len train loss 0.047, col train loss 153.216


Epoch 218: 2batch [00:00, 12.74batch/s, loss=505]

epoch 217: avg test  loss 576.51, bar  test loss 4.193, len  test loss 0.072, col  test loss 160.211


Epoch 218: 543batch [00:42, 12.90batch/s, loss=515]


epoch 218: avg train loss 530.40, bar train loss 3.065, len train loss 0.048, col train loss 153.143


Epoch 219: 2batch [00:00, 12.90batch/s, loss=529]

epoch 218: avg test  loss 576.69, bar  test loss 4.176, len  test loss 0.076, col  test loss 160.272


Epoch 219: 543batch [00:42, 12.89batch/s, loss=523]


epoch 219: avg train loss 530.23, bar train loss 3.058, len train loss 0.048, col train loss 153.119


Epoch 220: 2batch [00:00, 12.66batch/s, loss=534]

epoch 219: avg test  loss 577.28, bar  test loss 4.193, len  test loss 0.074, col  test loss 160.425


Epoch 220: 543batch [00:42, 12.86batch/s, loss=558]


epoch 220: avg train loss 530.21, bar train loss 3.054, len train loss 0.048, col train loss 153.137
epoch 220: avg test  loss 576.41, bar  test loss 4.186, len  test loss 0.065, col  test loss 160.278


Epoch 221: 543batch [00:42, 12.85batch/s, loss=552]


epoch 221: avg train loss 530.08, bar train loss 3.059, len train loss 0.047, col train loss 153.093


Epoch 222: 2batch [00:00, 12.90batch/s, loss=543]

epoch 221: avg test  loss 576.80, bar  test loss 4.199, len  test loss 0.067, col  test loss 160.295


Epoch 222: 543batch [00:42, 12.87batch/s, loss=524]


epoch 222: avg train loss 529.93, bar train loss 3.054, len train loss 0.048, col train loss 153.048


Epoch 223: 2batch [00:00, 12.58batch/s, loss=561]

epoch 222: avg test  loss 576.76, bar  test loss 4.197, len  test loss 0.070, col  test loss 160.328


Epoch 223: 543batch [00:42, 12.87batch/s, loss=530]


epoch 223: avg train loss 529.85, bar train loss 3.053, len train loss 0.048, col train loss 153.049


Epoch 224: 2batch [00:00, 12.74batch/s, loss=548]

epoch 223: avg test  loss 576.60, bar  test loss 4.189, len  test loss 0.069, col  test loss 160.398


Epoch 224: 543batch [00:42, 12.86batch/s, loss=519]


epoch 224: avg train loss 529.81, bar train loss 3.050, len train loss 0.047, col train loss 153.059


Epoch 225: 2batch [00:00, 12.50batch/s, loss=515]

epoch 224: avg test  loss 576.91, bar  test loss 4.189, len  test loss 0.071, col  test loss 160.385


Epoch 225: 543batch [00:42, 12.83batch/s, loss=487]


epoch 225: avg train loss 529.90, bar train loss 3.060, len train loss 0.047, col train loss 153.038
epoch 225: avg test  loss 576.88, bar  test loss 4.211, len  test loss 0.070, col  test loss 160.343


Epoch 226: 543batch [00:42, 12.80batch/s, loss=535]


epoch 226: avg train loss 530.02, bar train loss 3.054, len train loss 0.049, col train loss 153.066


Epoch 227: 2batch [00:00, 12.66batch/s, loss=509]

epoch 226: avg test  loss 576.87, bar  test loss 4.209, len  test loss 0.071, col  test loss 160.285


Epoch 227: 543batch [00:42, 12.85batch/s, loss=524]


epoch 227: avg train loss 529.87, bar train loss 3.058, len train loss 0.048, col train loss 153.035


Epoch 228: 2batch [00:00, 12.82batch/s, loss=503]

epoch 227: avg test  loss 577.57, bar  test loss 4.203, len  test loss 0.071, col  test loss 160.313


Epoch 228: 543batch [00:42, 12.78batch/s, loss=523]


epoch 228: avg train loss 529.74, bar train loss 3.053, len train loss 0.047, col train loss 153.034


Epoch 229: 2batch [00:00, 12.82batch/s, loss=543]

epoch 228: avg test  loss 576.35, bar  test loss 4.188, len  test loss 0.070, col  test loss 160.241


Epoch 229: 543batch [00:42, 12.81batch/s, loss=549]


epoch 229: avg train loss 529.47, bar train loss 3.049, len train loss 0.047, col train loss 152.983


Epoch 230: 2batch [00:00, 12.90batch/s, loss=520]

epoch 229: avg test  loss 576.88, bar  test loss 4.206, len  test loss 0.068, col  test loss 160.288


Epoch 230: 543batch [00:42, 12.80batch/s, loss=510]


epoch 230: avg train loss 529.64, bar train loss 3.047, len train loss 0.047, col train loss 153.027
epoch 230: avg test  loss 576.61, bar  test loss 4.186, len  test loss 0.070, col  test loss 160.343


Epoch 231: 543batch [00:42, 12.79batch/s, loss=539]


epoch 231: avg train loss 529.32, bar train loss 3.041, len train loss 0.047, col train loss 153.005


Epoch 232: 2batch [00:00, 12.82batch/s, loss=529]

epoch 231: avg test  loss 577.41, bar  test loss 4.198, len  test loss 0.075, col  test loss 160.334


Epoch 232: 543batch [00:42, 12.78batch/s, loss=559]


epoch 232: avg train loss 529.56, bar train loss 3.051, len train loss 0.048, col train loss 152.967


Epoch 233: 2batch [00:00, 12.66batch/s, loss=523]

epoch 232: avg test  loss 577.13, bar  test loss 4.214, len  test loss 0.072, col  test loss 160.333


Epoch 233: 543batch [00:42, 12.76batch/s, loss=522]


epoch 233: avg train loss 529.65, bar train loss 3.052, len train loss 0.048, col train loss 152.968


Epoch 234: 2batch [00:00, 12.05batch/s, loss=542]

epoch 233: avg test  loss 577.61, bar  test loss 4.221, len  test loss 0.073, col  test loss 160.357


Epoch 234: 543batch [00:43, 12.56batch/s, loss=592]


epoch 234: avg train loss 528.97, bar train loss 3.040, len train loss 0.046, col train loss 152.901


Epoch 235: 2batch [00:00, 12.58batch/s, loss=526]

epoch 234: avg test  loss 577.40, bar  test loss 4.220, len  test loss 0.070, col  test loss 160.366


Epoch 235: 543batch [00:42, 12.75batch/s, loss=535]


epoch 235: avg train loss 529.04, bar train loss 3.039, len train loss 0.047, col train loss 152.906
epoch 235: avg test  loss 577.60, bar  test loss 4.215, len  test loss 0.072, col  test loss 160.359


Epoch 236: 543batch [00:42, 12.73batch/s, loss=532]


epoch 236: avg train loss 529.32, bar train loss 3.045, len train loss 0.047, col train loss 152.953


Epoch 237: 2batch [00:00, 12.35batch/s, loss=547]

epoch 236: avg test  loss 577.91, bar  test loss 4.218, len  test loss 0.076, col  test loss 160.467


Epoch 237: 543batch [00:42, 12.73batch/s, loss=502]


epoch 237: avg train loss 529.33, bar train loss 3.049, len train loss 0.048, col train loss 152.914


Epoch 238: 2batch [00:00, 12.66batch/s, loss=519]

epoch 237: avg test  loss 577.11, bar  test loss 4.208, len  test loss 0.073, col  test loss 160.342


Epoch 238: 543batch [00:42, 12.67batch/s, loss=573]


epoch 238: avg train loss 528.78, bar train loss 3.038, len train loss 0.046, col train loss 152.855


Epoch 239: 2batch [00:00, 12.82batch/s, loss=530]

epoch 238: avg test  loss 577.28, bar  test loss 4.222, len  test loss 0.070, col  test loss 160.333


Epoch 239: 543batch [00:43, 12.62batch/s, loss=546]


epoch 239: avg train loss 529.06, bar train loss 3.039, len train loss 0.048, col train loss 152.876


Epoch 240: 0batch [00:00, ?batch/s]

epoch 239: avg test  loss 577.06, bar  test loss 4.206, len  test loss 0.068, col  test loss 160.379


Epoch 240: 543batch [00:42, 12.69batch/s, loss=512]


epoch 240: avg train loss 528.90, bar train loss 3.041, len train loss 0.047, col train loss 152.858
epoch 240: avg test  loss 577.13, bar  test loss 4.208, len  test loss 0.068, col  test loss 160.412


Epoch 241: 543batch [00:42, 12.72batch/s, loss=521]


epoch 241: avg train loss 528.76, bar train loss 3.035, len train loss 0.046, col train loss 152.868


Epoch 242: 2batch [00:00, 12.82batch/s, loss=519]

epoch 241: avg test  loss 577.32, bar  test loss 4.218, len  test loss 0.070, col  test loss 160.410


Epoch 242: 543batch [00:42, 12.68batch/s, loss=513]


epoch 242: avg train loss 528.75, bar train loss 3.037, len train loss 0.047, col train loss 152.832


Epoch 243: 2batch [00:00, 12.66batch/s, loss=536]

epoch 242: avg test  loss 577.64, bar  test loss 4.226, len  test loss 0.071, col  test loss 160.456


Epoch 243: 543batch [00:42, 12.64batch/s, loss=591]


epoch 243: avg train loss 528.91, bar train loss 3.038, len train loss 0.047, col train loss 152.866


Epoch 244: 2batch [00:00, 12.42batch/s, loss=510]

epoch 243: avg test  loss 577.27, bar  test loss 4.210, len  test loss 0.071, col  test loss 160.395


Epoch 244: 543batch [00:43, 12.54batch/s, loss=536]


epoch 244: avg train loss 528.66, bar train loss 3.039, len train loss 0.046, col train loss 152.809


Epoch 245: 2batch [00:00, 12.74batch/s, loss=522]

epoch 244: avg test  loss 577.54, bar  test loss 4.223, len  test loss 0.071, col  test loss 160.297


Epoch 245: 543batch [00:43, 12.45batch/s, loss=546]


epoch 245: avg train loss 528.77, bar train loss 3.036, len train loss 0.047, col train loss 152.846
epoch 245: avg test  loss 577.98, bar  test loss 4.225, len  test loss 0.076, col  test loss 160.294


Epoch 246: 543batch [00:43, 12.50batch/s, loss=501]


epoch 246: avg train loss 528.57, bar train loss 3.034, len train loss 0.047, col train loss 152.795


Epoch 247: 0batch [00:00, ?batch/s, loss=532]

epoch 246: avg test  loss 578.58, bar  test loss 4.232, len  test loss 0.078, col  test loss 160.366


Epoch 247: 543batch [00:43, 12.58batch/s, loss=548]


epoch 247: avg train loss 528.47, bar train loss 3.035, len train loss 0.047, col train loss 152.749


Epoch 248: 0batch [00:00, ?batch/s]

epoch 247: avg test  loss 577.88, bar  test loss 4.236, len  test loss 0.072, col  test loss 160.454


Epoch 248: 543batch [00:43, 12.43batch/s, loss=551]


epoch 248: avg train loss 528.34, bar train loss 3.035, len train loss 0.046, col train loss 152.738


Epoch 249: 2batch [00:00, 12.12batch/s, loss=530]

epoch 248: avg test  loss 577.63, bar  test loss 4.223, len  test loss 0.074, col  test loss 160.345


Epoch 249: 543batch [00:43, 12.54batch/s, loss=588]


epoch 249: avg train loss 528.21, bar train loss 3.024, len train loss 0.047, col train loss 152.758


Epoch 250: 2batch [00:00, 12.42batch/s, loss=502]

epoch 249: avg test  loss 578.19, bar  test loss 4.240, len  test loss 0.076, col  test loss 160.433


Epoch 250: 543batch [00:43, 12.52batch/s, loss=532]


epoch 250: avg train loss 528.24, bar train loss 3.028, len train loss 0.047, col train loss 152.747
epoch 250: avg test  loss 577.79, bar  test loss 4.223, len  test loss 0.074, col  test loss 160.521


Epoch 251: 543batch [00:43, 12.41batch/s, loss=504]


epoch 251: avg train loss 528.16, bar train loss 3.024, len train loss 0.047, col train loss 152.730


Epoch 252: 2batch [00:00, 12.58batch/s, loss=504]

epoch 251: avg test  loss 578.80, bar  test loss 4.246, len  test loss 0.081, col  test loss 160.439


Epoch 252: 543batch [00:43, 12.51batch/s, loss=555]


epoch 252: avg train loss 528.49, bar train loss 3.031, len train loss 0.048, col train loss 152.762


Epoch 253: 2batch [00:00, 12.58batch/s, loss=523]

epoch 252: avg test  loss 578.61, bar  test loss 4.263, len  test loss 0.074, col  test loss 160.428


Epoch 253: 543batch [00:43, 12.50batch/s, loss=488]


epoch 253: avg train loss 528.25, bar train loss 3.033, len train loss 0.047, col train loss 152.709


Epoch 254: 2batch [00:00, 12.42batch/s, loss=491]

epoch 253: avg test  loss 577.78, bar  test loss 4.229, len  test loss 0.069, col  test loss 160.424


Epoch 254: 543batch [00:43, 12.45batch/s, loss=554]


epoch 254: avg train loss 527.88, bar train loss 3.020, len train loss 0.047, col train loss 152.681


Epoch 255: 2batch [00:00, 12.66batch/s, loss=508]

epoch 254: avg test  loss 578.06, bar  test loss 4.234, len  test loss 0.072, col  test loss 160.441


Epoch 255: 543batch [00:43, 12.38batch/s, loss=510]


epoch 255: avg train loss 528.18, bar train loss 3.029, len train loss 0.046, col train loss 152.728
epoch 255: avg test  loss 577.88, bar  test loss 4.224, len  test loss 0.070, col  test loss 160.461


Epoch 256: 543batch [00:43, 12.43batch/s, loss=524]


epoch 256: avg train loss 527.89, bar train loss 3.023, len train loss 0.046, col train loss 152.663


Epoch 257: 0batch [00:00, ?batch/s, loss=536]

epoch 256: avg test  loss 578.50, bar  test loss 4.254, len  test loss 0.077, col  test loss 160.390


Epoch 257: 543batch [00:43, 12.45batch/s, loss=532]


epoch 257: avg train loss 527.87, bar train loss 3.020, len train loss 0.047, col train loss 152.664


Epoch 258: 2batch [00:00, 12.58batch/s, loss=526]

epoch 257: avg test  loss 578.60, bar  test loss 4.267, len  test loss 0.074, col  test loss 160.433


Epoch 258: 543batch [00:43, 12.42batch/s, loss=569]


epoch 258: avg train loss 528.08, bar train loss 3.029, len train loss 0.047, col train loss 152.701


Epoch 259: 0batch [00:00, ?batch/s, loss=524]

epoch 258: avg test  loss 577.79, bar  test loss 4.238, len  test loss 0.070, col  test loss 160.404


Epoch 259: 543batch [00:43, 12.41batch/s, loss=564]


epoch 259: avg train loss 527.93, bar train loss 3.023, len train loss 0.047, col train loss 152.661


Epoch 260: 2batch [00:00, 12.19batch/s, loss=558]

epoch 259: avg test  loss 577.46, bar  test loss 4.217, len  test loss 0.071, col  test loss 160.396


Epoch 260: 543batch [00:43, 12.41batch/s, loss=527]


epoch 260: avg train loss 527.68, bar train loss 3.018, len train loss 0.047, col train loss 152.642
epoch 260: avg test  loss 578.15, bar  test loss 4.244, len  test loss 0.071, col  test loss 160.471


Epoch 261: 543batch [00:44, 12.34batch/s, loss=528]


epoch 261: avg train loss 527.47, bar train loss 3.017, len train loss 0.047, col train loss 152.569


Epoch 262: 0batch [00:00, ?batch/s, loss=534]

epoch 261: avg test  loss 578.49, bar  test loss 4.255, len  test loss 0.071, col  test loss 160.454


Epoch 262: 543batch [00:43, 12.42batch/s, loss=558]


epoch 262: avg train loss 527.70, bar train loss 3.017, len train loss 0.047, col train loss 152.627


Epoch 263: 2batch [00:00, 12.12batch/s, loss=519]

epoch 262: avg test  loss 579.19, bar  test loss 4.284, len  test loss 0.075, col  test loss 160.472


Epoch 263: 543batch [00:43, 12.40batch/s, loss=572]


epoch 263: avg train loss 527.57, bar train loss 3.025, len train loss 0.046, col train loss 152.589


Epoch 264: 2batch [00:00, 12.50batch/s, loss=485]

epoch 263: avg test  loss 577.63, bar  test loss 4.226, len  test loss 0.071, col  test loss 160.349


Epoch 264: 543batch [00:43, 12.34batch/s, loss=555]


epoch 264: avg train loss 527.68, bar train loss 3.021, len train loss 0.048, col train loss 152.591


Epoch 265: 2batch [00:00, 12.50batch/s, loss=508]

epoch 264: avg test  loss 578.74, bar  test loss 4.261, len  test loss 0.070, col  test loss 160.564


Epoch 265: 543batch [00:43, 12.34batch/s, loss=543]


epoch 265: avg train loss 527.74, bar train loss 3.018, len train loss 0.047, col train loss 152.645
epoch 265: avg test  loss 577.97, bar  test loss 4.254, len  test loss 0.070, col  test loss 160.460


Epoch 266: 543batch [00:44, 12.24batch/s, loss=556]


epoch 266: avg train loss 527.41, bar train loss 3.020, len train loss 0.047, col train loss 152.540


Epoch 267: 0batch [00:00, ?batch/s, loss=514]

epoch 266: avg test  loss 578.27, bar  test loss 4.247, len  test loss 0.071, col  test loss 160.552


Epoch 267: 543batch [00:44, 12.29batch/s, loss=551]


epoch 267: avg train loss 527.26, bar train loss 3.019, len train loss 0.046, col train loss 152.530


Epoch 268: 0batch [00:00, ?batch/s, loss=519]

epoch 267: avg test  loss 578.34, bar  test loss 4.251, len  test loss 0.071, col  test loss 160.494


Epoch 268: 543batch [00:44, 12.30batch/s, loss=632]


epoch 268: avg train loss 527.26, bar train loss 3.010, len train loss 0.046, col train loss 152.566


Epoch 269: 0batch [00:00, ?batch/s, loss=546]

epoch 268: avg test  loss 578.57, bar  test loss 4.248, len  test loss 0.074, col  test loss 160.526


Epoch 269: 543batch [00:44, 12.30batch/s, loss=571]


epoch 269: avg train loss 527.42, bar train loss 3.020, len train loss 0.047, col train loss 152.527


Epoch 270: 2batch [00:00, 12.50batch/s, loss=529]

epoch 269: avg test  loss 578.70, bar  test loss 4.253, len  test loss 0.071, col  test loss 160.550


Epoch 270: 543batch [00:43, 12.40batch/s, loss=498]


epoch 270: avg train loss 527.26, bar train loss 3.014, len train loss 0.046, col train loss 152.554
epoch 270: avg test  loss 577.92, bar  test loss 4.236, len  test loss 0.069, col  test loss 160.435


Epoch 271: 543batch [00:44, 12.26batch/s, loss=565]


epoch 271: avg train loss 527.41, bar train loss 3.015, len train loss 0.047, col train loss 152.579


Epoch 272: 2batch [00:00, 12.35batch/s, loss=522]

epoch 271: avg test  loss 578.61, bar  test loss 4.245, len  test loss 0.073, col  test loss 160.525


Epoch 272: 543batch [00:43, 12.35batch/s, loss=500]


epoch 272: avg train loss 527.11, bar train loss 3.009, len train loss 0.047, col train loss 152.520


Epoch 273: 2batch [00:00, 12.82batch/s, loss=522]

epoch 272: avg test  loss 578.26, bar  test loss 4.234, len  test loss 0.071, col  test loss 160.537


Epoch 273: 543batch [00:43, 12.35batch/s, loss=525]


epoch 273: avg train loss 527.36, bar train loss 3.017, len train loss 0.047, col train loss 152.509


Epoch 274: 2batch [00:00, 12.12batch/s, loss=544]

epoch 273: avg test  loss 578.51, bar  test loss 4.244, len  test loss 0.071, col  test loss 160.552


Epoch 274: 496batch [00:40, 12.24batch/s, loss=514]

In [None]:
diva.lqz1[-5:], diva.lqz2[-5:], diva.lpz1[-5:], diva.lpz2[-5:]

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')