In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import math


writer = SummaryWriter(f"{link}/saved_models/new/HVAE3/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z1_dim=1000, z2_dim=1000, d_dim=45, x_dim=7500, y_dim=2,
                 h_dim = 600, h2_dim = 600, number_components = 500,
                 beta=1, rec_alpha = 100, rec_beta = 20, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z1_dim = z1_dim
        self.z2_dim = z2_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.h_dim = h_dim
        self.h2_dim = h2_dim
        
        self.number_components = number_components
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_bar, x_col = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len3.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar3.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col3.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x_len, x_col, x_bar, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,5,2,100), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]), 0, j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 1, j] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len3.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col3.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar3.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim, z2_dim, 
                 h_dim, h2_dim, dim0=2000, dim1=1200, dim2=400):
        super(px, self).__init__()

        
        # p(z1|z2)
        
        self.p_z1 = nn.Sequential(nn.Linear(z2_dim+200, h2_dim),
                                  nn.ReLU(),
                                  nn.Linear(h2_dim, h2_dim),
                                  nn.ReLU())
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())
        
        
        # p(x|z1,z2,m)
        
        self.px_z1 = nn.Sequential(nn.Linear(z1_dim, h_dim),
                                   nn.ReLU())
        self.px_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim),
                                   nn.ReLU())
        # seperate decoders for length of RNA, color and size of bars
        self.fc_col = nn.Sequential(nn.Linear(2*h_dim, 600),
                                    nn.ReLU())
        self.dc_col1 = nn.Sequential(nn.ConvTranspose2d(in_channels=12, out_channels=36, 
                                                       kernel_size=4, stride=(2,2), padding=(1,1)),  
                                     nn.ReLU())
        self.dc_col2 = nn.Sequential(nn.ConvTranspose2d(in_channels=36, out_channels=36,
                                                       kernel_size=3, stride=(1,1), padding=(1,1)),
                                     nn.ReLU())
        self.dc_col3 = nn.Sequential(nn.ConvTranspose2d(in_channels=36, out_channels=72,
                                                        kernel_size=3, stride=(1,1), padding=(1,1)),
                                     nn.ReLU())
        self.dc_col4 = nn.Sequential(nn.ConvTranspose2d(in_channels=72, out_channels=72,
                                                        kernel_size=3, stride=(1,1), padding=(1,1)),
                                     nn.ReLU())
        self.dc_col5 = nn.Sequential(nn.ConvTranspose2d(in_channels=72, out_channels=5,
                                                        kernel_size=3, stride=(1,1), padding=(1,1)),
                                     nn.ReLU(),
                                     nn.Conv2d(in_channels=5, out_channels=5, kernel_size=1, stride=1, padding=0),
                                     nn.Softmax(dim=1))
        
        self.fc_bar = nn.Sequential(nn.Linear(2*h_dim, dim1),  
                                    nn.ReLU(),
                                    nn.Linear(dim1, dim2),
                                    nn.ReLU(),
                                    nn.Dropout(0.2))
        
        self.fc_len = nn.Sequential(nn.Linear(2*h_dim, dim1),  
                                    nn.ReLU(),
                                    nn.Linear(dim1, dim2),
                                    nn.ReLU(),
                                    nn.Dropout(0.2))
#         self.fc3 = nn.Sequential(nn.Linear(dim1, dim2, bias=False),
#                                  nn.ReLU())
        
        # Predicting length and color of each bar
        #self.color = nn.Sequential(nn.Conv1d(1,5, kernel_size=1, bias=False), 
                                  # nn.Softmax(dim=1))

        
        # Predicting the length of each bar
        self.length_bar_top = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        self.length_bar_bot = nn.Sequential(nn.Linear(dim2,100), nn.Softplus())
        #self.length_bar_scale = nn.Sequential(nn.Conv1d(100, 1, kernel_size = 3, padding = 'same', bias=False), nn.Sigmoid())
        
        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(dim2,400), nn.ReLU(),nn.Linear(400,1), nn.Softplus())
        #self.length_RNA_scale = nn.Sequential(nn.Linear(400,1, bias=False), nn.Sigmoid())
        
    def forward(self, z1, mz2):
        
        # p(z1|z2)
        pz1 = self.p_z1(mz2)
        pz1_m = self.mu_z1(pz1)
        pz1_s = self.si_z1(pz1)
        
        # p(x|z1,z2,m)
        hz1 = self.px_z1(z1)
        hz2 = self.px_z2(mz2)
        h = torch.cat([hz1,hz2],1)
        
        len_RNA = self.fc_len(h)
        len_RNA = self.length_RNA(len_RNA)
        len_RNA_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)

        
        len_bar = self.fc_bar(h)
        len_bar = torch.cat([self.length_bar_top(len_bar)[:,None,:],self.length_bar_bot(len_bar)[:,None,:]], dim=1) 
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)

        col = self.fc_col(h)
        col = col.reshape(-1,12,1,50)
        col1 = self.dc_col1(col)
        col2 = self.dc_col2(col1) + col1
        col3 = self.dc_col3(col2)
        col4 = self.dc_col4(col3) + col3
        col_bar = self.dc_col5(col4)
        
        
        return len_RNA, len_RNA_sc, len_bar, len_bar_sc, col_bar, pz1_m, pz1_s

    def reconstruct_image(self, len_RNA, var_RNA, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()
        var_RNA = var_RNA.cpu().numpy()
        #.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = int(np.round(np.random.normal(loc=len_RNA[i], scale=var_RNA[i])))
            else:
                limit = int(np.round(len_RNA[i]))
            limit = min(100, limit)
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1 = np.argmax(col_bar[i,:,0,j])
                    _col_bar_2 = np.argmax(col_bar[i,:,1,j])
                
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output


In [9]:
int(np.round(3.7, 0))
int(3.7)

3

In [10]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
pzy_ = px(45, 7500, 2, 500,500,500,500)
summary(pzy_, [(1,500),(1,700)])

Layer (type:depth-idx)                   Output Shape              Param #
px                                       --                        --
├─Sequential: 1-1                        [1, 500]                  --
│    └─Linear: 2-1                       [1, 500]                  350,500
│    └─ReLU: 2-2                         [1, 500]                  --
│    └─Linear: 2-3                       [1, 500]                  250,500
│    └─ReLU: 2-4                         [1, 500]                  --
├─Sequential: 1-2                        [1, 500]                  --
│    └─Linear: 2-5                       [1, 500]                  250,500
├─Sequential: 1-3                        [1, 500]                  --
│    └─Linear: 2-6                       [1, 500]                  250,500
│    └─Softplus: 2-7                     [1, 500]                  --
├─Sequential: 1-4                        [1, 500]                  --
│    └─Linear: 2-8                       [1, 500]                

## Endcoder Classes

In [11]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [12]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim):
        super(qz, self).__init__()

        # q(z2 | x)
        self.encoder_z2 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.mu_z2 = nn.Sequential(nn.Linear(2560, z2_dim))
        self.si_z2 = nn.Sequential(nn.Linear(2560, z2_dim), nn.Softplus())
        
        
        # q(z1 | x, z2)
        self.encoder_z1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 'valid',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'valid', bias=False),
            nn.ReLU(), 
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
#             nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_z2 = nn.Sequential(nn.Linear(z2_dim+200, h_dim), nn.ReLU())
        self.fc_z1 = nn.Sequential(nn.Linear(2560, h_dim), nn.ReLU())
        
        self.fc_z1_z2 = nn.Sequential(nn.Linear(2*h_dim, h2_dim), nn.ReLU())
        
        self.mu_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim))
        self.si_z1 = nn.Sequential(nn.Linear(h2_dim, z1_dim), nn.Softplus())

#         torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         torch.nn.init.xavier_uniform_(self.encoder[3].weight)
#         torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         self.fc11[0].bias.data.zero_()
#         torch.nn.init.xavier_uniform_(self.fc12[0].weight)
#         self.fc12[0].bias.data.zero_()
    
    def q_z2(self, x):
        z2 = self.encoder_z2(x)
        z2 = z2.view(-1, 2560)
        z2_m = self.mu_z2(z2) 
        z2_s = self.si_z2(z2)
        
        return z2_m, z2_s
    
    def forward(self, x, m):
        
        # q(z2 | x) & m
        z2_m, z2_s = self.q_z2(x)
        # reparameterization trick
        qz2 = dist.Normal(z2_m, z2_s)
        z2 = qz2.rsample()
        # z2 & m
        mz2 = torch.cat([z2, m],1)
        
        
        
        # q(z1 | x, z2, m)
        z1 = self.encoder_z1(x)
        z1 = z1.view(-1, 2560)
        z1 = self.fc_z1(z1)
        
        mz2_ = self.fc_z2(mz2)
        
        z1 = torch.cat([mz2_, z1],1)
        z1 = self.fc_z1_z2(z1)
        z1_m = self.mu_z1(z1)
        z1_s = self.si_z1(z1)
        
        qz1 = dist.Normal(z1_m, z1_s)
        z1 = qz1.rsample()
        
        
        #z_loc = self.fc11(h)
        #z_scale = self.fc12(h) + 1e-7

        return z1, z2, mz2, z1_m, z1_s, z2_m, z2_s




In [13]:
a = torch.tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,3],[4,6]])

torch.cat([a,b],1)

tensor([[1, 2, 3, 1, 3],
        [4, 5, 6, 4, 6]])

In [14]:
enc = qz(128,10,10,10,500,400,400)
enc(torch.zeros((1,3,25,100)), torch.zeros((1,200)))
summary(enc, [(1,3,25,100),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 1, 10]           --
│    └─Conv2d: 2-1                       [1, 32, 23, 98]           864
│    └─ReLU: 2-2                         [1, 32, 23, 98]           --
│    └─Conv2d: 2-3                       [1, 64, 21, 96]           18,432
│    └─ReLU: 2-4                         [1, 64, 21, 96]           --
│    └─MaxPool2d: 2-5                    [1, 64, 10, 48]           --
│    └─Conv2d: 2-6                       [1, 128, 8, 46]           73,728
│    └─ReLU: 2-7                         [1, 128, 8, 46]           --
│    └─MaxPool2d: 2-8                    [1, 128, 4, 23]           --
│    └─Conv2d: 2-9                       [1, 256, 2, 21]           294,912
│    └─ReLU: 2-10                        [1, 256, 2, 21]           --
│    └─MaxPool2d: 2-11                   [1, 256, 1, 10]           --
├

In [15]:
def log_Normal_diag(x, mean, std, average=False, dim=None):
    log_var = 2*torch.log(std)
    log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
    if average:
        return torch.mean( log_normal, dim )
    else:
        return torch.sum( log_normal, dim )

## Full model class

In [16]:
class HVAE(nn.Module):
    def __init__(self, args):
        super(HVAE, self).__init__()
        self.z1_dim = args.z1_dim
        self.z2_dim = args.z2_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.h_dim = args.h_dim
        self.h2_dim = args.h2_dim
        self.number_components = args.number_components
        
        #d_dim, x_dim, y_dim, z1_dim ,z2_dim, h_dim, h2_dim
        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z1_dim, self.z2_dim, 
                     self.h_dim, self.h2_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup
        
        self.add_pseudoinputs()
        
        self.lqz1 = []
        self.lqz2 = []
        self.lpz1 = []
        self.lpz2 = []
        
        self.bar = []
        self.len = []
        self.col = []
        
        self.cuda()

    def forward(self, d, x, y, m):
        # Encode
        z1, z2, mz2, z1_m, z1_s, z2_m, z2_s = self.qz(x, m)
        # Decode
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, pz1_m, pz1_s = self.px(z1, mz2)
        
        return x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s
    
    
    def log_p_z2(self, z2):
        C = self.number_components
        
        X = self.means(self.idle_input).view(-1,3,25,100)
        
        pz2_m, pz2_s = self.qz.q_z2(X)
        
        z_expand = z2.unsqueeze(1)
        means = pz2_m.unsqueeze(0)
        stds = pz2_s.unsqueeze(0)
        
        a = log_Normal_diag(z_expand, means, stds, dim=2) - math.log(C)
        a_max, _ = torch.max(a,1)
        
        log_prior = (a_max + torch.log(torch.sum(torch.exp(a-a_max.unsqueeze(1)),1)))
        
        return log_prior
    
    def loss_function(self, d, x, y, m, out_len, out_bar, out_col):
        
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = self.forward(d, x, y, m)
        
        # Reconstruction Loss
        mask = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,None,:])
        mask1 = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,:]).repeat(1,2,1)

        x_col = mask.repeat(1,5,2,1)*x_col
        
#         x_gap = x_col[:,0,:]
#         x_nuc = x_col[:,1:,:]
                
        
        
        dist_len = dist.Normal(x_len, x_len_scale+1e-7)
        log_len = dist_len.log_prob(out_len[:,None]).mean()
        
        
        mse_bar = ((((x_bar - out_bar)**2)*mask1).sum(dim=(1,2))/(mask1.sum(dim=(1,2)))).sum()#.detach().item()
        
        max_bar = torch.argmax(x_col, dim=1)
        sort_bar = torch.argsort(out_col, dim=1)
        acc_bar = (((max_bar==sort_bar[:,0,:,:])*mask1).sum((1,2))/out_len).sum()
        acc_bar2 = (((max_bar==sort_bar[:,1,:,:])*mask1).sum((1,2))/out_len).sum() + acc_bar
        acc_bar3 = (((max_bar==sort_bar[:,2,:,:])*mask1).sum((1,2))/out_len).sum() + acc_bar2
        acc_bar4 = (((max_bar==sort_bar[:,3,:,:])*mask1).sum((1,2))/out_len).sum() + acc_bar3
        acc_bar5 = (((max_bar==sort_bar[:,4,:,:])*mask1).sum((1,2))/out_len).sum() + acc_bar4
        #acc_bar2 = (((max_bar==torch.argsort(out_col, dim=1)[:,1,:,:])*mask1).sum((1,2))/out_len).sum() + acc_bar
        
        RE_len = -log_len
        RE_bar = mse_bar#-log_bar
        RE_col = F.cross_entropy(x_col, out_col, reduction='sum')
          
            
        # KL loss
        KL_p_z1 = log_Normal_diag(z1, pz1_m, pz1_s, dim=1).sum()
        KL_q_z1 = log_Normal_diag(z1, z1_m, z1_s, dim=1).sum()
        KL_p_z2 = self.log_p_z2(z2).sum()
        KL_q_z2 = log_Normal_diag(z2, z2_m, z2_s, dim=1).sum()
        KL = -(KL_p_z1 + KL_p_z2 - KL_q_z1 - KL_q_z2)
        #print(KL_p_z1.shape,KL_p_z2.shape,KL_q_z1.shape,KL_q_z2.shape)
        
        
#         self.lpz1.append(KL_p_z1.detach().item())
#         self.lpz2.append(KL_p_z2.detach().item())
#         self.lqz1.append(KL_q_z1.detach().item())
#         self.lqz2.append(KL_q_z2.detach().item())
        
#         self.bar.append(RE_bar.detach())
#         self.col.append(RE_col.detach())
#         self.len.append(RE_len.detach())
        
        
        return self.rec_alpha * RE_len \
                  + self.rec_beta * RE_bar \
                  + self.rec_gamma * RE_col \
                  + self.beta * KL, \
                  RE_bar, RE_len, RE_col, mse_bar, acc_bar, acc_bar2, acc_bar3, acc_bar4, acc_bar5
    
    
    
    def add_pseudoinputs(self):
        # TODO: rework pseudo generation based on reconstruction
        nonlinearity = nn.Hardtanh(min_val=0.0, max_val=1.0)
        self.means = nn.Sequential(nn.Linear(self.number_components, 3*25*100, bias=False), nonlinearity)
        self.idle_input = Variable(torch.eye(self.number_components, self.number_components), requires_grad=False).to(DEVICE)

In [17]:
a = dist.Normal(0,1)
a.log_prob(torch.tensor(10))

tensor(-50.9189)

In [18]:
default_args = diva_args()
enc = HVAE(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1),(1,200)])

Layer (type:depth-idx)                   Output Shape              Param #
HVAE                                     --                        --
├─px: 1-1                                --                        (recursive)
│    └─Sequential: 2-1                   --                        (recursive)
│    │    └─Linear: 3-1                  --                        (recursive)
│    │    └─ReLU: 3-2                    --                        --
│    │    └─Linear: 3-3                  --                        (recursive)
│    │    └─ReLU: 3-4                    --                        --
│    └─Sequential: 2-2                   --                        (recursive)
│    │    └─Linear: 3-5                  --                        (recursive)
│    └─Sequential: 2-3                   --                        (recursive)
│    │    └─Linear: 3-6                  --                        (recursive)
│    │    └─Softplus: 3-7                --                        --
│    └─Sequen

# Training the model

## Loading dataset

In [19]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [20]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [21]:
RNA_dataset.x_bar.shape, RNA_dataset.x_col.shape 

((34721, 2, 100), (34721, 5, 2, 100))

In [22]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    acc_bar2 = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar, m) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col, m= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5 = model.loss_function(d.float(), x.float(), y.float(), m.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        mse_bar += mse
        acc_bar += acc
        acc_bar2 += acc2
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    acc_bar2 /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc3, acc4, acc5 

In [23]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    mse_bar = 0
    acc_bar = 0   
    acc_bar2 = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar, m) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col, m = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE), m.to(DEVICE)
            loss, bar_loss, len_loss, col_loss, mse, acc, acc2, acc3, acc4, acc5  = model.loss_function(d.float(), x.float(), y.float(),m.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
            mse_bar += mse
            acc_bar += acc
            acc_bar2 += acc2
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_len_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    acc_bar2 /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar, acc_bar2, acc3, acc4, acc5 
  

In [24]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col, mtr, atr, atr2, atr3, atr4, atr5  = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test, mte, ate, ate2, ate3, ate4, ate5  = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train-top1': atr, 'test-top1':ate,
                                          'train-top2': atr2, 'test-top2':ate2,
                                          'train-top3': atr3, 'test-top3':ate3,
                                          'train-top4': atr4, 'test-top4':ate4,
                                          'train-top5': atr4, 'test-top5':ate4
                                         }, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [25]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x_1, x_1var, x_2, x_2var, x_3 ,z1, z2, z1_m, z1_s, z2_m, z2_s, pz1_m, pz1_s = diva(d,x,y,m)
        out = diva.px.reconstruct_image(x_1, x_1var, x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [26]:
DEVICE

device(type='cuda')

## Model Training

In [27]:
default_args = diva_args(prewarmup=0, number_components=50)

In [28]:
diva = HVAE(default_args).to(DEVICE)

In [29]:
diva.load_state_dict(torch.load(f'{link}/saved_models/new/HVAE3/checkpoints/100.pth'))

<All keys matched successfully>

In [30]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [31]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.001)

In [32]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [33]:
writer.flush()

In [34]:
#diva.rec_gamma = 3

In [35]:
%tensorboard  --logdir="D:/users/Marko/downloads/mirna/saved_models/new/HVAE3/tensorboard/"

Reusing TensorBoard on port 6006 (pid 14624), started 1:49:53 ago. (Use '!kill 14624' to kill it.)

In [36]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 135, save_folder="new/HVAE3",save_interval=5)

Epoch 136: 272batch [00:33,  8.15batch/s, loss=201]


epoch 136: avg train loss 211.55, bar train loss 2.507, len train loss 0.019, col train loss 153.775
epoch 136: avg test  loss 272.60, bar  test loss 3.243, len  test loss 0.431, col  test loss 159.984


Epoch 137: 272batch [00:31,  8.69batch/s, loss=228]


epoch 137: avg train loss 209.77, bar train loss 2.486, len train loss 0.019, col train loss 153.594


Epoch 138: 1batch [00:00,  8.93batch/s, loss=207]

epoch 137: avg test  loss 275.81, bar  test loss 3.256, len  test loss 0.451, col  test loss 159.969


Epoch 138: 272batch [00:30,  8.86batch/s, loss=230]


epoch 138: avg train loss 209.52, bar train loss 2.479, len train loss 0.019, col train loss 153.518


Epoch 139: 1batch [00:00,  8.93batch/s, loss=208]

epoch 138: avg test  loss 282.35, bar  test loss 3.274, len  test loss 0.523, col  test loss 160.085


Epoch 139: 272batch [00:30,  8.99batch/s, loss=210]


epoch 139: avg train loss 209.13, bar train loss 2.468, len train loss 0.018, col train loss 153.461


Epoch 140: 1batch [00:00,  8.93batch/s, loss=206]

epoch 139: avg test  loss 274.60, bar  test loss 3.263, len  test loss 0.443, col  test loss 160.063


Epoch 140: 272batch [00:30,  8.93batch/s, loss=240]


epoch 140: avg train loss 209.00, bar train loss 2.461, len train loss 0.019, col train loss 153.432
epoch 140: avg test  loss 270.94, bar  test loss 3.278, len  test loss 0.408, col  test loss 160.121


Epoch 141: 272batch [00:30,  8.78batch/s, loss=211]


epoch 141: avg train loss 208.93, bar train loss 2.461, len train loss 0.019, col train loss 153.377


Epoch 142: 1batch [00:00,  8.93batch/s, loss=212]

epoch 141: avg test  loss 284.45, bar  test loss 3.305, len  test loss 0.537, col  test loss 160.156


Epoch 142: 272batch [00:31,  8.72batch/s, loss=211]


epoch 142: avg train loss 208.64, bar train loss 2.453, len train loss 0.019, col train loss 153.342


Epoch 143: 1batch [00:00,  8.00batch/s, loss=206]

epoch 142: avg test  loss 273.16, bar  test loss 3.280, len  test loss 0.429, col  test loss 160.144


Epoch 143: 272batch [00:30,  8.78batch/s, loss=219]


epoch 143: avg train loss 208.25, bar train loss 2.445, len train loss 0.017, col train loss 153.258


Epoch 144: 1batch [00:00,  9.01batch/s, loss=205]

epoch 143: avg test  loss 277.71, bar  test loss 3.314, len  test loss 0.467, col  test loss 160.151


Epoch 144: 272batch [00:30,  8.93batch/s, loss=210]


epoch 144: avg train loss 208.11, bar train loss 2.439, len train loss 0.018, col train loss 153.238


Epoch 145: 1batch [00:00,  8.77batch/s, loss=212]

epoch 144: avg test  loss 282.10, bar  test loss 3.316, len  test loss 0.512, col  test loss 160.234


Epoch 145: 272batch [00:31,  8.72batch/s, loss=208]


epoch 145: avg train loss 208.00, bar train loss 2.438, len train loss 0.017, col train loss 153.182
epoch 145: avg test  loss 265.04, bar  test loss 3.307, len  test loss 0.342, col  test loss 160.356


Epoch 146: 272batch [00:31,  8.68batch/s, loss=207]


epoch 146: avg train loss 207.78, bar train loss 2.429, len train loss 0.018, col train loss 153.141


Epoch 147: 1batch [00:00,  8.06batch/s, loss=205]

epoch 146: avg test  loss 281.58, bar  test loss 3.300, len  test loss 0.499, col  test loss 160.329


Epoch 147: 272batch [00:30,  8.78batch/s, loss=221]


epoch 147: avg train loss 207.74, bar train loss 2.428, len train loss 0.018, col train loss 153.125


Epoch 148: 1batch [00:00,  8.40batch/s, loss=210]

epoch 147: avg test  loss 273.86, bar  test loss 3.305, len  test loss 0.433, col  test loss 160.318


Epoch 148: 272batch [00:30,  8.85batch/s, loss=200]


epoch 148: avg train loss 207.42, bar train loss 2.419, len train loss 0.018, col train loss 153.038


Epoch 149: 1batch [00:00,  9.01batch/s, loss=211]

epoch 148: avg test  loss 274.59, bar  test loss 3.307, len  test loss 0.439, col  test loss 160.425


Epoch 149: 272batch [00:30,  9.04batch/s, loss=209]


epoch 149: avg train loss 207.20, bar train loss 2.409, len train loss 0.018, col train loss 153.025


Epoch 150: 1batch [00:00,  9.09batch/s, loss=205]

epoch 149: avg test  loss 274.94, bar  test loss 3.352, len  test loss 0.435, col  test loss 160.319


Epoch 150: 272batch [00:30,  8.95batch/s, loss=210]


epoch 150: avg train loss 207.06, bar train loss 2.410, len train loss 0.017, col train loss 152.968
epoch 150: avg test  loss 274.99, bar  test loss 3.332, len  test loss 0.434, col  test loss 160.407


Epoch 151: 272batch [00:30,  9.00batch/s, loss=209]


epoch 151: avg train loss 206.90, bar train loss 2.401, len train loss 0.018, col train loss 152.922


Epoch 152: 1batch [00:00,  8.85batch/s, loss=209]

epoch 151: avg test  loss 275.69, bar  test loss 3.333, len  test loss 0.447, col  test loss 160.351


Epoch 152: 272batch [00:30,  8.82batch/s, loss=209]


epoch 152: avg train loss 206.74, bar train loss 2.399, len train loss 0.017, col train loss 152.886


Epoch 153: 1batch [00:00,  8.93batch/s, loss=204]

epoch 152: avg test  loss 272.26, bar  test loss 3.328, len  test loss 0.411, col  test loss 160.456


Epoch 153: 272batch [00:30,  8.95batch/s, loss=208]


epoch 153: avg train loss 206.38, bar train loss 2.385, len train loss 0.017, col train loss 152.827


Epoch 154: 1batch [00:00,  8.77batch/s, loss=210]

epoch 153: avg test  loss 282.39, bar  test loss 3.385, len  test loss 0.500, col  test loss 160.505


Epoch 154: 272batch [00:30,  8.84batch/s, loss=205]


epoch 154: avg train loss 206.26, bar train loss 2.384, len train loss 0.017, col train loss 152.778


Epoch 155: 1batch [00:00,  8.70batch/s, loss=206]

epoch 154: avg test  loss 277.08, bar  test loss 3.317, len  test loss 0.459, col  test loss 160.378


Epoch 155: 272batch [00:31,  8.63batch/s, loss=226]


epoch 155: avg train loss 206.26, bar train loss 2.383, len train loss 0.017, col train loss 152.758
epoch 155: avg test  loss 282.16, bar  test loss 3.352, len  test loss 0.508, col  test loss 160.509


Epoch 156: 272batch [00:30,  8.93batch/s, loss=211]


epoch 156: avg train loss 206.03, bar train loss 2.375, len train loss 0.017, col train loss 152.711


Epoch 157: 1batch [00:00,  8.85batch/s, loss=198]

epoch 156: avg test  loss 279.10, bar  test loss 3.346, len  test loss 0.476, col  test loss 160.521


Epoch 157: 272batch [00:30,  8.91batch/s, loss=199]


epoch 157: avg train loss 205.91, bar train loss 2.375, len train loss 0.017, col train loss 152.679


Epoch 158: 1batch [00:00,  8.77batch/s, loss=199]

epoch 157: avg test  loss 279.96, bar  test loss 3.349, len  test loss 0.482, col  test loss 160.536


Epoch 158: 272batch [00:31,  8.74batch/s, loss=209]


epoch 158: avg train loss 205.70, bar train loss 2.364, len train loss 0.017, col train loss 152.629


Epoch 159: 1batch [00:00,  8.62batch/s, loss=205]

epoch 158: avg test  loss 274.72, bar  test loss 3.363, len  test loss 0.417, col  test loss 160.507


Epoch 159: 272batch [00:30,  8.84batch/s, loss=214]


epoch 159: avg train loss 205.88, bar train loss 2.365, len train loss 0.018, col train loss 152.641


Epoch 160: 1batch [00:00,  8.62batch/s, loss=203]

epoch 159: avg test  loss 274.25, bar  test loss 3.342, len  test loss 0.427, col  test loss 160.455


Epoch 160: 272batch [00:31,  8.77batch/s, loss=231]


epoch 160: avg train loss 205.36, bar train loss 2.355, len train loss 0.017, col train loss 152.531
epoch 160: avg test  loss 280.33, bar  test loss 3.345, len  test loss 0.482, col  test loss 160.499


Epoch 161: 272batch [00:30,  8.82batch/s, loss=215]


epoch 161: avg train loss 205.22, bar train loss 2.353, len train loss 0.016, col train loss 152.505


Epoch 162: 1batch [00:00,  8.93batch/s, loss=214]

epoch 161: avg test  loss 275.44, bar  test loss 3.365, len  test loss 0.435, col  test loss 160.549


Epoch 162: 272batch [00:30,  8.80batch/s, loss=203]


epoch 162: avg train loss 204.89, bar train loss 2.344, len train loss 0.016, col train loss 152.432


Epoch 163: 1batch [00:00,  8.77batch/s, loss=210]

epoch 162: avg test  loss 276.95, bar  test loss 3.375, len  test loss 0.440, col  test loss 160.606


Epoch 163: 272batch [00:30,  8.82batch/s, loss=216]


epoch 163: avg train loss 205.00, bar train loss 2.345, len train loss 0.017, col train loss 152.429


Epoch 164: 1batch [00:00,  8.70batch/s, loss=208]

epoch 163: avg test  loss 287.18, bar  test loss 3.387, len  test loss 0.546, col  test loss 160.660


Epoch 164: 272batch [00:30,  8.88batch/s, loss=221]


epoch 164: avg train loss 204.81, bar train loss 2.338, len train loss 0.017, col train loss 152.410


Epoch 165: 1batch [00:00,  8.85batch/s, loss=195]

epoch 164: avg test  loss 279.88, bar  test loss 3.389, len  test loss 0.477, col  test loss 160.580


Epoch 165: 272batch [00:30,  8.85batch/s, loss=204]


epoch 165: avg train loss 204.80, bar train loss 2.339, len train loss 0.017, col train loss 152.349
epoch 165: avg test  loss 279.37, bar  test loss 3.392, len  test loss 0.470, col  test loss 160.665


Epoch 166: 272batch [00:30,  8.87batch/s, loss=205]


epoch 166: avg train loss 204.50, bar train loss 2.332, len train loss 0.016, col train loss 152.300


Epoch 167: 1batch [00:00,  8.70batch/s, loss=194]

epoch 166: avg test  loss 284.26, bar  test loss 3.389, len  test loss 0.519, col  test loss 160.613


Epoch 167: 272batch [00:30,  8.85batch/s, loss=219]


epoch 167: avg train loss 204.45, bar train loss 2.330, len train loss 0.016, col train loss 152.279


Epoch 168: 1batch [00:00,  8.33batch/s, loss=206]

epoch 167: avg test  loss 282.33, bar  test loss 3.404, len  test loss 0.478, col  test loss 160.663


Epoch 168: 272batch [00:30,  8.86batch/s, loss=193]


epoch 168: avg train loss 204.39, bar train loss 2.329, len train loss 0.016, col train loss 152.248


Epoch 169: 1batch [00:00,  8.77batch/s, loss=202]

epoch 168: avg test  loss 282.26, bar  test loss 3.406, len  test loss 0.495, col  test loss 160.688


Epoch 169: 272batch [00:30,  8.84batch/s, loss=205]


epoch 169: avg train loss 204.42, bar train loss 2.326, len train loss 0.017, col train loss 152.263


Epoch 170: 1batch [00:00,  8.62batch/s, loss=208]

epoch 169: avg test  loss 284.50, bar  test loss 3.410, len  test loss 0.516, col  test loss 160.691


Epoch 170: 272batch [00:30,  8.84batch/s, loss=214]


epoch 170: avg train loss 203.99, bar train loss 2.314, len train loss 0.016, col train loss 152.163
epoch 170: avg test  loss 279.82, bar  test loss 3.394, len  test loss 0.463, col  test loss 160.768


Epoch 171: 272batch [00:31,  8.64batch/s, loss=218]


epoch 171: avg train loss 203.93, bar train loss 2.312, len train loss 0.017, col train loss 152.144


Epoch 172: 1batch [00:00,  8.70batch/s, loss=210]

epoch 171: avg test  loss 277.59, bar  test loss 3.415, len  test loss 0.449, col  test loss 160.764


Epoch 172: 272batch [00:30,  8.85batch/s, loss=203]


epoch 172: avg train loss 203.91, bar train loss 2.311, len train loss 0.017, col train loss 152.125


Epoch 173: 1batch [00:00,  8.62batch/s, loss=208]

epoch 172: avg test  loss 278.49, bar  test loss 3.398, len  test loss 0.460, col  test loss 160.692


Epoch 173: 272batch [00:30,  8.81batch/s, loss=213]


epoch 173: avg train loss 203.61, bar train loss 2.305, len train loss 0.016, col train loss 152.067


Epoch 174: 1batch [00:00,  8.13batch/s, loss=209]

epoch 173: avg test  loss 283.25, bar  test loss 3.412, len  test loss 0.496, col  test loss 160.831


Epoch 174: 272batch [00:30,  8.81batch/s, loss=198]


epoch 174: avg train loss 203.62, bar train loss 2.305, len train loss 0.016, col train loss 152.027


Epoch 175: 1batch [00:00,  6.85batch/s, loss=204]

epoch 174: avg test  loss 285.15, bar  test loss 3.429, len  test loss 0.512, col  test loss 160.745


Epoch 175: 272batch [00:32,  8.34batch/s, loss=214]


epoch 175: avg train loss 203.44, bar train loss 2.300, len train loss 0.016, col train loss 152.016
epoch 175: avg test  loss 279.09, bar  test loss 3.400, len  test loss 0.461, col  test loss 160.824


Epoch 176: 272batch [00:30,  8.81batch/s, loss=215]


epoch 176: avg train loss 203.52, bar train loss 2.303, len train loss 0.016, col train loss 151.983


Epoch 177: 1batch [00:00,  8.77batch/s, loss=204]

epoch 176: avg test  loss 276.50, bar  test loss 3.424, len  test loss 0.435, col  test loss 160.784


Epoch 177: 272batch [00:30,  8.82batch/s, loss=201]


epoch 177: avg train loss 203.12, bar train loss 2.287, len train loss 0.016, col train loss 151.937


Epoch 178: 1batch [00:00,  8.33batch/s, loss=197]

epoch 177: avg test  loss 275.49, bar  test loss 3.408, len  test loss 0.428, col  test loss 160.713


Epoch 178: 272batch [00:30,  8.79batch/s, loss=227]


epoch 178: avg train loss 203.05, bar train loss 2.287, len train loss 0.016, col train loss 151.904


Epoch 179: 1batch [00:00,  8.62batch/s, loss=199]

epoch 178: avg test  loss 279.80, bar  test loss 3.443, len  test loss 0.460, col  test loss 160.811


Epoch 179: 272batch [00:30,  8.83batch/s, loss=215]


epoch 179: avg train loss 203.03, bar train loss 2.285, len train loss 0.016, col train loss 151.863


Epoch 180: 1batch [00:00,  8.62batch/s, loss=211]

epoch 179: avg test  loss 279.15, bar  test loss 3.447, len  test loss 0.429, col  test loss 160.857


Epoch 180: 272batch [00:30,  8.83batch/s, loss=218]


epoch 180: avg train loss 202.99, bar train loss 2.284, len train loss 0.015, col train loss 151.872
epoch 180: avg test  loss 279.07, bar  test loss 3.433, len  test loss 0.453, col  test loss 160.832


Epoch 181: 272batch [00:30,  8.82batch/s, loss=222]


epoch 181: avg train loss 202.80, bar train loss 2.281, len train loss 0.015, col train loss 151.804


Epoch 182: 1batch [00:00,  8.85batch/s, loss=198]

epoch 181: avg test  loss 281.48, bar  test loss 3.432, len  test loss 0.479, col  test loss 160.867


Epoch 182: 272batch [00:30,  8.83batch/s, loss=209]


epoch 182: avg train loss 202.60, bar train loss 2.274, len train loss 0.015, col train loss 151.768


Epoch 183: 1batch [00:00,  8.70batch/s, loss=199]

epoch 182: avg test  loss 282.46, bar  test loss 3.434, len  test loss 0.493, col  test loss 160.854


Epoch 183: 272batch [00:30,  8.84batch/s, loss=206]


epoch 183: avg train loss 202.45, bar train loss 2.269, len train loss 0.016, col train loss 151.753


Epoch 184: 1batch [00:00,  8.85batch/s, loss=205]

epoch 183: avg test  loss 279.72, bar  test loss 3.439, len  test loss 0.463, col  test loss 160.982


Epoch 184: 272batch [00:30,  8.87batch/s, loss=194]


epoch 184: avg train loss 202.44, bar train loss 2.270, len train loss 0.015, col train loss 151.723


Epoch 185: 1batch [00:00,  8.62batch/s, loss=210]

epoch 184: avg test  loss 281.17, bar  test loss 3.450, len  test loss 0.475, col  test loss 160.913


Epoch 185: 272batch [00:40,  6.72batch/s, loss=232]


epoch 185: avg train loss 202.40, bar train loss 2.270, len train loss 0.016, col train loss 151.699
epoch 185: avg test  loss 285.14, bar  test loss 3.451, len  test loss 0.514, col  test loss 160.904


Epoch 186: 272batch [00:34,  7.83batch/s, loss=189]


epoch 186: avg train loss 202.14, bar train loss 2.260, len train loss 0.016, col train loss 151.639


Epoch 187: 1batch [00:00,  6.45batch/s, loss=202]

epoch 186: avg test  loss 278.72, bar  test loss 3.457, len  test loss 0.452, col  test loss 160.994


Epoch 187: 272batch [00:43,  6.22batch/s, loss=202]


epoch 187: avg train loss 201.90, bar train loss 2.254, len train loss 0.015, col train loss 151.591


Epoch 188: 1batch [00:00,  7.14batch/s, loss=198]

epoch 187: avg test  loss 287.08, bar  test loss 3.465, len  test loss 0.530, col  test loss 161.006


Epoch 188: 272batch [00:44,  6.15batch/s, loss=227]


epoch 188: avg train loss 202.01, bar train loss 2.259, len train loss 0.016, col train loss 151.580


Epoch 189: 1batch [00:00,  5.92batch/s, loss=202]

epoch 188: avg test  loss 285.12, bar  test loss 3.472, len  test loss 0.486, col  test loss 161.064


Epoch 189: 272batch [00:43,  6.21batch/s, loss=209]


epoch 189: avg train loss 202.10, bar train loss 2.262, len train loss 0.016, col train loss 151.548


Epoch 190: 0batch [00:00, ?batch/s]

epoch 189: avg test  loss 282.73, bar  test loss 3.449, len  test loss 0.491, col  test loss 160.927


Epoch 190: 272batch [00:43,  6.20batch/s, loss=197]


epoch 190: avg train loss 201.94, bar train loss 2.259, len train loss 0.015, col train loss 151.510
epoch 190: avg test  loss 282.63, bar  test loss 3.478, len  test loss 0.485, col  test loss 161.047


Epoch 191: 272batch [00:44,  6.13batch/s, loss=199]


epoch 191: avg train loss 201.52, bar train loss 2.242, len train loss 0.015, col train loss 151.500


Epoch 192: 1batch [00:00,  6.33batch/s, loss=201]

epoch 191: avg test  loss 284.99, bar  test loss 3.456, len  test loss 0.512, col  test loss 160.931


Epoch 192: 272batch [00:50,  5.40batch/s, loss=197]


epoch 192: avg train loss 201.42, bar train loss 2.239, len train loss 0.015, col train loss 151.459


Epoch 193: 1batch [00:00,  6.54batch/s, loss=197]

epoch 192: avg test  loss 280.14, bar  test loss 3.493, len  test loss 0.457, col  test loss 160.984


Epoch 193: 272batch [00:41,  6.56batch/s, loss=205]


epoch 193: avg train loss 201.72, bar train loss 2.247, len train loss 0.016, col train loss 151.476


Epoch 194: 1batch [00:00,  5.95batch/s, loss=208]

epoch 193: avg test  loss 283.69, bar  test loss 3.479, len  test loss 0.484, col  test loss 161.058


Epoch 194: 272batch [00:39,  6.81batch/s, loss=214]


epoch 194: avg train loss 201.41, bar train loss 2.240, len train loss 0.015, col train loss 151.424


Epoch 195: 1batch [00:00,  6.54batch/s, loss=208]

epoch 194: avg test  loss 280.98, bar  test loss 3.465, len  test loss 0.460, col  test loss 160.970


Epoch 195: 272batch [00:42,  6.43batch/s, loss=219]


epoch 195: avg train loss 201.26, bar train loss 2.237, len train loss 0.015, col train loss 151.390
epoch 195: avg test  loss 285.00, bar  test loss 3.502, len  test loss 0.503, col  test loss 161.036


Epoch 196: 272batch [00:44,  6.14batch/s, loss=213]


epoch 196: avg train loss 201.21, bar train loss 2.234, len train loss 0.015, col train loss 151.364


Epoch 197: 1batch [00:00,  5.99batch/s, loss=203]

epoch 196: avg test  loss 283.78, bar  test loss 3.491, len  test loss 0.493, col  test loss 161.002


Epoch 197: 272batch [00:43,  6.24batch/s, loss=212]


epoch 197: avg train loss 201.11, bar train loss 2.233, len train loss 0.015, col train loss 151.311


Epoch 198: 1batch [00:00,  6.21batch/s, loss=199]

epoch 197: avg test  loss 281.04, bar  test loss 3.499, len  test loss 0.465, col  test loss 161.081


Epoch 198: 272batch [00:40,  6.73batch/s, loss=213]


epoch 198: avg train loss 201.00, bar train loss 2.227, len train loss 0.015, col train loss 151.306


Epoch 199: 1batch [00:00,  5.88batch/s, loss=192]

epoch 198: avg test  loss 284.11, bar  test loss 3.521, len  test loss 0.480, col  test loss 161.145


Epoch 199: 272batch [00:41,  6.62batch/s, loss=196]


epoch 199: avg train loss 201.10, bar train loss 2.228, len train loss 0.016, col train loss 151.314


Epoch 200: 1batch [00:00,  6.67batch/s, loss=192]

epoch 199: avg test  loss 283.88, bar  test loss 3.497, len  test loss 0.490, col  test loss 161.174


Epoch 200: 272batch [00:39,  6.85batch/s, loss=201]


epoch 200: avg train loss 200.80, bar train loss 2.223, len train loss 0.015, col train loss 151.269
epoch 200: avg test  loss 284.70, bar  test loss 3.512, len  test loss 0.497, col  test loss 161.073


Epoch 201: 272batch [00:41,  6.60batch/s, loss=212]


epoch 201: avg train loss 200.64, bar train loss 2.219, len train loss 0.014, col train loss 151.210


Epoch 202: 1batch [00:00,  7.81batch/s, loss=201]

epoch 201: avg test  loss 280.92, bar  test loss 3.494, len  test loss 0.464, col  test loss 161.220


Epoch 202: 272batch [00:41,  6.62batch/s, loss=211]


epoch 202: avg train loss 200.64, bar train loss 2.221, len train loss 0.014, col train loss 151.199


Epoch 203: 0batch [00:00, ?batch/s]

epoch 202: avg test  loss 285.46, bar  test loss 3.493, len  test loss 0.491, col  test loss 161.163


Epoch 203: 272batch [00:45,  6.04batch/s, loss=209]


epoch 203: avg train loss 200.68, bar train loss 2.221, len train loss 0.015, col train loss 151.164


Epoch 204: 0batch [00:00, ?batch/s]

epoch 203: avg test  loss 283.33, bar  test loss 3.495, len  test loss 0.483, col  test loss 161.161


Epoch 204: 272batch [00:44,  6.08batch/s, loss=199]


epoch 204: avg train loss 200.55, bar train loss 2.212, len train loss 0.016, col train loss 151.155


Epoch 205: 1batch [00:00,  7.14batch/s, loss=196]

epoch 204: avg test  loss 278.61, bar  test loss 3.511, len  test loss 0.420, col  test loss 161.165


Epoch 205: 272batch [00:44,  6.11batch/s, loss=189]


epoch 205: avg train loss 200.61, bar train loss 2.216, len train loss 0.015, col train loss 151.164
epoch 205: avg test  loss 285.01, bar  test loss 3.509, len  test loss 0.483, col  test loss 161.185


Epoch 206: 272batch [00:41,  6.49batch/s, loss=197]


epoch 206: avg train loss 200.37, bar train loss 2.211, len train loss 0.014, col train loss 151.118


Epoch 207: 1batch [00:00,  6.99batch/s, loss=195]

epoch 206: avg test  loss 281.88, bar  test loss 3.508, len  test loss 0.471, col  test loss 161.165


Epoch 207: 272batch [00:44,  6.13batch/s, loss=210]


epoch 207: avg train loss 200.11, bar train loss 2.205, len train loss 0.014, col train loss 151.061


Epoch 208: 1batch [00:00,  6.33batch/s, loss=202]

epoch 207: avg test  loss 278.55, bar  test loss 3.516, len  test loss 0.417, col  test loss 161.085


Epoch 208: 272batch [00:41,  6.54batch/s, loss=196]


epoch 208: avg train loss 200.29, bar train loss 2.209, len train loss 0.015, col train loss 151.068


Epoch 209: 0batch [00:00, ?batch/s]

epoch 208: avg test  loss 284.20, bar  test loss 3.521, len  test loss 0.491, col  test loss 161.223


Epoch 209: 272batch [00:41,  6.55batch/s, loss=199]


epoch 209: avg train loss 200.10, bar train loss 2.204, len train loss 0.015, col train loss 151.044


Epoch 210: 1batch [00:00,  6.45batch/s, loss=193]

epoch 209: avg test  loss 283.16, bar  test loss 3.528, len  test loss 0.479, col  test loss 161.257


Epoch 210: 272batch [00:40,  6.66batch/s, loss=196]


epoch 210: avg train loss 200.00, bar train loss 2.200, len train loss 0.015, col train loss 150.998
epoch 210: avg test  loss 281.21, bar  test loss 3.519, len  test loss 0.459, col  test loss 161.211


Epoch 211: 272batch [00:41,  6.49batch/s, loss=186]


epoch 211: avg train loss 199.89, bar train loss 2.199, len train loss 0.014, col train loss 150.975


Epoch 212: 1batch [00:00,  7.81batch/s, loss=203]

epoch 211: avg test  loss 293.95, bar  test loss 3.530, len  test loss 0.574, col  test loss 161.303


Epoch 212: 272batch [00:38,  7.01batch/s, loss=203]


epoch 212: avg train loss 199.91, bar train loss 2.198, len train loss 0.014, col train loss 150.992


Epoch 213: 1batch [00:00,  7.19batch/s, loss=207]

epoch 212: avg test  loss 286.50, bar  test loss 3.530, len  test loss 0.488, col  test loss 161.252


Epoch 213: 272batch [00:40,  6.70batch/s, loss=235]


epoch 213: avg train loss 199.98, bar train loss 2.199, len train loss 0.015, col train loss 150.970


Epoch 214: 1batch [00:00,  5.95batch/s, loss=200]

epoch 213: avg test  loss 282.41, bar  test loss 3.526, len  test loss 0.472, col  test loss 161.239


Epoch 214: 272batch [00:38,  7.02batch/s, loss=208]


epoch 214: avg train loss 199.78, bar train loss 2.196, len train loss 0.015, col train loss 150.936


Epoch 215: 1batch [00:00,  7.30batch/s, loss=199]

epoch 214: avg test  loss 281.26, bar  test loss 3.540, len  test loss 0.459, col  test loss 161.238


Epoch 215: 272batch [00:42,  6.42batch/s, loss=215]


epoch 215: avg train loss 199.64, bar train loss 2.192, len train loss 0.015, col train loss 150.868
epoch 215: avg test  loss 286.16, bar  test loss 3.532, len  test loss 0.498, col  test loss 161.314


Epoch 216: 272batch [00:42,  6.44batch/s, loss=224]


epoch 216: avg train loss 199.56, bar train loss 2.189, len train loss 0.014, col train loss 150.845


Epoch 217: 0batch [00:00, ?batch/s]

epoch 216: avg test  loss 275.49, bar  test loss 3.543, len  test loss 0.399, col  test loss 161.320


Epoch 217: 272batch [00:47,  5.74batch/s, loss=190]


epoch 217: avg train loss 199.39, bar train loss 2.183, len train loss 0.014, col train loss 150.856


Epoch 218: 1batch [00:00,  7.14batch/s, loss=198]

epoch 217: avg test  loss 283.90, bar  test loss 3.532, len  test loss 0.481, col  test loss 161.252


Epoch 218: 272batch [00:39,  6.83batch/s, loss=215]


epoch 218: avg train loss 199.41, bar train loss 2.184, len train loss 0.014, col train loss 150.828


Epoch 219: 1batch [00:00,  8.13batch/s, loss=196]

epoch 218: avg test  loss 289.47, bar  test loss 3.550, len  test loss 0.536, col  test loss 161.328


Epoch 219: 272batch [00:44,  6.09batch/s, loss=199]


epoch 219: avg train loss 199.36, bar train loss 2.184, len train loss 0.014, col train loss 150.806


Epoch 220: 1batch [00:00,  7.41batch/s, loss=201]

epoch 219: avg test  loss 286.98, bar  test loss 3.559, len  test loss 0.508, col  test loss 161.281


Epoch 220: 272batch [00:41,  6.56batch/s, loss=204]


epoch 220: avg train loss 199.21, bar train loss 2.181, len train loss 0.014, col train loss 150.773
epoch 220: avg test  loss 284.71, bar  test loss 3.536, len  test loss 0.486, col  test loss 161.316


Epoch 221: 272batch [00:37,  7.17batch/s, loss=199]


epoch 221: avg train loss 199.32, bar train loss 2.181, len train loss 0.014, col train loss 150.745


Epoch 222: 1batch [00:00,  6.41batch/s, loss=197]

epoch 221: avg test  loss 287.27, bar  test loss 3.553, len  test loss 0.513, col  test loss 161.295


Epoch 222: 272batch [00:38,  7.02batch/s, loss=216]


epoch 222: avg train loss 199.09, bar train loss 2.178, len train loss 0.014, col train loss 150.714


Epoch 223: 1batch [00:00,  6.54batch/s, loss=194]

epoch 222: avg test  loss 280.24, bar  test loss 3.562, len  test loss 0.440, col  test loss 161.295


Epoch 223: 272batch [00:39,  6.84batch/s, loss=193]


epoch 223: avg train loss 199.20, bar train loss 2.180, len train loss 0.014, col train loss 150.740


Epoch 224: 1batch [00:00,  6.62batch/s, loss=198]

epoch 223: avg test  loss 286.84, bar  test loss 3.543, len  test loss 0.511, col  test loss 161.365


Epoch 224: 272batch [00:41,  6.55batch/s, loss=201]


epoch 224: avg train loss 199.06, bar train loss 2.175, len train loss 0.014, col train loss 150.688


Epoch 225: 0batch [00:00, ?batch/s]

epoch 224: avg test  loss 282.85, bar  test loss 3.565, len  test loss 0.464, col  test loss 161.332


Epoch 225: 272batch [00:32,  8.45batch/s, loss=194]


epoch 225: avg train loss 198.92, bar train loss 2.173, len train loss 0.014, col train loss 150.647
epoch 225: avg test  loss 284.82, bar  test loss 3.570, len  test loss 0.483, col  test loss 161.367


Epoch 226: 272batch [00:32,  8.38batch/s, loss=205]


epoch 226: avg train loss 198.91, bar train loss 2.171, len train loss 0.014, col train loss 150.656


Epoch 227: 1batch [00:00,  7.94batch/s, loss=194]

epoch 226: avg test  loss 283.95, bar  test loss 3.564, len  test loss 0.475, col  test loss 161.515


Epoch 227: 272batch [00:31,  8.64batch/s, loss=209]


epoch 227: avg train loss 198.88, bar train loss 2.170, len train loss 0.014, col train loss 150.651


Epoch 228: 1batch [00:00,  8.70batch/s, loss=188]

epoch 227: avg test  loss 285.94, bar  test loss 3.572, len  test loss 0.498, col  test loss 161.359


Epoch 228: 272batch [00:31,  8.71batch/s, loss=209]


epoch 228: avg train loss 198.82, bar train loss 2.171, len train loss 0.014, col train loss 150.602


Epoch 229: 1batch [00:00,  8.55batch/s, loss=191]

epoch 228: avg test  loss 285.43, bar  test loss 3.578, len  test loss 0.487, col  test loss 161.412


Epoch 229: 272batch [00:31,  8.56batch/s, loss=218]


epoch 229: avg train loss 198.69, bar train loss 2.164, len train loss 0.014, col train loss 150.587


Epoch 230: 1batch [00:00,  8.55batch/s, loss=208]

epoch 229: avg test  loss 284.68, bar  test loss 3.577, len  test loss 0.484, col  test loss 161.347


Epoch 230: 272batch [00:32,  8.25batch/s, loss=182]


epoch 230: avg train loss 198.56, bar train loss 2.162, len train loss 0.014, col train loss 150.554
epoch 230: avg test  loss 281.91, bar  test loss 3.589, len  test loss 0.452, col  test loss 161.428


Epoch 231: 272batch [00:34,  7.84batch/s, loss=184]


epoch 231: avg train loss 198.59, bar train loss 2.165, len train loss 0.014, col train loss 150.542


Epoch 232: 1batch [00:00,  7.87batch/s, loss=190]

epoch 231: avg test  loss 287.45, bar  test loss 3.578, len  test loss 0.508, col  test loss 161.496


Epoch 232: 272batch [00:35,  7.63batch/s, loss=206]


epoch 232: avg train loss 198.41, bar train loss 2.158, len train loss 0.014, col train loss 150.529


Epoch 233: 1batch [00:00,  6.29batch/s, loss=191]

epoch 232: avg test  loss 284.91, bar  test loss 3.574, len  test loss 0.485, col  test loss 161.365


Epoch 233: 272batch [00:42,  6.35batch/s, loss=204]


epoch 233: avg train loss 198.49, bar train loss 2.163, len train loss 0.014, col train loss 150.484


Epoch 234: 1batch [00:00,  6.21batch/s, loss=193]

epoch 233: avg test  loss 286.85, bar  test loss 3.570, len  test loss 0.508, col  test loss 161.353


Epoch 234: 272batch [00:32,  8.50batch/s, loss=208]


epoch 234: avg train loss 198.19, bar train loss 2.153, len train loss 0.014, col train loss 150.459


Epoch 235: 1batch [00:00,  8.70batch/s, loss=206]

epoch 234: avg test  loss 289.00, bar  test loss 3.592, len  test loss 0.524, col  test loss 161.398


Epoch 235: 272batch [00:31,  8.72batch/s, loss=224]


epoch 235: avg train loss 198.23, bar train loss 2.155, len train loss 0.013, col train loss 150.447
epoch 235: avg test  loss 290.41, bar  test loss 3.591, len  test loss 0.534, col  test loss 161.443


Epoch 236: 272batch [00:31,  8.68batch/s, loss=193]


epoch 236: avg train loss 198.30, bar train loss 2.156, len train loss 0.014, col train loss 150.424


Epoch 237: 1batch [00:00,  8.77batch/s, loss=195]

epoch 236: avg test  loss 291.37, bar  test loss 3.585, len  test loss 0.546, col  test loss 161.497


Epoch 237: 272batch [00:31,  8.70batch/s, loss=216]


epoch 237: avg train loss 198.09, bar train loss 2.148, len train loss 0.014, col train loss 150.417


Epoch 238: 1batch [00:00,  6.76batch/s, loss=198]

epoch 237: avg test  loss 284.27, bar  test loss 3.599, len  test loss 0.476, col  test loss 161.553


Epoch 238: 272batch [00:38,  7.03batch/s, loss=188]


epoch 238: avg train loss 198.11, bar train loss 2.153, len train loss 0.013, col train loss 150.388


Epoch 239: 1batch [00:00,  8.20batch/s, loss=188]

epoch 238: avg test  loss 289.13, bar  test loss 3.584, len  test loss 0.523, col  test loss 161.425


Epoch 239: 272batch [00:38,  6.99batch/s, loss=202]


epoch 239: avg train loss 197.89, bar train loss 2.144, len train loss 0.013, col train loss 150.356


Epoch 240: 1batch [00:00,  6.54batch/s, loss=206]

epoch 239: avg test  loss 286.38, bar  test loss 3.602, len  test loss 0.484, col  test loss 161.416


Epoch 240: 272batch [00:34,  7.79batch/s, loss=207]


epoch 240: avg train loss 198.12, bar train loss 2.153, len train loss 0.013, col train loss 150.369
epoch 240: avg test  loss 282.83, bar  test loss 3.602, len  test loss 0.458, col  test loss 161.391


Epoch 241: 272batch [00:42,  6.47batch/s, loss=214]


epoch 241: avg train loss 198.08, bar train loss 2.151, len train loss 0.014, col train loss 150.355


Epoch 242: 1batch [00:00,  7.87batch/s, loss=188]

epoch 241: avg test  loss 282.19, bar  test loss 3.598, len  test loss 0.454, col  test loss 161.504


Epoch 242: 272batch [00:38,  7.02batch/s, loss=218]


epoch 242: avg train loss 197.91, bar train loss 2.146, len train loss 0.013, col train loss 150.322


Epoch 243: 1batch [00:00,  6.71batch/s, loss=195]

epoch 242: avg test  loss 286.06, bar  test loss 3.602, len  test loss 0.481, col  test loss 161.417


Epoch 243: 272batch [00:43,  6.22batch/s, loss=185]


epoch 243: avg train loss 197.82, bar train loss 2.140, len train loss 0.014, col train loss 150.306


Epoch 244: 0batch [00:00, ?batch/s]

epoch 243: avg test  loss 283.05, bar  test loss 3.586, len  test loss 0.460, col  test loss 161.575


Epoch 244: 272batch [00:43,  6.21batch/s, loss=187]


epoch 244: avg train loss 197.68, bar train loss 2.140, len train loss 0.013, col train loss 150.265


Epoch 245: 0batch [00:00, ?batch/s]

epoch 244: avg test  loss 283.18, bar  test loss 3.644, len  test loss 0.456, col  test loss 161.494


Epoch 245: 272batch [00:45,  6.03batch/s, loss=213]


epoch 245: avg train loss 197.76, bar train loss 2.144, len train loss 0.013, col train loss 150.264
epoch 245: avg test  loss 287.50, bar  test loss 3.612, len  test loss 0.502, col  test loss 161.614


Epoch 246: 272batch [00:45,  5.98batch/s, loss=198]


epoch 246: avg train loss 197.54, bar train loss 2.133, len train loss 0.014, col train loss 150.233


Epoch 247: 0batch [00:00, ?batch/s]

epoch 246: avg test  loss 282.81, bar  test loss 3.600, len  test loss 0.458, col  test loss 161.560


Epoch 247: 272batch [00:43,  6.31batch/s, loss=191]


epoch 247: avg train loss 197.41, bar train loss 2.132, len train loss 0.013, col train loss 150.177


Epoch 248: 0batch [00:00, ?batch/s]

epoch 247: avg test  loss 282.91, bar  test loss 3.613, len  test loss 0.460, col  test loss 161.566


Epoch 248: 272batch [00:45,  5.97batch/s, loss=195]


epoch 248: avg train loss 197.53, bar train loss 2.135, len train loss 0.014, col train loss 150.192


Epoch 249: 0batch [00:00, ?batch/s]

epoch 248: avg test  loss 282.29, bar  test loss 3.608, len  test loss 0.432, col  test loss 161.533


Epoch 249: 272batch [00:49,  5.50batch/s, loss=200]


epoch 249: avg train loss 197.43, bar train loss 2.131, len train loss 0.013, col train loss 150.176


Epoch 250: 0batch [00:00, ?batch/s]

epoch 249: avg test  loss 290.02, bar  test loss 3.627, len  test loss 0.518, col  test loss 161.567


Epoch 250: 272batch [00:43,  6.21batch/s, loss=219]


epoch 250: avg train loss 197.37, bar train loss 2.128, len train loss 0.013, col train loss 150.213
epoch 250: avg test  loss 279.39, bar  test loss 3.603, len  test loss 0.421, col  test loss 161.536


Epoch 251: 272batch [00:43,  6.25batch/s, loss=196]


epoch 251: avg train loss 197.26, bar train loss 2.128, len train loss 0.013, col train loss 150.141


Epoch 252: 1batch [00:00,  6.25batch/s, loss=205]

epoch 251: avg test  loss 285.46, bar  test loss 3.626, len  test loss 0.479, col  test loss 161.579


Epoch 252: 272batch [00:45,  5.99batch/s, loss=203]


epoch 252: avg train loss 197.31, bar train loss 2.129, len train loss 0.013, col train loss 150.152


Epoch 253: 1batch [00:00,  7.75batch/s, loss=200]

epoch 252: avg test  loss 286.26, bar  test loss 3.640, len  test loss 0.484, col  test loss 161.627


Epoch 253: 272batch [00:44,  6.07batch/s, loss=225]


epoch 253: avg train loss 197.39, bar train loss 2.133, len train loss 0.013, col train loss 150.121


Epoch 254: 0batch [00:00, ?batch/s, loss=200]

epoch 253: avg test  loss 281.31, bar  test loss 3.633, len  test loss 0.437, col  test loss 161.527


Epoch 254: 272batch [00:49,  5.47batch/s, loss=215]


epoch 254: avg train loss 197.20, bar train loss 2.126, len train loss 0.013, col train loss 150.079


Epoch 255: 1batch [00:00,  6.45batch/s, loss=198]

epoch 254: avg test  loss 286.39, bar  test loss 3.647, len  test loss 0.487, col  test loss 161.627


Epoch 255: 272batch [00:44,  6.15batch/s, loss=202]


epoch 255: avg train loss 197.11, bar train loss 2.126, len train loss 0.013, col train loss 150.075
epoch 255: avg test  loss 285.77, bar  test loss 3.627, len  test loss 0.483, col  test loss 161.640


Epoch 256: 272batch [00:48,  5.66batch/s, loss=206]


epoch 256: avg train loss 197.29, bar train loss 2.128, len train loss 0.014, col train loss 150.104


Epoch 257: 1batch [00:00,  6.21batch/s, loss=195]

epoch 256: avg test  loss 285.55, bar  test loss 3.630, len  test loss 0.482, col  test loss 161.631


Epoch 257: 272batch [00:43,  6.27batch/s, loss=207]


epoch 257: avg train loss 197.08, bar train loss 2.125, len train loss 0.013, col train loss 150.023


Epoch 258: 1batch [00:00,  7.14batch/s, loss=196]

epoch 257: avg test  loss 287.78, bar  test loss 3.639, len  test loss 0.502, col  test loss 161.626


Epoch 258: 272batch [00:42,  6.43batch/s, loss=199]


epoch 258: avg train loss 197.06, bar train loss 2.125, len train loss 0.013, col train loss 150.012


Epoch 259: 1batch [00:00,  5.88batch/s, loss=197]

epoch 258: avg test  loss 286.59, bar  test loss 3.636, len  test loss 0.479, col  test loss 161.630


Epoch 259: 272batch [00:41,  6.51batch/s, loss=215]


epoch 259: avg train loss 197.01, bar train loss 2.122, len train loss 0.013, col train loss 150.011


Epoch 260: 1batch [00:00,  7.69batch/s, loss=195]

epoch 259: avg test  loss 283.03, bar  test loss 3.631, len  test loss 0.456, col  test loss 161.622


Epoch 260: 272batch [00:40,  6.71batch/s, loss=216]


epoch 260: avg train loss 196.90, bar train loss 2.121, len train loss 0.013, col train loss 149.966
epoch 260: avg test  loss 285.03, bar  test loss 3.628, len  test loss 0.474, col  test loss 161.683


Epoch 261: 272batch [00:43,  6.25batch/s, loss=210]


epoch 261: avg train loss 196.77, bar train loss 2.114, len train loss 0.013, col train loss 149.976


Epoch 262: 0batch [00:00, ?batch/s]

epoch 261: avg test  loss 286.11, bar  test loss 3.642, len  test loss 0.476, col  test loss 161.762


Epoch 262: 272batch [00:42,  6.34batch/s, loss=202]


epoch 262: avg train loss 196.82, bar train loss 2.116, len train loss 0.013, col train loss 149.989


Epoch 263: 1batch [00:00,  6.58batch/s, loss=199]

epoch 262: avg test  loss 288.13, bar  test loss 3.639, len  test loss 0.499, col  test loss 161.609


Epoch 263: 272batch [00:43,  6.27batch/s, loss=210]


epoch 263: avg train loss 196.78, bar train loss 2.119, len train loss 0.013, col train loss 149.912


Epoch 264: 1batch [00:00,  6.29batch/s, loss=195]

epoch 263: avg test  loss 288.98, bar  test loss 3.642, len  test loss 0.507, col  test loss 161.664


Epoch 264: 272batch [00:41,  6.48batch/s, loss=203]


epoch 264: avg train loss 196.74, bar train loss 2.116, len train loss 0.012, col train loss 149.929


Epoch 265: 1batch [00:00,  6.54batch/s, loss=197]

epoch 264: avg test  loss 290.08, bar  test loss 3.649, len  test loss 0.521, col  test loss 161.691


Epoch 265: 272batch [00:40,  6.66batch/s, loss=201]


epoch 265: avg train loss 196.63, bar train loss 2.111, len train loss 0.013, col train loss 149.911
epoch 265: avg test  loss 285.70, bar  test loss 3.655, len  test loss 0.472, col  test loss 161.737


Epoch 266: 272batch [00:44,  6.11batch/s, loss=201]


epoch 266: avg train loss 196.69, bar train loss 2.114, len train loss 0.013, col train loss 149.860


Epoch 267: 1batch [00:00,  6.13batch/s, loss=183]

epoch 266: avg test  loss 282.80, bar  test loss 3.647, len  test loss 0.449, col  test loss 161.722


Epoch 267: 272batch [00:41,  6.63batch/s, loss=196]


epoch 267: avg train loss 196.53, bar train loss 2.108, len train loss 0.013, col train loss 149.867


Epoch 268: 1batch [00:00,  7.63batch/s, loss=196]

epoch 267: avg test  loss 283.43, bar  test loss 3.635, len  test loss 0.459, col  test loss 161.615


Epoch 268: 272batch [00:42,  6.39batch/s, loss=195]


epoch 268: avg train loss 196.58, bar train loss 2.115, len train loss 0.012, col train loss 149.857


Epoch 269: 1batch [00:00,  6.13batch/s, loss=203]

epoch 268: avg test  loss 280.70, bar  test loss 3.646, len  test loss 0.425, col  test loss 161.766


Epoch 269: 272batch [00:48,  5.66batch/s, loss=194]


epoch 269: avg train loss 196.40, bar train loss 2.103, len train loss 0.013, col train loss 149.879


Epoch 270: 0batch [00:00, ?batch/s]

epoch 269: avg test  loss 285.22, bar  test loss 3.644, len  test loss 0.475, col  test loss 161.608


Epoch 270: 272batch [00:45,  5.95batch/s, loss=214]


epoch 270: avg train loss 196.31, bar train loss 2.101, len train loss 0.012, col train loss 149.847
epoch 270: avg test  loss 281.43, bar  test loss 3.661, len  test loss 0.434, col  test loss 161.725


Epoch 271: 272batch [00:45,  6.04batch/s, loss=200]


epoch 271: avg train loss 196.36, bar train loss 2.105, len train loss 0.013, col train loss 149.811


Epoch 272: 0batch [00:00, ?batch/s]

epoch 271: avg test  loss 288.57, bar  test loss 3.649, len  test loss 0.503, col  test loss 161.751


Epoch 272: 272batch [00:46,  5.82batch/s, loss=214]


epoch 272: avg train loss 196.44, bar train loss 2.108, len train loss 0.013, col train loss 149.835


Epoch 273: 0batch [00:00, ?batch/s]

epoch 272: avg test  loss 288.22, bar  test loss 3.646, len  test loss 0.502, col  test loss 161.670


Epoch 273: 272batch [00:46,  5.81batch/s, loss=185]


epoch 273: avg train loss 196.15, bar train loss 2.100, len train loss 0.012, col train loss 149.741


Epoch 274: 1batch [00:00,  6.29batch/s, loss=199]

epoch 273: avg test  loss 284.82, bar  test loss 3.650, len  test loss 0.463, col  test loss 161.692


Epoch 274: 272batch [00:43,  6.30batch/s, loss=196]


epoch 274: avg train loss 196.10, bar train loss 2.097, len train loss 0.012, col train loss 149.753


Epoch 275: 1batch [00:00,  6.25batch/s, loss=195]

epoch 274: avg test  loss 283.41, bar  test loss 3.689, len  test loss 0.448, col  test loss 161.690


Epoch 275: 272batch [00:43,  6.27batch/s, loss=192]


epoch 275: avg train loss 196.05, bar train loss 2.096, len train loss 0.012, col train loss 149.742
epoch 275: avg test  loss 289.53, bar  test loss 3.690, len  test loss 0.506, col  test loss 161.877


Epoch 276: 272batch [00:45,  5.99batch/s, loss=206]


epoch 276: avg train loss 196.07, bar train loss 2.097, len train loss 0.012, col train loss 149.717


Epoch 277: 1batch [00:00,  5.95batch/s, loss=203]

epoch 276: avg test  loss 284.29, bar  test loss 3.672, len  test loss 0.458, col  test loss 161.676


Epoch 277: 272batch [00:44,  6.14batch/s, loss=193]


epoch 277: avg train loss 196.15, bar train loss 2.100, len train loss 0.013, col train loss 149.742


Epoch 278: 1batch [00:00,  6.10batch/s, loss=200]

epoch 277: avg test  loss 281.43, bar  test loss 3.669, len  test loss 0.427, col  test loss 161.738


Epoch 278: 272batch [00:42,  6.36batch/s, loss=200]


epoch 278: avg train loss 195.98, bar train loss 2.096, len train loss 0.012, col train loss 149.706


Epoch 279: 1batch [00:00,  5.88batch/s, loss=196]

epoch 278: avg test  loss 288.57, bar  test loss 3.675, len  test loss 0.493, col  test loss 161.800


Epoch 279: 272batch [00:42,  6.46batch/s, loss=202]


epoch 279: avg train loss 195.98, bar train loss 2.094, len train loss 0.012, col train loss 149.688


Epoch 280: 0batch [00:00, ?batch/s]

epoch 279: avg test  loss 296.63, bar  test loss 3.657, len  test loss 0.576, col  test loss 161.706


Epoch 280: 272batch [00:45,  5.93batch/s, loss=203]


epoch 280: avg train loss 196.04, bar train loss 2.097, len train loss 0.013, col train loss 149.675
epoch 280: avg test  loss 290.43, bar  test loss 3.674, len  test loss 0.518, col  test loss 161.768


Epoch 281: 272batch [00:43,  6.28batch/s, loss=198]


epoch 281: avg train loss 195.83, bar train loss 2.090, len train loss 0.013, col train loss 149.647


Epoch 282: 1batch [00:00,  7.75batch/s, loss=189]

epoch 281: avg test  loss 285.13, bar  test loss 3.676, len  test loss 0.466, col  test loss 161.689


Epoch 282: 272batch [00:44,  6.08batch/s, loss=188]


epoch 282: avg train loss 196.12, bar train loss 2.102, len train loss 0.012, col train loss 149.671


Epoch 283: 1batch [00:00,  6.25batch/s, loss=199]

epoch 282: avg test  loss 288.10, bar  test loss 3.687, len  test loss 0.493, col  test loss 161.812


Epoch 283: 272batch [00:42,  6.47batch/s, loss=202]


epoch 283: avg train loss 195.93, bar train loss 2.093, len train loss 0.012, col train loss 149.661


Epoch 284: 1batch [00:00,  6.37batch/s, loss=189]

epoch 283: avg test  loss 285.53, bar  test loss 3.684, len  test loss 0.471, col  test loss 161.775


Epoch 284: 272batch [00:45,  5.98batch/s, loss=201]


epoch 284: avg train loss 195.82, bar train loss 2.091, len train loss 0.012, col train loss 149.619


Epoch 285: 1batch [00:00,  6.37batch/s, loss=190]

epoch 284: avg test  loss 288.19, bar  test loss 3.682, len  test loss 0.497, col  test loss 161.790


Epoch 285: 272batch [00:40,  6.66batch/s, loss=210]


epoch 285: avg train loss 195.75, bar train loss 2.087, len train loss 0.012, col train loss 149.653
epoch 285: avg test  loss 284.87, bar  test loss 3.679, len  test loss 0.463, col  test loss 161.745


Epoch 286: 272batch [00:41,  6.54batch/s, loss=202]


epoch 286: avg train loss 195.60, bar train loss 2.084, len train loss 0.012, col train loss 149.605


Epoch 287: 1batch [00:00,  6.67batch/s, loss=197]

epoch 286: avg test  loss 288.69, bar  test loss 3.678, len  test loss 0.501, col  test loss 161.749


Epoch 287: 272batch [00:42,  6.44batch/s, loss=201]


epoch 287: avg train loss 195.54, bar train loss 2.084, len train loss 0.012, col train loss 149.547


Epoch 288: 1batch [00:00,  7.04batch/s, loss=194]

epoch 287: avg test  loss 286.84, bar  test loss 3.699, len  test loss 0.479, col  test loss 161.833


Epoch 288: 272batch [00:39,  6.91batch/s, loss=224]


epoch 288: avg train loss 195.51, bar train loss 2.080, len train loss 0.012, col train loss 149.563


Epoch 289: 1batch [00:00,  6.25batch/s, loss=204]

epoch 288: avg test  loss 287.98, bar  test loss 3.701, len  test loss 0.485, col  test loss 161.802


Epoch 289: 272batch [00:40,  6.68batch/s, loss=187]


epoch 289: avg train loss 195.55, bar train loss 2.084, len train loss 0.012, col train loss 149.524


Epoch 290: 1batch [00:00,  7.14batch/s, loss=195]

epoch 289: avg test  loss 287.34, bar  test loss 3.686, len  test loss 0.487, col  test loss 161.731


Epoch 290: 272batch [00:39,  6.81batch/s, loss=203]


epoch 290: avg train loss 195.44, bar train loss 2.083, len train loss 0.012, col train loss 149.506
epoch 290: avg test  loss 288.23, bar  test loss 3.701, len  test loss 0.491, col  test loss 161.824


Epoch 291: 272batch [00:50,  5.41batch/s, loss=209]


epoch 291: avg train loss 195.49, bar train loss 2.086, len train loss 0.012, col train loss 149.530


Epoch 292: 1batch [00:00,  6.17batch/s, loss=194]

epoch 291: avg test  loss 286.85, bar  test loss 3.701, len  test loss 0.475, col  test loss 161.876


Epoch 292: 272batch [00:48,  5.66batch/s, loss=190]


epoch 292: avg train loss 195.42, bar train loss 2.081, len train loss 0.012, col train loss 149.501


Epoch 293: 1batch [00:00,  6.90batch/s, loss=192]

epoch 292: avg test  loss 282.93, bar  test loss 3.676, len  test loss 0.444, col  test loss 161.759


Epoch 293: 272batch [00:47,  5.68batch/s, loss=206]


epoch 293: avg train loss 195.51, bar train loss 2.084, len train loss 0.012, col train loss 149.513


Epoch 294: 1batch [00:00,  6.80batch/s, loss=200]

epoch 293: avg test  loss 287.65, bar  test loss 3.691, len  test loss 0.488, col  test loss 161.840


Epoch 294: 272batch [00:41,  6.49batch/s, loss=207]


epoch 294: avg train loss 195.51, bar train loss 2.082, len train loss 0.012, col train loss 149.544


Epoch 295: 1batch [00:00,  6.17batch/s, loss=195]

epoch 294: avg test  loss 287.05, bar  test loss 3.704, len  test loss 0.479, col  test loss 161.811


Epoch 295: 272batch [00:48,  5.58batch/s, loss=190]


epoch 295: avg train loss 195.25, bar train loss 2.078, len train loss 0.012, col train loss 149.445
epoch 295: avg test  loss 287.56, bar  test loss 3.688, len  test loss 0.488, col  test loss 161.781


Epoch 296: 272batch [00:46,  5.81batch/s, loss=216]


epoch 296: avg train loss 195.35, bar train loss 2.082, len train loss 0.012, col train loss 149.457


Epoch 297: 1batch [00:00,  6.54batch/s, loss=193]

epoch 296: avg test  loss 286.73, bar  test loss 3.694, len  test loss 0.479, col  test loss 161.869


Epoch 297: 272batch [00:49,  5.52batch/s, loss=198]


epoch 297: avg train loss 195.31, bar train loss 2.078, len train loss 0.012, col train loss 149.441


Epoch 298: 1batch [00:00,  6.06batch/s, loss=194]

epoch 297: avg test  loss 283.61, bar  test loss 3.681, len  test loss 0.447, col  test loss 161.904


Epoch 298: 272batch [00:51,  5.29batch/s, loss=187]


epoch 298: avg train loss 195.36, bar train loss 2.079, len train loss 0.012, col train loss 149.486


Epoch 299: 1batch [00:00,  6.45batch/s, loss=201]

epoch 298: avg test  loss 288.30, bar  test loss 3.701, len  test loss 0.489, col  test loss 161.847


Epoch 299: 272batch [00:54,  5.02batch/s, loss=191]


epoch 299: avg train loss 195.16, bar train loss 2.073, len train loss 0.012, col train loss 149.427


Epoch 300: 0batch [00:00, ?batch/s]

epoch 299: avg test  loss 286.72, bar  test loss 3.695, len  test loss 0.477, col  test loss 161.841


Epoch 300: 272batch [00:54,  5.01batch/s, loss=204]


epoch 300: avg train loss 195.15, bar train loss 2.072, len train loss 0.012, col train loss 149.420
epoch 300: avg test  loss 284.51, bar  test loss 3.706, len  test loss 0.451, col  test loss 161.794


Epoch 301: 272batch [00:48,  5.59batch/s, loss=206]


epoch 301: avg train loss 195.01, bar train loss 2.069, len train loss 0.012, col train loss 149.388


Epoch 302: 0batch [00:00, ?batch/s]

epoch 301: avg test  loss 289.46, bar  test loss 3.712, len  test loss 0.503, col  test loss 161.883


Epoch 302: 272batch [00:44,  6.09batch/s, loss=211]


epoch 302: avg train loss 195.12, bar train loss 2.073, len train loss 0.012, col train loss 149.409


Epoch 303: 1batch [00:00,  6.21batch/s, loss=199]

epoch 302: avg test  loss 286.84, bar  test loss 3.717, len  test loss 0.476, col  test loss 161.855


Epoch 303: 272batch [00:43,  6.26batch/s, loss=197]


epoch 303: avg train loss 195.02, bar train loss 2.071, len train loss 0.012, col train loss 149.346


Epoch 304: 1batch [00:00,  6.41batch/s, loss=199]

epoch 303: avg test  loss 287.28, bar  test loss 3.707, len  test loss 0.471, col  test loss 161.794


Epoch 304: 272batch [00:48,  5.59batch/s, loss=189]


epoch 304: avg train loss 194.93, bar train loss 2.069, len train loss 0.012, col train loss 149.338


Epoch 305: 1batch [00:00,  6.17batch/s, loss=199]

epoch 304: avg test  loss 287.96, bar  test loss 3.725, len  test loss 0.486, col  test loss 161.839


Epoch 305: 272batch [00:41,  6.58batch/s, loss=197]


epoch 305: avg train loss 195.09, bar train loss 2.072, len train loss 0.012, col train loss 149.379
epoch 305: avg test  loss 289.93, bar  test loss 3.715, len  test loss 0.508, col  test loss 161.871


Epoch 306: 272batch [00:41,  6.63batch/s, loss=211]


epoch 306: avg train loss 194.88, bar train loss 2.067, len train loss 0.012, col train loss 149.315


Epoch 307: 1batch [00:00,  6.37batch/s, loss=187]

epoch 306: avg test  loss 290.27, bar  test loss 3.713, len  test loss 0.510, col  test loss 161.824


Epoch 307: 272batch [00:47,  5.78batch/s, loss=199]


epoch 307: avg train loss 194.90, bar train loss 2.068, len train loss 0.012, col train loss 149.335


Epoch 308: 0batch [00:00, ?batch/s]

epoch 307: avg test  loss 289.60, bar  test loss 3.697, len  test loss 0.496, col  test loss 161.826


Epoch 308: 272batch [00:48,  5.65batch/s, loss=199]


epoch 308: avg train loss 194.90, bar train loss 2.067, len train loss 0.012, col train loss 149.352


Epoch 309: 1batch [00:00,  6.41batch/s, loss=197]

epoch 308: avg test  loss 287.51, bar  test loss 3.710, len  test loss 0.484, col  test loss 161.846


Epoch 309: 272batch [00:44,  6.17batch/s, loss=208]


epoch 309: avg train loss 194.76, bar train loss 2.063, len train loss 0.012, col train loss 149.270


Epoch 310: 1batch [00:00,  6.10batch/s, loss=196]

epoch 309: avg test  loss 284.89, bar  test loss 3.711, len  test loss 0.457, col  test loss 161.892


Epoch 310: 272batch [00:45,  5.99batch/s, loss=190]


epoch 310: avg train loss 194.61, bar train loss 2.060, len train loss 0.011, col train loss 149.248
epoch 310: avg test  loss 286.84, bar  test loss 3.717, len  test loss 0.469, col  test loss 161.848


Epoch 311: 272batch [00:51,  5.33batch/s, loss=214]


epoch 311: avg train loss 194.69, bar train loss 2.061, len train loss 0.011, col train loss 149.295


Epoch 312: 0batch [00:00, ?batch/s]

epoch 311: avg test  loss 288.87, bar  test loss 3.741, len  test loss 0.493, col  test loss 161.931


Epoch 312: 272batch [00:44,  6.08batch/s, loss=197]


epoch 312: avg train loss 194.61, bar train loss 2.061, len train loss 0.012, col train loss 149.224


Epoch 313: 1batch [00:00,  6.80batch/s, loss=194]

epoch 312: avg test  loss 287.64, bar  test loss 3.734, len  test loss 0.480, col  test loss 161.926


Epoch 313: 272batch [00:46,  5.82batch/s, loss=202]


epoch 313: avg train loss 194.65, bar train loss 2.062, len train loss 0.012, col train loss 149.254


Epoch 314: 1batch [00:00,  6.67batch/s, loss=196]

epoch 313: avg test  loss 288.47, bar  test loss 3.739, len  test loss 0.482, col  test loss 161.856


Epoch 314: 272batch [00:49,  5.45batch/s, loss=201]


epoch 314: avg train loss 194.63, bar train loss 2.060, len train loss 0.012, col train loss 149.211


Epoch 315: 1batch [00:00,  7.69batch/s, loss=191]

epoch 314: avg test  loss 293.28, bar  test loss 3.715, len  test loss 0.536, col  test loss 161.873


Epoch 315: 272batch [00:49,  5.49batch/s, loss=197]


epoch 315: avg train loss 194.50, bar train loss 2.059, len train loss 0.011, col train loss 149.187
epoch 315: avg test  loss 287.67, bar  test loss 3.726, len  test loss 0.483, col  test loss 161.862


Epoch 316: 272batch [00:45,  5.94batch/s, loss=206]


epoch 316: avg train loss 194.41, bar train loss 2.053, len train loss 0.011, col train loss 149.221


Epoch 317: 0batch [00:00, ?batch/s]

epoch 316: avg test  loss 288.53, bar  test loss 3.739, len  test loss 0.489, col  test loss 161.937


Epoch 317: 272batch [00:49,  5.52batch/s, loss=206]


epoch 317: avg train loss 194.51, bar train loss 2.057, len train loss 0.012, col train loss 149.194


Epoch 318: 0batch [00:00, ?batch/s]

epoch 317: avg test  loss 290.40, bar  test loss 3.720, len  test loss 0.507, col  test loss 161.914


Epoch 318: 272batch [00:50,  5.35batch/s, loss=203]


epoch 318: avg train loss 194.60, bar train loss 2.059, len train loss 0.012, col train loss 149.215


Epoch 319: 0batch [00:00, ?batch/s]

epoch 318: avg test  loss 287.91, bar  test loss 3.734, len  test loss 0.483, col  test loss 161.997


Epoch 319: 272batch [00:46,  5.88batch/s, loss=199]


epoch 319: avg train loss 194.32, bar train loss 2.049, len train loss 0.012, col train loss 149.173


Epoch 320: 0batch [00:00, ?batch/s]

epoch 319: avg test  loss 285.77, bar  test loss 3.742, len  test loss 0.457, col  test loss 161.923


Epoch 320: 272batch [00:50,  5.40batch/s, loss=219]


epoch 320: avg train loss 194.41, bar train loss 2.055, len train loss 0.011, col train loss 149.172
epoch 320: avg test  loss 287.88, bar  test loss 3.733, len  test loss 0.483, col  test loss 161.925


Epoch 321: 272batch [00:48,  5.58batch/s, loss=200]


epoch 321: avg train loss 194.40, bar train loss 2.051, len train loss 0.012, col train loss 149.174


Epoch 322: 1batch [00:00,  6.21batch/s, loss=201]

epoch 321: avg test  loss 293.40, bar  test loss 3.736, len  test loss 0.538, col  test loss 161.870


Epoch 322: 272batch [00:40,  6.74batch/s, loss=180]


epoch 322: avg train loss 194.37, bar train loss 2.057, len train loss 0.011, col train loss 149.112


Epoch 323: 1batch [00:00,  6.41batch/s, loss=189]

epoch 322: avg test  loss 286.01, bar  test loss 3.742, len  test loss 0.461, col  test loss 161.955


Epoch 323: 272batch [00:47,  5.76batch/s, loss=195]


epoch 323: avg train loss 194.30, bar train loss 2.053, len train loss 0.011, col train loss 149.117


Epoch 324: 0batch [00:00, ?batch/s]

epoch 323: avg test  loss 289.39, bar  test loss 3.767, len  test loss 0.488, col  test loss 161.929


Epoch 324: 272batch [00:54,  4.96batch/s, loss=192]


epoch 324: avg train loss 194.31, bar train loss 2.053, len train loss 0.011, col train loss 149.123


Epoch 325: 1batch [00:00,  6.14batch/s, loss=194]

epoch 324: avg test  loss 289.49, bar  test loss 3.744, len  test loss 0.494, col  test loss 162.032


Epoch 325: 272batch [00:51,  5.31batch/s, loss=193]


epoch 325: avg train loss 194.29, bar train loss 2.051, len train loss 0.011, col train loss 149.137
epoch 325: avg test  loss 288.00, bar  test loss 3.746, len  test loss 0.480, col  test loss 162.075


Epoch 326: 272batch [00:52,  5.23batch/s, loss=186]


epoch 326: avg train loss 194.22, bar train loss 2.047, len train loss 0.012, col train loss 149.126


Epoch 327: 0batch [00:00, ?batch/s]

epoch 326: avg test  loss 286.87, bar  test loss 3.733, len  test loss 0.474, col  test loss 161.924


Epoch 327: 272batch [00:54,  4.96batch/s, loss=198]


epoch 327: avg train loss 194.04, bar train loss 2.043, len train loss 0.012, col train loss 149.035


Epoch 328: 0batch [00:00, ?batch/s]

epoch 327: avg test  loss 286.09, bar  test loss 3.754, len  test loss 0.462, col  test loss 161.849


Epoch 328: 272batch [00:50,  5.44batch/s, loss=207]


epoch 328: avg train loss 194.16, bar train loss 2.050, len train loss 0.011, col train loss 149.047


Epoch 329: 0batch [00:00, ?batch/s]

epoch 328: avg test  loss 291.65, bar  test loss 3.741, len  test loss 0.518, col  test loss 162.063


Epoch 329: 272batch [00:52,  5.20batch/s, loss=210]


epoch 329: avg train loss 194.21, bar train loss 2.049, len train loss 0.011, col train loss 149.102


Epoch 330: 0batch [00:00, ?batch/s]

epoch 329: avg test  loss 285.99, bar  test loss 3.746, len  test loss 0.462, col  test loss 161.992


Epoch 330: 272batch [00:53,  5.13batch/s, loss=201]


epoch 330: avg train loss 194.14, bar train loss 2.047, len train loss 0.011, col train loss 149.063
epoch 330: avg test  loss 290.38, bar  test loss 3.780, len  test loss 0.493, col  test loss 162.011


Epoch 331: 272batch [00:55,  4.91batch/s, loss=193]


epoch 331: avg train loss 194.26, bar train loss 2.051, len train loss 0.012, col train loss 149.118


Epoch 332: 1batch [00:00,  6.67batch/s, loss=193]

epoch 331: avg test  loss 287.88, bar  test loss 3.756, len  test loss 0.477, col  test loss 161.950


Epoch 332: 272batch [00:53,  5.09batch/s, loss=197]


epoch 332: avg train loss 194.08, bar train loss 2.048, len train loss 0.011, col train loss 149.036


Epoch 333: 0batch [00:00, ?batch/s]

epoch 332: avg test  loss 287.94, bar  test loss 3.756, len  test loss 0.477, col  test loss 162.086


Epoch 333: 272batch [00:53,  5.04batch/s, loss=198]


epoch 333: avg train loss 194.09, bar train loss 2.047, len train loss 0.011, col train loss 149.043


Epoch 334: 0batch [00:00, ?batch/s]

epoch 333: avg test  loss 290.82, bar  test loss 3.759, len  test loss 0.507, col  test loss 162.033


Epoch 334: 272batch [00:52,  5.16batch/s, loss=214]


epoch 334: avg train loss 194.02, bar train loss 2.049, len train loss 0.011, col train loss 148.963


Epoch 335: 1batch [00:00,  6.49batch/s, loss=186]

epoch 334: avg test  loss 286.77, bar  test loss 3.772, len  test loss 0.465, col  test loss 161.987


Epoch 335: 272batch [01:02,  4.34batch/s, loss=209]


epoch 335: avg train loss 193.93, bar train loss 2.039, len train loss 0.012, col train loss 149.005
epoch 335: avg test  loss 284.69, bar  test loss 3.762, len  test loss 0.444, col  test loss 161.990


Epoch 336: 272batch [00:43,  6.30batch/s, loss=195]


epoch 336: avg train loss 193.97, bar train loss 2.044, len train loss 0.011, col train loss 148.990


Epoch 337: 1batch [00:00,  6.45batch/s, loss=198]

epoch 336: avg test  loss 291.13, bar  test loss 3.758, len  test loss 0.509, col  test loss 161.957


Epoch 337: 272batch [00:41,  6.51batch/s, loss=199]


epoch 337: avg train loss 193.85, bar train loss 2.041, len train loss 0.011, col train loss 148.985


Epoch 338: 0batch [00:00, ?batch/s]

epoch 337: avg test  loss 286.73, bar  test loss 3.753, len  test loss 0.469, col  test loss 161.953


Epoch 338: 272batch [00:49,  5.53batch/s, loss=209]


epoch 338: avg train loss 193.80, bar train loss 2.038, len train loss 0.011, col train loss 148.986


Epoch 339: 1batch [00:00,  6.13batch/s, loss=189]

epoch 338: avg test  loss 288.55, bar  test loss 3.750, len  test loss 0.487, col  test loss 161.974


Epoch 339: 272batch [00:55,  4.92batch/s, loss=215]


epoch 339: avg train loss 193.79, bar train loss 2.039, len train loss 0.011, col train loss 148.954


Epoch 340: 1batch [00:00,  7.30batch/s, loss=192]

epoch 339: avg test  loss 288.06, bar  test loss 3.750, len  test loss 0.480, col  test loss 162.004


Epoch 340: 272batch [00:47,  5.68batch/s, loss=188]


epoch 340: avg train loss 193.66, bar train loss 2.031, len train loss 0.011, col train loss 148.969
epoch 340: avg test  loss 287.06, bar  test loss 3.754, len  test loss 0.466, col  test loss 162.090


Epoch 341: 272batch [00:45,  6.01batch/s, loss=202]


epoch 341: avg train loss 193.83, bar train loss 2.040, len train loss 0.011, col train loss 148.964


Epoch 342: 0batch [00:00, ?batch/s]

epoch 341: avg test  loss 287.26, bar  test loss 3.745, len  test loss 0.475, col  test loss 161.979


Epoch 342: 272batch [00:51,  5.25batch/s, loss=206]


epoch 342: avg train loss 193.79, bar train loss 2.042, len train loss 0.011, col train loss 148.911


Epoch 343: 0batch [00:00, ?batch/s]

epoch 342: avg test  loss 289.43, bar  test loss 3.749, len  test loss 0.487, col  test loss 162.033


Epoch 343: 272batch [00:46,  5.79batch/s, loss=194]


epoch 343: avg train loss 193.62, bar train loss 2.034, len train loss 0.011, col train loss 148.914


Epoch 344: 1batch [00:00,  6.71batch/s, loss=193]

epoch 343: avg test  loss 287.93, bar  test loss 3.755, len  test loss 0.476, col  test loss 161.957


Epoch 344: 272batch [00:40,  6.70batch/s, loss=197]


epoch 344: avg train loss 193.65, bar train loss 2.036, len train loss 0.011, col train loss 148.887


Epoch 345: 1batch [00:00,  6.37batch/s, loss=195]

epoch 344: avg test  loss 292.63, bar  test loss 3.769, len  test loss 0.522, col  test loss 162.005


Epoch 345: 272batch [00:52,  5.17batch/s, loss=201]


epoch 345: avg train loss 193.72, bar train loss 2.036, len train loss 0.012, col train loss 148.927
epoch 345: avg test  loss 288.61, bar  test loss 3.767, len  test loss 0.477, col  test loss 162.000


Epoch 346: 272batch [00:53,  5.10batch/s, loss=194]


epoch 346: avg train loss 193.67, bar train loss 2.038, len train loss 0.011, col train loss 148.875


Epoch 347: 0batch [00:00, ?batch/s, loss=187]

epoch 346: avg test  loss 291.05, bar  test loss 3.754, len  test loss 0.508, col  test loss 161.977


Epoch 347: 272batch [00:50,  5.42batch/s, loss=198]


epoch 347: avg train loss 193.59, bar train loss 2.034, len train loss 0.011, col train loss 148.883


Epoch 348: 0batch [00:00, ?batch/s]

epoch 347: avg test  loss 291.97, bar  test loss 3.759, len  test loss 0.518, col  test loss 162.012


Epoch 348: 272batch [00:53,  5.08batch/s, loss=203]


epoch 348: avg train loss 193.56, bar train loss 2.036, len train loss 0.011, col train loss 148.855


Epoch 349: 0batch [00:00, ?batch/s, loss=186]

epoch 348: avg test  loss 287.51, bar  test loss 3.763, len  test loss 0.471, col  test loss 161.996


Epoch 349: 272batch [00:48,  5.62batch/s, loss=203]


epoch 349: avg train loss 193.50, bar train loss 2.033, len train loss 0.011, col train loss 148.863


Epoch 350: 1batch [00:00,  7.58batch/s, loss=190]

epoch 349: avg test  loss 288.05, bar  test loss 3.780, len  test loss 0.473, col  test loss 162.018


Epoch 350: 272batch [00:39,  6.83batch/s, loss=194]


epoch 350: avg train loss 193.45, bar train loss 2.031, len train loss 0.011, col train loss 148.841
epoch 350: avg test  loss 291.23, bar  test loss 3.785, len  test loss 0.506, col  test loss 162.068


Epoch 351: 272batch [00:42,  6.39batch/s, loss=201]


epoch 351: avg train loss 193.39, bar train loss 2.029, len train loss 0.011, col train loss 148.805


Epoch 352: 0batch [00:00, ?batch/s]

epoch 351: avg test  loss 290.51, bar  test loss 3.764, len  test loss 0.500, col  test loss 161.989


Epoch 352: 272batch [00:48,  5.57batch/s, loss=205]


epoch 352: avg train loss 193.41, bar train loss 2.029, len train loss 0.011, col train loss 148.846


Epoch 353: 0batch [00:00, ?batch/s]

epoch 352: avg test  loss 287.82, bar  test loss 3.800, len  test loss 0.466, col  test loss 162.087


Epoch 353: 272batch [00:49,  5.53batch/s, loss=181]


epoch 353: avg train loss 193.48, bar train loss 2.032, len train loss 0.011, col train loss 148.841


Epoch 354: 0batch [00:00, ?batch/s]

epoch 353: avg test  loss 286.76, bar  test loss 3.772, len  test loss 0.462, col  test loss 161.983


Epoch 354: 272batch [00:53,  5.04batch/s, loss=190]


epoch 354: avg train loss 193.38, bar train loss 2.026, len train loss 0.011, col train loss 148.834


Epoch 355: 0batch [00:00, ?batch/s]

epoch 354: avg test  loss 288.66, bar  test loss 3.775, len  test loss 0.478, col  test loss 161.986


Epoch 355: 272batch [00:52,  5.21batch/s, loss=194]


epoch 355: avg train loss 193.43, bar train loss 2.029, len train loss 0.011, col train loss 148.827
epoch 355: avg test  loss 288.07, bar  test loss 3.779, len  test loss 0.474, col  test loss 162.115


Epoch 356: 272batch [00:52,  5.17batch/s, loss=192]


epoch 356: avg train loss 193.45, bar train loss 2.030, len train loss 0.011, col train loss 148.798


Epoch 357: 1batch [00:00,  6.54batch/s, loss=194]

epoch 356: avg test  loss 288.37, bar  test loss 3.766, len  test loss 0.481, col  test loss 162.001


Epoch 357: 272batch [00:54,  4.97batch/s, loss=201]


epoch 357: avg train loss 193.30, bar train loss 2.030, len train loss 0.011, col train loss 148.762


Epoch 358: 0batch [00:00, ?batch/s]

epoch 357: avg test  loss 288.88, bar  test loss 3.770, len  test loss 0.485, col  test loss 162.110


Epoch 358: 272batch [00:53,  5.04batch/s, loss=204]


epoch 358: avg train loss 193.29, bar train loss 2.027, len train loss 0.011, col train loss 148.780


Epoch 359: 1batch [00:00,  6.54batch/s, loss=192]

epoch 358: avg test  loss 288.29, bar  test loss 3.799, len  test loss 0.475, col  test loss 162.030


Epoch 359: 272batch [00:56,  4.81batch/s, loss=213]


epoch 359: avg train loss 193.42, bar train loss 2.032, len train loss 0.011, col train loss 148.777


Epoch 360: 0batch [00:00, ?batch/s]

epoch 359: avg test  loss 291.14, bar  test loss 3.786, len  test loss 0.505, col  test loss 162.060


Epoch 360: 272batch [00:54,  4.95batch/s, loss=224]


epoch 360: avg train loss 193.30, bar train loss 2.027, len train loss 0.011, col train loss 148.764
epoch 360: avg test  loss 290.27, bar  test loss 3.768, len  test loss 0.497, col  test loss 162.094


Epoch 361: 272batch [00:54,  5.02batch/s, loss=196]


epoch 361: avg train loss 193.37, bar train loss 2.028, len train loss 0.011, col train loss 148.804


Epoch 362: 0batch [00:00, ?batch/s]

epoch 361: avg test  loss 284.47, bar  test loss 3.795, len  test loss 0.436, col  test loss 162.082


Epoch 362: 272batch [00:53,  5.12batch/s, loss=189]


epoch 362: avg train loss 193.18, bar train loss 2.022, len train loss 0.011, col train loss 148.735


Epoch 363: 0batch [00:00, ?batch/s]

epoch 362: avg test  loss 287.84, bar  test loss 3.810, len  test loss 0.465, col  test loss 162.086


Epoch 363: 272batch [00:52,  5.21batch/s, loss=195]


epoch 363: avg train loss 193.04, bar train loss 2.020, len train loss 0.011, col train loss 148.674


Epoch 364: 0batch [00:00, ?batch/s]

epoch 363: avg test  loss 289.34, bar  test loss 3.805, len  test loss 0.482, col  test loss 162.092


Epoch 364: 272batch [00:53,  5.07batch/s, loss=191]


epoch 364: avg train loss 193.09, bar train loss 2.021, len train loss 0.011, col train loss 148.729


Epoch 365: 0batch [00:00, ?batch/s]

epoch 364: avg test  loss 289.00, bar  test loss 3.794, len  test loss 0.480, col  test loss 162.071


Epoch 365: 272batch [00:51,  5.29batch/s, loss=196]


epoch 365: avg train loss 193.20, bar train loss 2.024, len train loss 0.011, col train loss 148.754
epoch 365: avg test  loss 286.85, bar  test loss 3.796, len  test loss 0.460, col  test loss 162.059


Epoch 366: 272batch [00:50,  5.37batch/s, loss=185]


epoch 366: avg train loss 193.21, bar train loss 2.024, len train loss 0.011, col train loss 148.749


Epoch 367: 1batch [00:00,  7.04batch/s, loss=182]

epoch 366: avg test  loss 288.46, bar  test loss 3.800, len  test loss 0.475, col  test loss 162.096


Epoch 367: 272batch [00:49,  5.45batch/s, loss=209]


epoch 367: avg train loss 193.19, bar train loss 2.022, len train loss 0.012, col train loss 148.724


Epoch 368: 0batch [00:00, ?batch/s]

epoch 367: avg test  loss 292.94, bar  test loss 3.798, len  test loss 0.520, col  test loss 162.206


Epoch 368: 272batch [00:52,  5.17batch/s, loss=187]


epoch 368: avg train loss 192.87, bar train loss 2.013, len train loss 0.011, col train loss 148.657


Epoch 369: 1batch [00:00,  6.67batch/s, loss=191]

epoch 368: avg test  loss 288.44, bar  test loss 3.797, len  test loss 0.475, col  test loss 162.118


Epoch 369: 272batch [00:39,  6.87batch/s, loss=208]


epoch 369: avg train loss 192.95, bar train loss 2.017, len train loss 0.011, col train loss 148.665


Epoch 370: 0batch [00:00, ?batch/s]

epoch 369: avg test  loss 290.91, bar  test loss 3.812, len  test loss 0.496, col  test loss 162.087


Epoch 370: 272batch [00:37,  7.18batch/s, loss=185]


epoch 370: avg train loss 192.96, bar train loss 2.018, len train loss 0.011, col train loss 148.667
epoch 370: avg test  loss 287.88, bar  test loss 3.802, len  test loss 0.469, col  test loss 162.102


Epoch 371: 272batch [00:32,  8.28batch/s, loss=189]


epoch 371: avg train loss 192.96, bar train loss 2.019, len train loss 0.011, col train loss 148.639


Epoch 372: 1batch [00:00,  8.20batch/s, loss=196]

epoch 371: avg test  loss 290.02, bar  test loss 3.806, len  test loss 0.490, col  test loss 162.077


Epoch 372: 272batch [00:32,  8.48batch/s, loss=200]


epoch 372: avg train loss 193.03, bar train loss 2.019, len train loss 0.011, col train loss 148.659


Epoch 373: 1batch [00:00,  8.13batch/s, loss=203]

epoch 372: avg test  loss 289.12, bar  test loss 3.814, len  test loss 0.478, col  test loss 162.148


Epoch 373: 272batch [00:31,  8.52batch/s, loss=198]


epoch 373: avg train loss 192.95, bar train loss 2.017, len train loss 0.011, col train loss 148.663


Epoch 374: 0batch [00:00, ?batch/s]

epoch 373: avg test  loss 291.36, bar  test loss 3.823, len  test loss 0.498, col  test loss 162.157


Epoch 374: 272batch [00:32,  8.43batch/s, loss=189]


epoch 374: avg train loss 192.97, bar train loss 2.018, len train loss 0.011, col train loss 148.655


Epoch 375: 1batch [00:00,  8.00batch/s, loss=201]

epoch 374: avg test  loss 288.98, bar  test loss 3.804, len  test loss 0.479, col  test loss 162.203


Epoch 375: 272batch [00:31,  8.50batch/s, loss=206]


epoch 375: avg train loss 192.93, bar train loss 2.016, len train loss 0.011, col train loss 148.626
epoch 375: avg test  loss 289.69, bar  test loss 3.825, len  test loss 0.483, col  test loss 162.101


Epoch 376: 272batch [00:32,  8.48batch/s, loss=190]


epoch 376: avg train loss 193.01, bar train loss 2.020, len train loss 0.011, col train loss 148.652


Epoch 377: 1batch [00:00,  8.06batch/s, loss=197]

epoch 376: avg test  loss 289.85, bar  test loss 3.805, len  test loss 0.488, col  test loss 162.146


Epoch 377: 272batch [00:32,  8.27batch/s, loss=190]


epoch 377: avg train loss 192.90, bar train loss 2.018, len train loss 0.011, col train loss 148.596


Epoch 378: 1batch [00:00,  8.13batch/s, loss=189]

epoch 377: avg test  loss 289.51, bar  test loss 3.812, len  test loss 0.483, col  test loss 162.154


Epoch 378: 272batch [00:33,  8.04batch/s, loss=196]


epoch 378: avg train loss 192.68, bar train loss 2.010, len train loss 0.011, col train loss 148.564


Epoch 379: 1batch [00:00,  8.40batch/s, loss=186]

epoch 378: avg test  loss 294.80, bar  test loss 3.794, len  test loss 0.540, col  test loss 162.112


Epoch 379: 272batch [00:35,  7.63batch/s, loss=214]


epoch 379: avg train loss 192.76, bar train loss 2.014, len train loss 0.011, col train loss 148.562


Epoch 380: 0batch [00:00, ?batch/s]

epoch 379: avg test  loss 293.50, bar  test loss 3.804, len  test loss 0.522, col  test loss 162.226


Epoch 380: 272batch [00:45,  5.99batch/s, loss=197]


epoch 380: avg train loss 192.87, bar train loss 2.018, len train loss 0.011, col train loss 148.601
epoch 380: avg test  loss 293.77, bar  test loss 3.825, len  test loss 0.524, col  test loss 162.117


Epoch 381: 272batch [00:46,  5.89batch/s, loss=184]


epoch 381: avg train loss 192.72, bar train loss 2.013, len train loss 0.010, col train loss 148.581


Epoch 382: 0batch [00:00, ?batch/s]

epoch 381: avg test  loss 292.09, bar  test loss 3.827, len  test loss 0.505, col  test loss 162.104


Epoch 382: 272batch [00:46,  5.80batch/s, loss=191]


epoch 382: avg train loss 192.73, bar train loss 2.011, len train loss 0.011, col train loss 148.577


Epoch 383: 0batch [00:00, ?batch/s, loss=192]

epoch 382: avg test  loss 286.59, bar  test loss 3.824, len  test loss 0.449, col  test loss 162.246


Epoch 383: 272batch [00:42,  6.39batch/s, loss=185]


epoch 383: avg train loss 192.72, bar train loss 2.012, len train loss 0.011, col train loss 148.558


Epoch 384: 1batch [00:00,  6.45batch/s, loss=193]

epoch 383: avg test  loss 288.89, bar  test loss 3.817, len  test loss 0.475, col  test loss 162.142


Epoch 384: 272batch [00:45,  5.95batch/s, loss=203]


epoch 384: avg train loss 192.70, bar train loss 2.013, len train loss 0.011, col train loss 148.549


Epoch 385: 0batch [00:00, ?batch/s]

epoch 384: avg test  loss 293.39, bar  test loss 3.829, len  test loss 0.518, col  test loss 162.147


Epoch 385: 272batch [00:47,  5.71batch/s, loss=207]


epoch 385: avg train loss 192.70, bar train loss 2.011, len train loss 0.011, col train loss 148.540
epoch 385: avg test  loss 289.43, bar  test loss 3.836, len  test loss 0.476, col  test loss 162.078


Epoch 386: 272batch [00:43,  6.25batch/s, loss=186]


epoch 386: avg train loss 192.71, bar train loss 2.013, len train loss 0.010, col train loss 148.558


Epoch 387: 1batch [00:00,  5.81batch/s, loss=199]

epoch 386: avg test  loss 288.22, bar  test loss 3.801, len  test loss 0.472, col  test loss 162.083


Epoch 387: 272batch [00:45,  5.93batch/s, loss=200]


epoch 387: avg train loss 192.66, bar train loss 2.013, len train loss 0.011, col train loss 148.501


Epoch 388: 1batch [00:00,  6.58batch/s, loss=192]

epoch 387: avg test  loss 286.65, bar  test loss 3.837, len  test loss 0.448, col  test loss 162.140


Epoch 388: 272batch [00:43,  6.31batch/s, loss=207]


epoch 388: avg train loss 192.51, bar train loss 2.005, len train loss 0.011, col train loss 148.499


Epoch 389: 0batch [00:00, ?batch/s]

epoch 388: avg test  loss 288.91, bar  test loss 3.809, len  test loss 0.471, col  test loss 162.171


Epoch 389: 272batch [00:43,  6.32batch/s, loss=190]


epoch 389: avg train loss 192.68, bar train loss 2.014, len train loss 0.011, col train loss 148.510


Epoch 390: 1batch [00:00,  6.54batch/s, loss=193]

epoch 389: avg test  loss 294.90, bar  test loss 3.824, len  test loss 0.534, col  test loss 162.152


Epoch 390: 272batch [00:43,  6.19batch/s, loss=178]


epoch 390: avg train loss 192.57, bar train loss 2.008, len train loss 0.011, col train loss 148.497
epoch 390: avg test  loss 289.79, bar  test loss 3.832, len  test loss 0.482, col  test loss 162.222


Epoch 391: 272batch [00:44,  6.13batch/s, loss=205]


epoch 391: avg train loss 192.65, bar train loss 2.008, len train loss 0.011, col train loss 148.549


Epoch 392: 0batch [00:00, ?batch/s]

epoch 391: avg test  loss 293.51, bar  test loss 3.827, len  test loss 0.521, col  test loss 162.129


Epoch 392: 272batch [00:51,  5.29batch/s, loss=203]


epoch 392: avg train loss 192.49, bar train loss 2.006, len train loss 0.011, col train loss 148.472


Epoch 393: 0batch [00:00, ?batch/s]

epoch 392: avg test  loss 290.39, bar  test loss 3.826, len  test loss 0.488, col  test loss 162.219


Epoch 393: 272batch [00:44,  6.13batch/s, loss=200]


epoch 393: avg train loss 192.55, bar train loss 2.010, len train loss 0.011, col train loss 148.454


Epoch 394: 1batch [00:00,  6.76batch/s, loss=203]

epoch 393: avg test  loss 289.69, bar  test loss 3.836, len  test loss 0.480, col  test loss 162.132


Epoch 394: 272batch [00:42,  6.37batch/s, loss=189]


epoch 394: avg train loss 192.47, bar train loss 2.004, len train loss 0.011, col train loss 148.461


Epoch 395: 1batch [00:00,  6.29batch/s, loss=196]

epoch 394: avg test  loss 289.73, bar  test loss 3.841, len  test loss 0.480, col  test loss 162.227


Epoch 395: 272batch [00:40,  6.71batch/s, loss=198]


epoch 395: avg train loss 192.61, bar train loss 2.011, len train loss 0.011, col train loss 148.477
epoch 395: avg test  loss 288.27, bar  test loss 3.834, len  test loss 0.466, col  test loss 162.153


Epoch 396: 272batch [00:40,  6.66batch/s, loss=188]


epoch 396: avg train loss 192.43, bar train loss 2.004, len train loss 0.011, col train loss 148.458


Epoch 397: 1batch [00:00,  6.10batch/s, loss=189]

epoch 396: avg test  loss 291.72, bar  test loss 3.826, len  test loss 0.499, col  test loss 162.189


Epoch 397: 272batch [00:42,  6.40batch/s, loss=197]


epoch 397: avg train loss 192.31, bar train loss 2.000, len train loss 0.011, col train loss 148.440


Epoch 398: 1batch [00:00,  6.25batch/s, loss=191]

epoch 397: avg test  loss 294.10, bar  test loss 3.835, len  test loss 0.523, col  test loss 162.261


Epoch 398: 272batch [00:40,  6.67batch/s, loss=174]


epoch 398: avg train loss 192.37, bar train loss 2.006, len train loss 0.010, col train loss 148.416


Epoch 399: 0batch [00:00, ?batch/s]

epoch 398: avg test  loss 293.32, bar  test loss 3.858, len  test loss 0.512, col  test loss 162.224


Epoch 399: 272batch [00:51,  5.32batch/s, loss=193]


epoch 399: avg train loss 192.41, bar train loss 2.006, len train loss 0.011, col train loss 148.423


Epoch 400: 0batch [00:00, ?batch/s]

epoch 399: avg test  loss 292.65, bar  test loss 3.833, len  test loss 0.508, col  test loss 162.269


Epoch 400: 272batch [00:45,  5.92batch/s, loss=197]


epoch 400: avg train loss 192.50, bar train loss 2.008, len train loss 0.011, col train loss 148.451
epoch 400: avg test  loss 288.56, bar  test loss 3.823, len  test loss 0.469, col  test loss 162.231


Epoch 401: 272batch [00:49,  5.50batch/s, loss=200]


epoch 401: avg train loss 192.35, bar train loss 2.004, len train loss 0.011, col train loss 148.399


Epoch 402: 1batch [00:00,  6.06batch/s, loss=194]

epoch 401: avg test  loss 289.48, bar  test loss 3.836, len  test loss 0.477, col  test loss 162.242


Epoch 402: 272batch [00:48,  5.64batch/s, loss=187]


epoch 402: avg train loss 192.37, bar train loss 2.005, len train loss 0.010, col train loss 148.408


Epoch 403: 0batch [00:00, ?batch/s]

epoch 402: avg test  loss 290.14, bar  test loss 3.840, len  test loss 0.482, col  test loss 162.177


Epoch 403: 272batch [00:51,  5.25batch/s, loss=195]


epoch 403: avg train loss 192.33, bar train loss 2.002, len train loss 0.011, col train loss 148.413


Epoch 404: 0batch [00:00, ?batch/s]

epoch 403: avg test  loss 289.92, bar  test loss 3.839, len  test loss 0.480, col  test loss 162.190


Epoch 404: 272batch [00:50,  5.34batch/s, loss=200]


epoch 404: avg train loss 192.32, bar train loss 2.003, len train loss 0.010, col train loss 148.394


Epoch 405: 0batch [00:00, ?batch/s]

epoch 404: avg test  loss 290.98, bar  test loss 3.828, len  test loss 0.488, col  test loss 162.173


Epoch 405: 272batch [00:50,  5.40batch/s, loss=205]


epoch 405: avg train loss 192.18, bar train loss 2.000, len train loss 0.010, col train loss 148.347
epoch 405: avg test  loss 288.85, bar  test loss 3.840, len  test loss 0.469, col  test loss 162.220


Epoch 406: 272batch [00:53,  5.05batch/s, loss=192]


epoch 406: avg train loss 192.30, bar train loss 2.004, len train loss 0.010, col train loss 148.377


Epoch 407: 1batch [00:00,  7.25batch/s, loss=189]

epoch 406: avg test  loss 292.99, bar  test loss 3.836, len  test loss 0.511, col  test loss 162.253


Epoch 407: 272batch [00:48,  5.63batch/s, loss=185]


epoch 407: avg train loss 192.19, bar train loss 2.000, len train loss 0.011, col train loss 148.342


Epoch 408: 0batch [00:00, ?batch/s]

epoch 407: avg test  loss 291.44, bar  test loss 3.847, len  test loss 0.494, col  test loss 162.166


Epoch 408: 272batch [00:51,  5.27batch/s, loss=212]


epoch 408: avg train loss 192.20, bar train loss 1.995, len train loss 0.011, col train loss 148.351


Epoch 409: 0batch [00:00, ?batch/s]

epoch 408: avg test  loss 286.70, bar  test loss 3.846, len  test loss 0.447, col  test loss 162.279


Epoch 409: 272batch [00:50,  5.38batch/s, loss=193]


epoch 409: avg train loss 192.39, bar train loss 2.002, len train loss 0.011, col train loss 148.433


Epoch 410: 0batch [00:00, ?batch/s]

epoch 409: avg test  loss 292.56, bar  test loss 3.832, len  test loss 0.504, col  test loss 162.155


Epoch 410: 272batch [00:52,  5.13batch/s, loss=214]


epoch 410: avg train loss 192.31, bar train loss 2.002, len train loss 0.011, col train loss 148.365
epoch 410: avg test  loss 289.60, bar  test loss 3.859, len  test loss 0.475, col  test loss 162.162


Epoch 411: 272batch [00:52,  5.14batch/s, loss=203]


epoch 411: avg train loss 192.10, bar train loss 1.995, len train loss 0.011, col train loss 148.347


Epoch 412: 0batch [00:00, ?batch/s]

epoch 411: avg test  loss 292.98, bar  test loss 3.848, len  test loss 0.510, col  test loss 162.240


Epoch 412: 272batch [00:51,  5.26batch/s, loss=182]


epoch 412: avg train loss 192.00, bar train loss 1.996, len train loss 0.010, col train loss 148.276


Epoch 413: 0batch [00:00, ?batch/s]

epoch 412: avg test  loss 289.96, bar  test loss 3.856, len  test loss 0.478, col  test loss 162.278


Epoch 413: 272batch [00:52,  5.15batch/s, loss=191]


epoch 413: avg train loss 192.00, bar train loss 1.996, len train loss 0.010, col train loss 148.275


Epoch 414: 1batch [00:00,  6.25batch/s, loss=196]

epoch 413: avg test  loss 290.42, bar  test loss 3.853, len  test loss 0.483, col  test loss 162.217


Epoch 414: 272batch [00:49,  5.46batch/s, loss=197]


epoch 414: avg train loss 192.04, bar train loss 1.994, len train loss 0.011, col train loss 148.317


Epoch 415: 1batch [00:00,  6.71batch/s, loss=183]

epoch 414: avg test  loss 290.14, bar  test loss 3.852, len  test loss 0.481, col  test loss 162.236


Epoch 415: 272batch [00:42,  6.38batch/s, loss=186]


epoch 415: avg train loss 192.01, bar train loss 1.996, len train loss 0.010, col train loss 148.295
epoch 415: avg test  loss 289.59, bar  test loss 3.860, len  test loss 0.474, col  test loss 162.260


Epoch 416: 272batch [00:43,  6.21batch/s, loss=200]


epoch 416: avg train loss 192.12, bar train loss 2.001, len train loss 0.010, col train loss 148.295


Epoch 417: 0batch [00:00, ?batch/s]

epoch 416: avg test  loss 292.48, bar  test loss 3.860, len  test loss 0.503, col  test loss 162.266


Epoch 417: 272batch [00:51,  5.25batch/s, loss=199]


epoch 417: avg train loss 192.04, bar train loss 1.997, len train loss 0.010, col train loss 148.284


Epoch 418: 0batch [00:00, ?batch/s]

epoch 417: avg test  loss 291.29, bar  test loss 3.859, len  test loss 0.491, col  test loss 162.241


Epoch 418: 272batch [00:50,  5.38batch/s, loss=172]


epoch 418: avg train loss 191.95, bar train loss 1.993, len train loss 0.010, col train loss 148.271


Epoch 419: 0batch [00:00, ?batch/s]

epoch 418: avg test  loss 291.07, bar  test loss 3.877, len  test loss 0.485, col  test loss 162.290


Epoch 419: 272batch [00:49,  5.53batch/s, loss=198]


epoch 419: avg train loss 191.96, bar train loss 1.991, len train loss 0.011, col train loss 148.281


Epoch 420: 0batch [00:00, ?batch/s]

epoch 419: avg test  loss 290.77, bar  test loss 3.847, len  test loss 0.487, col  test loss 162.281


Epoch 420: 272batch [00:48,  5.58batch/s, loss=195]


epoch 420: avg train loss 191.89, bar train loss 1.990, len train loss 0.010, col train loss 148.290
epoch 420: avg test  loss 289.85, bar  test loss 3.851, len  test loss 0.477, col  test loss 162.180


Epoch 421: 272batch [00:54,  4.96batch/s, loss=191]


epoch 421: avg train loss 191.90, bar train loss 1.992, len train loss 0.010, col train loss 148.268


Epoch 422: 1batch [00:00,  6.10batch/s, loss=190]

epoch 421: avg test  loss 290.41, bar  test loss 3.879, len  test loss 0.479, col  test loss 162.267


Epoch 422: 272batch [00:43,  6.30batch/s, loss=189]


epoch 422: avg train loss 191.91, bar train loss 1.993, len train loss 0.010, col train loss 148.231


Epoch 423: 1batch [00:00,  6.62batch/s, loss=192]

epoch 422: avg test  loss 289.08, bar  test loss 3.868, len  test loss 0.468, col  test loss 162.233


Epoch 423: 272batch [00:43,  6.33batch/s, loss=194]


epoch 423: avg train loss 191.93, bar train loss 1.993, len train loss 0.011, col train loss 148.218


Epoch 424: 1batch [00:00,  6.45batch/s, loss=179]

epoch 423: avg test  loss 290.35, bar  test loss 3.852, len  test loss 0.482, col  test loss 162.266


Epoch 424: 272batch [00:41,  6.60batch/s, loss=187]


epoch 424: avg train loss 192.01, bar train loss 1.996, len train loss 0.011, col train loss 148.249


Epoch 425: 1batch [00:00,  6.33batch/s, loss=195]

epoch 424: avg test  loss 294.24, bar  test loss 3.863, len  test loss 0.520, col  test loss 162.259


Epoch 425: 272batch [00:44,  6.07batch/s, loss=204]


epoch 425: avg train loss 191.76, bar train loss 1.989, len train loss 0.010, col train loss 148.207
epoch 425: avg test  loss 291.48, bar  test loss 3.878, len  test loss 0.488, col  test loss 162.325


Epoch 426: 272batch [00:42,  6.47batch/s, loss=197]


epoch 426: avg train loss 191.85, bar train loss 1.990, len train loss 0.010, col train loss 148.238


Epoch 427: 0batch [00:00, ?batch/s, loss=187]

epoch 426: avg test  loss 291.93, bar  test loss 3.886, len  test loss 0.492, col  test loss 162.280


Epoch 427: 272batch [00:50,  5.42batch/s, loss=180]


epoch 427: avg train loss 191.84, bar train loss 1.992, len train loss 0.010, col train loss 148.211


Epoch 428: 1batch [00:00,  6.54batch/s, loss=190]

epoch 427: avg test  loss 293.88, bar  test loss 3.874, len  test loss 0.513, col  test loss 162.231


Epoch 428: 219batch [00:29,  7.35batch/s, loss=199]


KeyboardInterrupt: 

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1000,500,save_folder="new/HVAE2",save_interval=5)

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 500, save_folder="VAEFC")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')