In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime


writer = SummaryWriter(f"{link}/saved_models/VAE13/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z_dim=64, d_dim=45, x_dim=7500, y_dim=2,
                 beta=10, rec_alpha = 1, rec_beta = 1, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z_dim = z_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [15]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_col, x_bar = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len2.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar2.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col2.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        

        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]

    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        return (x, y, d, x_len, x_col, x_bar)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,5,200), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]) ,2*j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 2*j+1] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len2.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col2.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar2.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [16]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim):
        super(px, self).__init__()

        self.fc1 = nn.Sequential(nn.Linear(z_dim, 1600, bias=False),  
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(1600, 1000, bias=False),  
                                 nn.ReLU())
        # Predicting length and color of each bar
        self.up1 = nn.Upsample(scale_factor=5)
        self.de1 = nn.Sequential(nn.ConvTranspose1d(50,100,kernel_size = 5,
                                                    stride = 1, padding = 2),
                                 nn.ReLU(),)
        self.up2 = nn.Upsample(scale_factor=2)
        self.de2 = nn.Sequential(
                                 nn.ConvTranspose1d(100,100,kernel_size = 5,
                                                    stride = 1, padding = 2),
                                 nn.ReLU(),
                                 )
        # Predicting color of each bar
        self.color_bar = nn.Sequential(nn.Conv1d(100,5, kernel_size = 9, padding = 'same', padding_mode='reflect'),
                                      nn.Softmax(dim=1))
        
        # Predicting the length of each bar
        self.length_bar = nn.Sequential(nn.Conv1d(100, 2, kernel_size = 9, padding = 4, padding_mode='reflect', stride=2), nn.Softplus())
        #self.length_bar_scale = nn.Sequential(nn.Conv1d(100, 1, kernel_size = 3, padding = 'same', bias=False), nn.Sigmoid())
        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(1000,400), nn.ReLU(),nn.Linear(400,1), nn.Softplus())
        #self.length_RNA_scale = nn.Sequential(nn.Linear(400,1, bias=False), nn.Sigmoid())
        
    def forward(self, z):
        
        h = self.fc1(z)
        h = self.fc2(h)
        
        len_RNA = self.length_RNA(h)
        
        len_RNA_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_RNA_sc = torch.exp(self.length_RNA_scale(h))
        
        h = h.view(-1, 50, 20)
        h = self.up1(h)
        h = self.de1(h)
        h = self.up2(h)
        h = self.de2(h)
        len_bar = self.length_bar(h)
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        #len_bar_sc = torch.exp(self.length_bar_scale(h))
        
        
        col_bar = self.color_bar(h)
        
        
        return len_RNA, len_RNA_sc, len_bar, len_bar_sc, col_bar

    def reconstruct_image(self, len_RNA, var_RNA, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()
        var_RNA = var_RNA.cpu().numpy()
        #.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = int(np.round(np.random.normal(loc=len_RNA[i], scale=var_RNA[i])))
            else:
                limit = int(np.round(len_RNA[i]))
            limit = min(100, limit)
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1 = np.argmax(col_bar[i,:, 2*j])
                    _col_bar_2 = np.argmax(col_bar[i,:, 2*j+1])
                
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output


In [17]:
int(np.round(3.7, 0))
int(3.7)

3

In [18]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [19]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [20]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim):
        super(qz, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )

        self.fc11 = nn.Sequential(nn.Linear(5632, z_dim))
        self.fc12 = nn.Sequential(nn.Linear(5632, z_dim), nn.Softplus())

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[3].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()
        torch.nn.init.xavier_uniform_(self.fc12[0].weight)
        self.fc12[0].bias.data.zero_()

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(-1, 5632)
        z_loc = self.fc11(h)
        z_scale = self.fc12(h) + 1e-7

        return z_loc, z_scale




In [21]:
enc = qz(128,10,10,10)
summary(enc, (1,3,25,100))

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 2, 11]           --
│    └─Conv2d: 2-1                       [1, 64, 25, 100]          4,800
│    └─ReLU: 2-2                         [1, 64, 25, 100]          --
│    └─MaxPool2d: 2-3                    [1, 64, 12, 50]           --
│    └─Conv2d: 2-4                       [1, 128, 12, 50]          73,728
│    └─ReLU: 2-5                         [1, 128, 12, 50]          --
│    └─MaxPool2d: 2-6                    [1, 128, 6, 25]           --
│    └─Conv2d: 2-7                       [1, 256, 4, 23]           294,912
│    └─ReLU: 2-8                         [1, 256, 4, 23]           --
│    └─MaxPool2d: 2-9                    [1, 256, 2, 11]           --
├─Sequential: 1-2                        [1, 10]                   --
│    └─Linear: 2-10                      [1, 10]                   56,330

## Full model class

In [22]:
class StampDIVA(nn.Module):
    def __init__(self, args):
        super(StampDIVA, self).__init__()
        self.z_dim = args.z_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim

        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup

        self.cuda()

    def forward(self, d, x, y):
        # Encode
        zd_q_loc, zd_q_scale = self.qz(x)
        
        # Reparameterization trick
        qz = dist.Normal(zd_q_loc, zd_q_scale)
        z_q = qz.rsample()
        
        
        # Decode
        x_len, x_len_scale, x_bar, x_bar_scale, x_col = self.px(z_q)
        z_p_loc, z_p_scale = torch.zeros(z_q.size()[0], self.z_dim).cuda(),\
                        torch.ones(z_q.size()[0], self.z_dim).cuda()
        pz = dist.Normal(z_p_loc, z_p_scale)

        # Reparameterization trick
        pz = dist.Normal(z_p_loc, z_p_scale)
        
        return x_len, x_len_scale, x_bar, x_bar_scale, x_col, qz, pz, z_q

    def loss_function(self, d, x, y, out_len, out_bar, out_col):
        
        x_len, x_len_scale, x_bar, x_bar_scale, x_col, qz, pz, z_q = self.forward(d, x, y)
        mask = 1 - F.one_hot(torch.round(out_len).to(torch.int64)*2-1, 200).cumsum(dim=1)[:,None,:]
        mask1 = (1 - F.one_hot(torch.round(out_len).to(torch.int64)-1, 100).cumsum(dim=1)[:,None,:]).repeat(1,2,1)

        
        #x_bar = mask.repeat(1,1,1)*x_bar
        
        #x_bar_scale = mask.repeat(1,1,1)*x_bar_scale
        x_col = mask.repeat(1,5,1)*x_col
        
        
        dist_len = dist.Normal(x_len, x_len_scale+1e-7)
        log_len = dist_len.log_prob(out_len[:,None]).mean()
        
       # dist_bar = dist.Normal(x_bar, x_bar_scale+1e-7)
        #log_bar = dist_bar.log_prob(out_bar)*mask1
        #log_bar = (log_bar).sum(dim=(1,2)).mean()
        
        mse_bar = ((((x_bar - out_bar)**2)*mask1).sum(dim=(1,2))/(mask1.sum(dim=(1,2)))).sum()#.detach().item()
        
        max_bar = torch.argmax(x_col, dim=1)
        #acc_bar = (((max_bar==out_col)*mask).sum(dim=(1,2))/mask.sum(dim=(1,2))).sum().detach().item()/5
        acc_bar = (max_bar==torch.argmax(out_col, dim=1)).sum().float()
        
        CE_len = -log_len
        CE_bar = mse_bar#-log_bar
        CE_col = F.cross_entropy(x_col, out_col, reduction='sum')

        KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
        return self.rec_alpha * CE_len \
                  + self.rec_beta * CE_bar \
                  + self.rec_gamma * CE_col \
                  - self.beta * KL_z, \
                  CE_bar, CE_len, CE_col, mse_bar, acc_bar

In [23]:
default_args = diva_args(z_dim=1600, rec_alpha = 10, rec_beta = 10, rec_gamma = 10, 
                         beta=1, warmup=1, prewarmup=0)
enc = StampDIVA(default_args)
summary(enc,[ (1,1),(1,3,25,100),(1,1)])

Layer (type:depth-idx)                   Output Shape              Param #
StampDIVA                                --                        --
├─qz: 1-1                                [1, 1600]                 --
│    └─Sequential: 2-1                   [1, 256, 2, 11]           --
│    │    └─Conv2d: 3-1                  [1, 64, 25, 100]          4,800
│    │    └─ReLU: 3-2                    [1, 64, 25, 100]          --
│    │    └─MaxPool2d: 3-3               [1, 64, 12, 50]           --
│    │    └─Conv2d: 3-4                  [1, 128, 12, 50]          73,728
│    │    └─ReLU: 3-5                    [1, 128, 12, 50]          --
│    │    └─MaxPool2d: 3-6               [1, 128, 6, 25]           --
│    │    └─Conv2d: 3-7                  [1, 256, 4, 23]           294,912
│    │    └─ReLU: 3-8                    [1, 256, 4, 23]           --
│    │    └─MaxPool2d: 3-9               [1, 256, 2, 11]           --
│    └─Sequential: 2-2                   [1, 1600]                 --
│  

# Training the model

## Loading dataset

In [24]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [25]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [26]:
len(RNA_dataset)

34721

In [27]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    mse_bar = 0
    acc_bar = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        mse_bar += mse
        acc_bar += acc
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar

In [28]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    mse_bar = 0
    acc_bar = 0        
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)
            loss, bar_loss, len_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
            mse_bar += mse
            acc_bar += acc
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_len_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss, mse_bar, acc_bar
  

In [29]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col, mtr, atr = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test, mte, ate = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train': atr, 'test':ate}, epoch)

        if epoch % save_interval == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')


    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.cpu().detach().numpy() for i in epoch_loss_sup]
    test_loss = [i.cpu().detach().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [30]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        x_1, x_1var, x_2, x_2var, x_3, _, _, _ = diva(d,x,y)
        out = diva.px.reconstruct_image(x_1, x_1var, x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [31]:
DEVICE

device(type='cuda')

## Model Training

In [51]:
default_args = diva_args(z_dim=1600, rec_alpha = 10, rec_beta = 5, rec_gamma = 5, 
                         beta=1, warmup=1, prewarmup=0)

In [64]:
diva = StampDIVA(default_args).to(DEVICE)

In [65]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE10/checkpoints/905.pth'))

In [66]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [67]:
optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)

In [68]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [69]:
writer.flush()

In [37]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/VAE13/tensorboard/"

In [70]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 2000, 0, save_folder="VAE13",save_interval=5)

Epoch 1: 272batch [00:22, 12.03batch/s, loss=1.02e+3]


epoch 1: avg train loss 1101.23, bar train loss 12.306, len train loss 4.835, col train loss 169.585
epoch 1: avg test  loss 1017.75, bar  test loss 10.684, len  test loss 0.996, col  test loss 169.390


Epoch 2: 272batch [00:22, 12.01batch/s, loss=918]    


epoch 2: avg train loss 983.15, bar train loss 10.630, len train loss 0.998, col train loss 169.432


Epoch 3: 2batch [00:00, 11.83batch/s, loss=959]

epoch 2: avg test  loss 913.82, bar  test loss 10.730, len  test loss 0.925, col  test loss 169.367


Epoch 3: 272batch [00:22, 12.06batch/s, loss=912]


epoch 3: avg train loss 911.36, bar train loss 10.610, len train loss 0.832, col train loss 169.412


Epoch 4: 0batch [00:00, ?batch/s, loss=893]

epoch 3: avg test  loss 908.48, bar  test loss 10.661, len  test loss 0.609, col  test loss 169.380


Epoch 4: 272batch [00:22, 12.12batch/s, loss=896]


epoch 4: avg train loss 908.99, bar train loss 10.620, len train loss 0.432, col train loss 169.407


Epoch 5: 2batch [00:00, 11.90batch/s, loss=925]

epoch 4: avg test  loss 905.48, bar  test loss 10.647, len  test loss 0.272, col  test loss 169.341


Epoch 5: 272batch [00:22, 12.05batch/s, loss=1.01e+3]


epoch 5: avg train loss 906.93, bar train loss 10.596, len train loss 0.296, col train loss 169.374
epoch 5: avg test  loss 904.89, bar  test loss 10.608, len  test loss 0.242, col  test loss 169.319


Epoch 6: 272batch [00:22, 12.10batch/s, loss=991]


epoch 6: avg train loss 905.72, bar train loss 10.558, len train loss 0.273, col train loss 169.339


Epoch 7: 2batch [00:00, 12.66batch/s, loss=918]

epoch 6: avg test  loss 906.52, bar  test loss 10.572, len  test loss 0.313, col  test loss 169.276


Epoch 7: 272batch [00:22, 12.07batch/s, loss=887]


epoch 7: avg train loss 904.87, bar train loss 10.542, len train loss 0.272, col train loss 169.297


Epoch 8: 2batch [00:00, 11.83batch/s, loss=892]

epoch 7: avg test  loss 904.65, bar  test loss 10.564, len  test loss 0.271, col  test loss 169.229


Epoch 8: 272batch [00:22, 12.16batch/s, loss=942]


epoch 8: avg train loss 904.99, bar train loss 10.531, len train loss 0.281, col train loss 169.269


Epoch 9: 2batch [00:00, 12.20batch/s, loss=896]

epoch 8: avg test  loss 903.88, bar  test loss 10.547, len  test loss 0.284, col  test loss 169.213


Epoch 9: 272batch [00:22, 12.08batch/s, loss=911]


epoch 9: avg train loss 904.74, bar train loss 10.523, len train loss 0.289, col train loss 169.223


Epoch 10: 2batch [00:00, 12.50batch/s, loss=904]

epoch 9: avg test  loss 904.26, bar  test loss 10.559, len  test loss 0.287, col  test loss 169.168


Epoch 10: 272batch [00:22, 12.04batch/s, loss=928]


epoch 10: avg train loss 904.16, bar train loss 10.520, len train loss 0.286, col train loss 169.188
epoch 10: avg test  loss 903.71, bar  test loss 10.537, len  test loss 0.290, col  test loss 169.199


Epoch 11: 272batch [00:22, 12.10batch/s, loss=883]


epoch 11: avg train loss 904.30, bar train loss 10.514, len train loss 0.309, col train loss 169.157


Epoch 12: 0batch [00:00, ?batch/s, loss=923]

epoch 11: avg test  loss 903.38, bar  test loss 10.554, len  test loss 0.294, col  test loss 169.121


Epoch 12: 272batch [00:22, 11.90batch/s, loss=876]


epoch 12: avg train loss 904.36, bar train loss 10.511, len train loss 0.318, col train loss 169.100


Epoch 13: 2batch [00:00, 12.50batch/s, loss=907]

epoch 12: avg test  loss 903.85, bar  test loss 10.548, len  test loss 0.331, col  test loss 169.006


Epoch 13: 272batch [00:22, 12.04batch/s, loss=849]


epoch 13: avg train loss 903.75, bar train loss 10.509, len train loss 0.338, col train loss 169.031


Epoch 14: 2batch [00:00, 11.70batch/s, loss=942]

epoch 13: avg test  loss 903.34, bar  test loss 10.561, len  test loss 0.365, col  test loss 168.979


Epoch 14: 272batch [00:22, 12.04batch/s, loss=900]


epoch 14: avg train loss 904.20, bar train loss 10.502, len train loss 0.364, col train loss 168.948


Epoch 15: 0batch [00:00, ?batch/s, loss=896]

epoch 14: avg test  loss 902.96, bar  test loss 10.525, len  test loss 0.344, col  test loss 168.888


Epoch 15: 272batch [00:22, 11.91batch/s, loss=934]


epoch 15: avg train loss 903.48, bar train loss 10.484, len train loss 0.387, col train loss 168.789
epoch 15: avg test  loss 901.63, bar  test loss 10.483, len  test loss 0.364, col  test loss 168.582


Epoch 16: 272batch [00:22, 12.14batch/s, loss=868]


epoch 16: avg train loss 902.90, bar train loss 10.428, len train loss 0.387, col train loss 168.540


Epoch 17: 2batch [00:00, 12.35batch/s, loss=933]

epoch 16: avg test  loss 899.83, bar  test loss 10.382, len  test loss 0.335, col  test loss 168.231


Epoch 17: 272batch [00:22, 11.90batch/s, loss=973]


epoch 17: avg train loss 901.02, bar train loss 10.326, len train loss 0.330, col train loss 168.303


Epoch 18: 2batch [00:00, 12.12batch/s, loss=865]

epoch 17: avg test  loss 901.15, bar  test loss 10.345, len  test loss 0.281, col  test loss 168.404


Epoch 18: 272batch [00:22, 11.93batch/s, loss=889]


epoch 18: avg train loss 899.77, bar train loss 10.188, len train loss 0.314, col train loss 168.230


Epoch 19: 2batch [00:00, 11.98batch/s, loss=857]

epoch 18: avg test  loss 897.68, bar  test loss 10.101, len  test loss 0.281, col  test loss 168.013


Epoch 19: 272batch [00:22, 12.11batch/s, loss=882]


epoch 19: avg train loss 898.74, bar train loss 10.027, len train loss 0.299, col train loss 168.128


Epoch 20: 2batch [00:00, 12.27batch/s, loss=897]

epoch 19: avg test  loss 896.27, bar  test loss 9.850, len  test loss 0.277, col  test loss 167.946


Epoch 20: 272batch [00:22, 12.08batch/s, loss=902]


epoch 20: avg train loss 896.50, bar train loss 9.716, len train loss 0.278, col train loss 167.978
epoch 20: avg test  loss 894.82, bar  test loss 9.495, len  test loss 0.254, col  test loss 167.896


Epoch 21: 272batch [00:22, 12.05batch/s, loss=823]


epoch 21: avg train loss 895.70, bar train loss 9.487, len train loss 0.276, col train loss 167.931


Epoch 22: 2batch [00:00, 12.27batch/s, loss=901]

epoch 21: avg test  loss 893.80, bar  test loss 9.341, len  test loss 0.238, col  test loss 167.815


Epoch 22: 272batch [00:22, 12.10batch/s, loss=932]


epoch 22: avg train loss 894.93, bar train loss 9.334, len train loss 0.272, col train loss 167.917


Epoch 23: 2batch [00:00, 11.43batch/s, loss=907]

epoch 22: avg test  loss 893.57, bar  test loss 9.247, len  test loss 0.258, col  test loss 167.783


Epoch 23: 272batch [00:22, 11.96batch/s, loss=928]


epoch 23: avg train loss 894.41, bar train loss 9.204, len train loss 0.272, col train loss 167.884


Epoch 24: 2batch [00:00, 12.12batch/s, loss=854]

epoch 23: avg test  loss 893.47, bar  test loss 9.180, len  test loss 0.260, col  test loss 167.794


Epoch 24: 272batch [00:22, 12.07batch/s, loss=914]


epoch 24: avg train loss 893.34, bar train loss 9.064, len train loss 0.266, col train loss 167.840


Epoch 25: 2batch [00:00, 11.76batch/s, loss=902]

epoch 24: avg test  loss 891.93, bar  test loss 8.979, len  test loss 0.243, col  test loss 167.697


Epoch 25: 272batch [00:22, 12.06batch/s, loss=852]


epoch 25: avg train loss 892.99, bar train loss 8.966, len train loss 0.262, col train loss 167.809
epoch 25: avg test  loss 892.77, bar  test loss 8.931, len  test loss 0.284, col  test loss 167.770


Epoch 26: 272batch [00:22, 12.08batch/s, loss=920]


epoch 26: avg train loss 891.96, bar train loss 8.817, len train loss 0.257, col train loss 167.766


Epoch 27: 2batch [00:00, 12.27batch/s, loss=868]

epoch 26: avg test  loss 890.76, bar  test loss 8.656, len  test loss 0.256, col  test loss 167.616


Epoch 27: 272batch [00:22, 11.94batch/s, loss=818]


epoch 27: avg train loss 891.51, bar train loss 8.717, len train loss 0.256, col train loss 167.735


Epoch 28: 0batch [00:00, ?batch/s, loss=862]

epoch 27: avg test  loss 891.68, bar  test loss 8.697, len  test loss 0.311, col  test loss 167.739


Epoch 28: 272batch [00:22, 11.92batch/s, loss=887]


epoch 28: avg train loss 891.02, bar train loss 8.619, len train loss 0.254, col train loss 167.695


Epoch 29: 2batch [00:00, 11.98batch/s, loss=910]

epoch 28: avg test  loss 890.30, bar  test loss 8.516, len  test loss 0.251, col  test loss 167.671


Epoch 29: 272batch [00:22, 11.87batch/s, loss=934]


epoch 29: avg train loss 890.54, bar train loss 8.516, len train loss 0.254, col train loss 167.659


Epoch 30: 2batch [00:00, 12.38batch/s, loss=907]

epoch 29: avg test  loss 889.73, bar  test loss 8.385, len  test loss 0.236, col  test loss 167.580


Epoch 30: 272batch [00:22, 11.92batch/s, loss=883]


epoch 30: avg train loss 889.86, bar train loss 8.416, len train loss 0.249, col train loss 167.609
epoch 30: avg test  loss 888.94, bar  test loss 8.333, len  test loss 0.235, col  test loss 167.491


Epoch 31: 272batch [00:22, 12.04batch/s, loss=906]


epoch 31: avg train loss 889.42, bar train loss 8.317, len train loss 0.249, col train loss 167.570


Epoch 32: 0batch [00:00, ?batch/s, loss=896]

epoch 31: avg test  loss 888.52, bar  test loss 8.237, len  test loss 0.233, col  test loss 167.491


Epoch 32: 272batch [00:22, 12.04batch/s, loss=915]


epoch 32: avg train loss 888.98, bar train loss 8.246, len train loss 0.245, col train loss 167.533


Epoch 33: 2batch [00:00, 11.90batch/s, loss=893]

epoch 32: avg test  loss 889.34, bar  test loss 8.460, len  test loss 0.228, col  test loss 167.521


Epoch 33: 272batch [00:22, 12.04batch/s, loss=913]


epoch 33: avg train loss 888.58, bar train loss 8.167, len train loss 0.243, col train loss 167.506


Epoch 34: 0batch [00:00, ?batch/s, loss=885]

epoch 33: avg test  loss 887.94, bar  test loss 8.083, len  test loss 0.247, col  test loss 167.490


Epoch 34: 272batch [00:22, 12.01batch/s, loss=847]


epoch 34: avg train loss 888.09, bar train loss 8.080, len train loss 0.242, col train loss 167.469


Epoch 35: 2batch [00:00, 11.90batch/s, loss=883]

epoch 34: avg test  loss 887.60, bar  test loss 8.113, len  test loss 0.227, col  test loss 167.419


Epoch 35: 272batch [00:22, 11.91batch/s, loss=910]


epoch 35: avg train loss 887.89, bar train loss 8.034, len train loss 0.241, col train loss 167.453
epoch 35: avg test  loss 887.22, bar  test loss 7.978, len  test loss 0.248, col  test loss 167.380


Epoch 36: 272batch [00:22, 12.00batch/s, loss=921]


epoch 36: avg train loss 887.55, bar train loss 7.985, len train loss 0.239, col train loss 167.407


Epoch 37: 2batch [00:00, 12.50batch/s, loss=885]

epoch 36: avg test  loss 887.21, bar  test loss 7.985, len  test loss 0.230, col  test loss 167.421


Epoch 37: 272batch [00:22, 12.35batch/s, loss=864]


epoch 37: avg train loss 887.20, bar train loss 7.908, len train loss 0.240, col train loss 167.387


Epoch 38: 2batch [00:00, 12.05batch/s, loss=904]

epoch 37: avg test  loss 886.88, bar  test loss 7.931, len  test loss 0.222, col  test loss 167.354


Epoch 38: 272batch [00:22, 12.28batch/s, loss=812]


epoch 38: avg train loss 886.96, bar train loss 7.871, len train loss 0.237, col train loss 167.367


Epoch 39: 2batch [00:00, 12.12batch/s, loss=879]

epoch 38: avg test  loss 886.30, bar  test loss 7.816, len  test loss 0.231, col  test loss 167.294


Epoch 39: 272batch [00:22, 12.28batch/s, loss=888]


epoch 39: avg train loss 886.69, bar train loss 7.821, len train loss 0.236, col train loss 167.338


Epoch 40: 2batch [00:00, 12.27batch/s, loss=884]

epoch 39: avg test  loss 886.79, bar  test loss 7.714, len  test loss 0.224, col  test loss 167.305


Epoch 40: 272batch [00:22, 12.31batch/s, loss=874]


epoch 40: avg train loss 886.37, bar train loss 7.762, len train loss 0.234, col train loss 167.307
epoch 40: avg test  loss 886.37, bar  test loss 7.675, len  test loss 0.220, col  test loss 167.274


Epoch 41: 272batch [00:22, 12.25batch/s, loss=949]


epoch 41: avg train loss 886.25, bar train loss 7.732, len train loss 0.238, col train loss 167.286


Epoch 42: 2batch [00:00, 12.42batch/s, loss=882]

epoch 41: avg test  loss 885.84, bar  test loss 7.801, len  test loss 0.231, col  test loss 167.264


Epoch 42: 272batch [00:22, 12.31batch/s, loss=914]


epoch 42: avg train loss 885.91, bar train loss 7.689, len train loss 0.234, col train loss 167.257


Epoch 43: 2batch [00:00, 11.98batch/s, loss=885]

epoch 42: avg test  loss 885.93, bar  test loss 7.622, len  test loss 0.220, col  test loss 167.273


Epoch 43: 272batch [00:22, 12.29batch/s, loss=943]


epoch 43: avg train loss 885.74, bar train loss 7.634, len train loss 0.241, col train loss 167.228


Epoch 44: 2batch [00:00, 12.50batch/s, loss=905]

epoch 43: avg test  loss 885.09, bar  test loss 7.630, len  test loss 0.217, col  test loss 167.174


Epoch 44: 272batch [00:22, 12.26batch/s, loss=770]


epoch 44: avg train loss 885.34, bar train loss 7.584, len train loss 0.236, col train loss 167.188


Epoch 45: 2batch [00:00, 12.35batch/s, loss=853]

epoch 44: avg test  loss 884.69, bar  test loss 7.494, len  test loss 0.240, col  test loss 167.163


Epoch 45: 272batch [00:22, 12.30batch/s, loss=880]


epoch 45: avg train loss 885.19, bar train loss 7.536, len train loss 0.239, col train loss 167.162
epoch 45: avg test  loss 884.93, bar  test loss 7.516, len  test loss 0.222, col  test loss 167.135


Epoch 46: 272batch [00:22, 12.28batch/s, loss=840]


epoch 46: avg train loss 884.77, bar train loss 7.456, len train loss 0.241, col train loss 167.119


Epoch 47: 2batch [00:00, 11.83batch/s, loss=887]

epoch 46: avg test  loss 884.06, bar  test loss 7.382, len  test loss 0.246, col  test loss 167.071


Epoch 47: 272batch [00:22, 12.27batch/s, loss=885]


epoch 47: avg train loss 884.68, bar train loss 7.416, len train loss 0.246, col train loss 167.106


Epoch 48: 2batch [00:00, 12.05batch/s, loss=900]

epoch 47: avg test  loss 884.00, bar  test loss 7.346, len  test loss 0.255, col  test loss 167.029


Epoch 48: 272batch [00:22, 12.20batch/s, loss=970]


epoch 48: avg train loss 884.26, bar train loss 7.330, len train loss 0.244, col train loss 167.074


Epoch 49: 2batch [00:00, 12.05batch/s, loss=875]

epoch 48: avg test  loss 884.38, bar  test loss 7.202, len  test loss 0.262, col  test loss 167.018


Epoch 49: 272batch [00:22, 12.17batch/s, loss=859]


epoch 49: avg train loss 883.84, bar train loss 7.230, len train loss 0.248, col train loss 167.050


Epoch 50: 2batch [00:00, 11.83batch/s, loss=938]

epoch 49: avg test  loss 883.33, bar  test loss 7.140, len  test loss 0.222, col  test loss 167.020


Epoch 50: 272batch [00:22, 11.83batch/s, loss=886]


epoch 50: avg train loss 883.68, bar train loss 7.155, len train loss 0.248, col train loss 167.048
epoch 50: avg test  loss 882.93, bar  test loss 7.030, len  test loss 0.246, col  test loss 166.967


Epoch 51: 272batch [00:23, 11.66batch/s, loss=899]


epoch 51: avg train loss 883.14, bar train loss 7.078, len train loss 0.242, col train loss 166.995


Epoch 52: 0batch [00:00, ?batch/s, loss=893]

epoch 51: avg test  loss 882.75, bar  test loss 6.919, len  test loss 0.237, col  test loss 166.913


Epoch 52: 272batch [00:22, 12.08batch/s, loss=948]


epoch 52: avg train loss 882.88, bar train loss 6.996, len train loss 0.245, col train loss 166.994


Epoch 53: 2batch [00:00, 12.12batch/s, loss=882]

epoch 52: avg test  loss 882.69, bar  test loss 6.892, len  test loss 0.235, col  test loss 166.927


Epoch 53: 272batch [00:22, 12.08batch/s, loss=1.05e+3]


epoch 53: avg train loss 882.69, bar train loss 6.954, len train loss 0.241, col train loss 166.983


Epoch 54: 2batch [00:00, 12.12batch/s, loss=893]

epoch 53: avg test  loss 881.69, bar  test loss 6.827, len  test loss 0.231, col  test loss 166.921


Epoch 54: 272batch [00:22, 12.10batch/s, loss=882]


epoch 54: avg train loss 882.36, bar train loss 6.881, len train loss 0.239, col train loss 166.963


Epoch 55: 2batch [00:00, 12.05batch/s, loss=865]

epoch 54: avg test  loss 881.60, bar  test loss 6.848, len  test loss 0.253, col  test loss 166.903


Epoch 55: 272batch [00:22, 12.06batch/s, loss=969]


epoch 55: avg train loss 881.89, bar train loss 6.800, len train loss 0.238, col train loss 166.933
epoch 55: avg test  loss 881.32, bar  test loss 6.855, len  test loss 0.221, col  test loss 166.890


Epoch 56: 272batch [00:22, 12.10batch/s, loss=918]


epoch 56: avg train loss 881.64, bar train loss 6.748, len train loss 0.234, col train loss 166.914


Epoch 57: 2batch [00:00, 11.83batch/s, loss=879]

epoch 56: avg test  loss 881.29, bar  test loss 6.647, len  test loss 0.231, col  test loss 166.827


Epoch 57: 272batch [00:22, 12.12batch/s, loss=868]


epoch 57: avg train loss 881.40, bar train loss 6.691, len train loss 0.234, col train loss 166.893


Epoch 58: 2batch [00:00, 12.50batch/s, loss=903]

epoch 57: avg test  loss 880.73, bar  test loss 6.567, len  test loss 0.220, col  test loss 166.822


Epoch 58: 272batch [00:22, 12.10batch/s, loss=860]


epoch 58: avg train loss 881.13, bar train loss 6.651, len train loss 0.227, col train loss 166.870


Epoch 59: 0batch [00:00, ?batch/s, loss=898]

epoch 58: avg test  loss 880.39, bar  test loss 6.537, len  test loss 0.230, col  test loss 166.803


Epoch 59: 272batch [00:22, 12.11batch/s, loss=874]


epoch 59: avg train loss 880.92, bar train loss 6.596, len train loss 0.230, col train loss 166.851


Epoch 60: 2batch [00:00, 11.98batch/s, loss=895]

epoch 59: avg test  loss 880.11, bar  test loss 6.464, len  test loss 0.240, col  test loss 166.754


Epoch 60: 272batch [00:22, 12.09batch/s, loss=840]


epoch 60: avg train loss 880.66, bar train loss 6.546, len train loss 0.229, col train loss 166.822
epoch 60: avg test  loss 879.88, bar  test loss 6.476, len  test loss 0.218, col  test loss 166.758


Epoch 61: 272batch [00:22, 12.11batch/s, loss=862]


epoch 61: avg train loss 880.42, bar train loss 6.494, len train loss 0.230, col train loss 166.809


Epoch 62: 2batch [00:00, 12.50batch/s, loss=880]

epoch 61: avg test  loss 879.99, bar  test loss 6.405, len  test loss 0.246, col  test loss 166.758


Epoch 62: 272batch [00:22, 12.03batch/s, loss=859]


epoch 62: avg train loss 880.30, bar train loss 6.462, len train loss 0.228, col train loss 166.792


Epoch 63: 2batch [00:00, 12.20batch/s, loss=834]

epoch 62: avg test  loss 879.65, bar  test loss 6.397, len  test loss 0.218, col  test loss 166.737


Epoch 63: 272batch [00:23, 11.73batch/s, loss=901]


epoch 63: avg train loss 880.09, bar train loss 6.431, len train loss 0.229, col train loss 166.761


Epoch 64: 2batch [00:00, 11.90batch/s, loss=870]

epoch 63: avg test  loss 879.50, bar  test loss 6.334, len  test loss 0.205, col  test loss 166.685


Epoch 64: 272batch [00:22, 11.99batch/s, loss=868]


epoch 64: avg train loss 879.81, bar train loss 6.389, len train loss 0.226, col train loss 166.733


Epoch 65: 2batch [00:00, 12.27batch/s, loss=850]

epoch 64: avg test  loss 879.89, bar  test loss 6.425, len  test loss 0.270, col  test loss 166.706


Epoch 65: 272batch [00:22, 12.01batch/s, loss=875]


epoch 65: avg train loss 879.72, bar train loss 6.364, len train loss 0.226, col train loss 166.722
epoch 65: avg test  loss 879.19, bar  test loss 6.234, len  test loss 0.227, col  test loss 166.684


Epoch 66: 272batch [00:23, 11.82batch/s, loss=893]


epoch 66: avg train loss 879.47, bar train loss 6.317, len train loss 0.224, col train loss 166.705


Epoch 67: 2batch [00:00, 12.74batch/s, loss=885]

epoch 66: avg test  loss 879.16, bar  test loss 6.367, len  test loss 0.211, col  test loss 166.655


Epoch 67: 272batch [00:22, 11.91batch/s, loss=908]


epoch 67: avg train loss 879.35, bar train loss 6.298, len train loss 0.224, col train loss 166.683


Epoch 68: 0batch [00:00, ?batch/s, loss=856]

epoch 67: avg test  loss 879.51, bar  test loss 6.346, len  test loss 0.240, col  test loss 166.723


Epoch 68: 272batch [00:22, 11.96batch/s, loss=1.02e+3]


epoch 68: avg train loss 879.23, bar train loss 6.273, len train loss 0.223, col train loss 166.670


Epoch 69: 2batch [00:00, 11.83batch/s, loss=865]

epoch 68: avg test  loss 878.92, bar  test loss 6.148, len  test loss 0.222, col  test loss 166.629


Epoch 69: 272batch [00:22, 12.01batch/s, loss=861]


epoch 69: avg train loss 879.08, bar train loss 6.248, len train loss 0.225, col train loss 166.650


Epoch 70: 0batch [00:00, ?batch/s, loss=848]

epoch 69: avg test  loss 879.68, bar  test loss 6.435, len  test loss 0.219, col  test loss 166.713


Epoch 70: 272batch [00:23, 11.79batch/s, loss=831]


epoch 70: avg train loss 878.91, bar train loss 6.226, len train loss 0.224, col train loss 166.623
epoch 70: avg test  loss 878.26, bar  test loss 6.136, len  test loss 0.209, col  test loss 166.496


Epoch 71: 272batch [00:23, 11.77batch/s, loss=842]


epoch 71: avg train loss 878.79, bar train loss 6.211, len train loss 0.224, col train loss 166.601


Epoch 72: 2batch [00:00, 11.83batch/s, loss=896]

epoch 71: avg test  loss 878.78, bar  test loss 6.146, len  test loss 0.242, col  test loss 166.622


Epoch 72: 272batch [00:22, 11.97batch/s, loss=889]


epoch 72: avg train loss 878.60, bar train loss 6.184, len train loss 0.225, col train loss 166.572


Epoch 73: 2batch [00:00, 12.05batch/s, loss=899]

epoch 72: avg test  loss 879.10, bar  test loss 6.021, len  test loss 0.254, col  test loss 166.445


Epoch 73: 272batch [00:22, 11.97batch/s, loss=970]


epoch 73: avg train loss 878.55, bar train loss 6.182, len train loss 0.225, col train loss 166.536


Epoch 74: 2batch [00:00, 12.05batch/s, loss=883]

epoch 73: avg test  loss 879.00, bar  test loss 6.046, len  test loss 0.253, col  test loss 166.444


Epoch 74: 272batch [00:22, 11.95batch/s, loss=986]


epoch 74: avg train loss 878.35, bar train loss 6.164, len train loss 0.224, col train loss 166.489


Epoch 75: 0batch [00:00, ?batch/s, loss=913]

epoch 74: avg test  loss 877.68, bar  test loss 6.010, len  test loss 0.224, col  test loss 166.372


Epoch 75: 272batch [00:22, 11.90batch/s, loss=883]


epoch 75: avg train loss 878.14, bar train loss 6.170, len train loss 0.225, col train loss 166.415
epoch 75: avg test  loss 878.20, bar  test loss 6.018, len  test loss 0.239, col  test loss 166.292


Epoch 76: 272batch [00:22, 11.86batch/s, loss=800]


epoch 76: avg train loss 878.00, bar train loss 6.179, len train loss 0.224, col train loss 166.344


Epoch 77: 2batch [00:00, 11.83batch/s, loss=905]

epoch 76: avg test  loss 877.11, bar  test loss 6.152, len  test loss 0.239, col  test loss 166.239


Epoch 77: 272batch [00:22, 11.93batch/s, loss=923]


epoch 77: avg train loss 877.73, bar train loss 6.170, len train loss 0.225, col train loss 166.266


Epoch 78: 2batch [00:00, 12.05batch/s, loss=875]

epoch 77: avg test  loss 877.12, bar  test loss 5.970, len  test loss 0.217, col  test loss 166.141


Epoch 78: 272batch [00:22, 11.91batch/s, loss=911]


epoch 78: avg train loss 877.46, bar train loss 6.152, len train loss 0.223, col train loss 166.220


Epoch 79: 2batch [00:00, 11.90batch/s, loss=878]

epoch 78: avg test  loss 876.85, bar  test loss 6.077, len  test loss 0.233, col  test loss 166.134


Epoch 79: 272batch [00:23, 11.75batch/s, loss=915]


epoch 79: avg train loss 877.27, bar train loss 6.134, len train loss 0.225, col train loss 166.167


Epoch 80: 2batch [00:00, 12.20batch/s, loss=878]

epoch 79: avg test  loss 876.40, bar  test loss 6.036, len  test loss 0.221, col  test loss 166.067


Epoch 80: 272batch [00:23, 11.81batch/s, loss=841]


epoch 80: avg train loss 876.96, bar train loss 6.118, len train loss 0.222, col train loss 166.127
epoch 80: avg test  loss 876.40, bar  test loss 5.997, len  test loss 0.231, col  test loss 165.946


Epoch 81: 272batch [00:22, 11.85batch/s, loss=892]


epoch 81: avg train loss 876.72, bar train loss 6.098, len train loss 0.224, col train loss 166.069


Epoch 82: 2batch [00:00, 12.05batch/s, loss=898]

epoch 81: avg test  loss 876.78, bar  test loss 6.054, len  test loss 0.297, col  test loss 166.067


Epoch 82: 272batch [00:22, 11.90batch/s, loss=763]


epoch 82: avg train loss 876.56, bar train loss 6.071, len train loss 0.226, col train loss 166.042


Epoch 83: 2batch [00:00, 12.12batch/s, loss=871]

epoch 82: avg test  loss 876.11, bar  test loss 5.987, len  test loss 0.220, col  test loss 165.972


Epoch 83: 272batch [00:22, 11.83batch/s, loss=990]


epoch 83: avg train loss 876.45, bar train loss 6.079, len train loss 0.222, col train loss 166.018


Epoch 84: 2batch [00:00, 11.70batch/s, loss=887]

epoch 83: avg test  loss 875.57, bar  test loss 5.974, len  test loss 0.216, col  test loss 165.957


Epoch 84: 272batch [00:22, 11.89batch/s, loss=837]


epoch 84: avg train loss 876.23, bar train loss 6.034, len train loss 0.223, col train loss 165.988


Epoch 85: 2batch [00:00, 11.98batch/s, loss=872]

epoch 84: avg test  loss 875.54, bar  test loss 5.957, len  test loss 0.216, col  test loss 165.896


Epoch 85: 272batch [00:23, 11.81batch/s, loss=887]


epoch 85: avg train loss 876.08, bar train loss 6.040, len train loss 0.222, col train loss 165.950
epoch 85: avg test  loss 875.78, bar  test loss 5.997, len  test loss 0.200, col  test loss 165.906


Epoch 86: 272batch [00:23, 11.81batch/s, loss=881]


epoch 86: avg train loss 876.13, bar train loss 6.028, len train loss 0.226, col train loss 165.944


Epoch 87: 2batch [00:00, 12.12batch/s, loss=904]

epoch 86: avg test  loss 876.01, bar  test loss 6.082, len  test loss 0.205, col  test loss 165.892


Epoch 87: 272batch [00:22, 11.86batch/s, loss=899]


epoch 87: avg train loss 875.76, bar train loss 5.997, len train loss 0.222, col train loss 165.899


Epoch 88: 0batch [00:00, ?batch/s, loss=858]

epoch 87: avg test  loss 874.99, bar  test loss 6.012, len  test loss 0.225, col  test loss 165.840


Epoch 88: 272batch [00:22, 11.83batch/s, loss=964]


epoch 88: avg train loss 875.70, bar train loss 5.998, len train loss 0.221, col train loss 165.868


Epoch 89: 2batch [00:00, 12.20batch/s, loss=886]

epoch 88: avg test  loss 875.00, bar  test loss 5.854, len  test loss 0.226, col  test loss 165.743


Epoch 89: 272batch [00:23, 11.82batch/s, loss=874]


epoch 89: avg train loss 875.44, bar train loss 5.976, len train loss 0.222, col train loss 165.827


Epoch 90: 2batch [00:00, 12.20batch/s, loss=861]

epoch 89: avg test  loss 875.08, bar  test loss 6.036, len  test loss 0.212, col  test loss 165.824


Epoch 90: 272batch [00:22, 11.86batch/s, loss=860]


epoch 90: avg train loss 875.40, bar train loss 5.970, len train loss 0.221, col train loss 165.819
epoch 90: avg test  loss 875.11, bar  test loss 5.914, len  test loss 0.235, col  test loss 165.757


Epoch 91: 272batch [00:23, 11.74batch/s, loss=831]


epoch 91: avg train loss 875.36, bar train loss 5.971, len train loss 0.220, col train loss 165.791


Epoch 92: 0batch [00:00, ?batch/s, loss=857]

epoch 91: avg test  loss 874.38, bar  test loss 5.864, len  test loss 0.217, col  test loss 165.689


Epoch 92: 272batch [00:22, 11.85batch/s, loss=953]


epoch 92: avg train loss 875.01, bar train loss 5.929, len train loss 0.221, col train loss 165.739


Epoch 93: 2batch [00:00, 12.35batch/s, loss=811]

epoch 92: avg test  loss 874.57, bar  test loss 5.778, len  test loss 0.204, col  test loss 165.708


Epoch 93: 272batch [00:22, 11.85batch/s, loss=864]


epoch 93: avg train loss 875.06, bar train loss 5.946, len train loss 0.220, col train loss 165.710


Epoch 94: 2batch [00:00, 11.98batch/s, loss=876]

epoch 93: avg test  loss 874.28, bar  test loss 5.850, len  test loss 0.207, col  test loss 165.563


Epoch 94: 272batch [00:22, 11.85batch/s, loss=895]


epoch 94: avg train loss 874.82, bar train loss 5.933, len train loss 0.219, col train loss 165.673


Epoch 95: 2batch [00:00, 12.20batch/s, loss=864]

epoch 94: avg test  loss 874.16, bar  test loss 5.837, len  test loss 0.209, col  test loss 165.616


Epoch 95: 272batch [00:23, 11.80batch/s, loss=898]


epoch 95: avg train loss 874.68, bar train loss 5.924, len train loss 0.221, col train loss 165.627
epoch 95: avg test  loss 874.73, bar  test loss 5.874, len  test loss 0.225, col  test loss 165.514


Epoch 96: 272batch [00:23, 11.79batch/s, loss=922]


epoch 96: avg train loss 874.53, bar train loss 5.926, len train loss 0.220, col train loss 165.573


Epoch 97: 2batch [00:00, 11.90batch/s, loss=876]

epoch 96: avg test  loss 873.71, bar  test loss 5.827, len  test loss 0.207, col  test loss 165.529


Epoch 97: 272batch [00:23, 11.79batch/s, loss=857]


epoch 97: avg train loss 874.42, bar train loss 5.927, len train loss 0.220, col train loss 165.526


Epoch 98: 2batch [00:00, 11.83batch/s, loss=859]

epoch 97: avg test  loss 873.75, bar  test loss 5.809, len  test loss 0.205, col  test loss 165.470


Epoch 98: 272batch [00:23, 11.81batch/s, loss=947]


epoch 98: avg train loss 874.21, bar train loss 5.922, len train loss 0.219, col train loss 165.469


Epoch 99: 2batch [00:00, 11.90batch/s, loss=902]

epoch 98: avg test  loss 874.18, bar  test loss 5.908, len  test loss 0.206, col  test loss 165.467


Epoch 99: 272batch [00:23, 11.83batch/s, loss=826]


epoch 99: avg train loss 874.16, bar train loss 5.945, len train loss 0.221, col train loss 165.404


Epoch 100: 2batch [00:00, 11.98batch/s, loss=905]

epoch 99: avg test  loss 873.50, bar  test loss 5.808, len  test loss 0.210, col  test loss 165.308


Epoch 100: 272batch [00:23, 11.81batch/s, loss=879]


epoch 100: avg train loss 873.97, bar train loss 5.959, len train loss 0.217, col train loss 165.328
epoch 100: avg test  loss 873.30, bar  test loss 5.796, len  test loss 0.198, col  test loss 165.194


Epoch 101: 272batch [00:23, 11.76batch/s, loss=979]


epoch 101: avg train loss 873.79, bar train loss 5.955, len train loss 0.220, col train loss 165.243


Epoch 102: 0batch [00:00, ?batch/s, loss=915]

epoch 101: avg test  loss 873.22, bar  test loss 5.926, len  test loss 0.216, col  test loss 165.205


Epoch 102: 272batch [00:22, 11.87batch/s, loss=918]


epoch 102: avg train loss 873.44, bar train loss 5.953, len train loss 0.220, col train loss 165.162


Epoch 103: 0batch [00:00, ?batch/s, loss=831]

epoch 102: avg test  loss 872.93, bar  test loss 5.788, len  test loss 0.232, col  test loss 164.984


Epoch 103: 272batch [00:23, 11.79batch/s, loss=895]


epoch 103: avg train loss 873.31, bar train loss 5.994, len train loss 0.219, col train loss 165.062


Epoch 104: 2batch [00:00, 12.20batch/s, loss=872]

epoch 103: avg test  loss 872.54, bar  test loss 5.801, len  test loss 0.209, col  test loss 164.984


Epoch 104: 272batch [00:23, 11.71batch/s, loss=753]


epoch 104: avg train loss 872.90, bar train loss 5.972, len train loss 0.219, col train loss 164.968


Epoch 105: 2batch [00:00, 12.27batch/s, loss=889]

epoch 104: avg test  loss 872.02, bar  test loss 5.862, len  test loss 0.216, col  test loss 164.783


Epoch 105: 272batch [00:23, 11.79batch/s, loss=897]


epoch 105: avg train loss 872.92, bar train loss 6.008, len train loss 0.218, col train loss 164.901
epoch 105: avg test  loss 871.95, bar  test loss 5.835, len  test loss 0.198, col  test loss 164.787


Epoch 106: 272batch [00:23, 11.61batch/s, loss=923]


epoch 106: avg train loss 872.55, bar train loss 6.000, len train loss 0.217, col train loss 164.801


Epoch 107: 0batch [00:00, ?batch/s, loss=899]

epoch 106: avg test  loss 872.06, bar  test loss 5.845, len  test loss 0.221, col  test loss 164.663


Epoch 107: 272batch [00:22, 11.83batch/s, loss=878]


epoch 107: avg train loss 872.43, bar train loss 6.024, len train loss 0.217, col train loss 164.712


Epoch 108: 2batch [00:00, 11.56batch/s, loss=871]

epoch 107: avg test  loss 871.88, bar  test loss 5.959, len  test loss 0.213, col  test loss 164.670


Epoch 108: 272batch [00:23, 11.80batch/s, loss=820]


epoch 108: avg train loss 872.19, bar train loss 6.003, len train loss 0.221, col train loss 164.640


Epoch 109: 0batch [00:00, ?batch/s, loss=902]

epoch 108: avg test  loss 871.29, bar  test loss 5.929, len  test loss 0.215, col  test loss 164.561


Epoch 109: 272batch [00:23, 11.66batch/s, loss=990]


epoch 109: avg train loss 871.89, bar train loss 6.015, len train loss 0.221, col train loss 164.538


Epoch 110: 2batch [00:00, 12.50batch/s, loss=928]

epoch 109: avg test  loss 872.14, bar  test loss 5.884, len  test loss 0.291, col  test loss 164.380


Epoch 110: 272batch [00:22, 11.96batch/s, loss=814]


epoch 110: avg train loss 871.51, bar train loss 6.001, len train loss 0.221, col train loss 164.433
epoch 110: avg test  loss 870.56, bar  test loss 5.937, len  test loss 0.212, col  test loss 164.337


Epoch 111: 272batch [00:22, 11.96batch/s, loss=992]


epoch 111: avg train loss 871.45, bar train loss 6.032, len train loss 0.220, col train loss 164.363


Epoch 112: 0batch [00:00, ?batch/s, loss=858]

epoch 111: avg test  loss 870.79, bar  test loss 5.929, len  test loss 0.219, col  test loss 164.191


Epoch 112: 272batch [00:23, 11.70batch/s, loss=868]


epoch 112: avg train loss 870.66, bar train loss 6.019, len train loss 0.218, col train loss 164.149


Epoch 113: 2batch [00:00, 11.98batch/s, loss=873]

epoch 112: avg test  loss 868.19, bar  test loss 5.774, len  test loss 0.202, col  test loss 163.726


Epoch 113: 272batch [00:21, 12.76batch/s, loss=827]


epoch 113: avg train loss 868.36, bar train loss 5.935, len train loss 0.219, col train loss 163.635


Epoch 114: 2batch [00:00, 14.08batch/s, loss=848]

epoch 113: avg test  loss 866.25, bar  test loss 5.805, len  test loss 0.219, col  test loss 163.383


Epoch 114: 272batch [00:21, 12.46batch/s, loss=890]


epoch 114: avg train loss 856.57, bar train loss 5.799, len train loss 0.217, col train loss 160.990


Epoch 115: 2batch [00:00, 12.58batch/s, loss=850]

epoch 114: avg test  loss 847.84, bar  test loss 5.333, len  test loss 0.204, col  test loss 159.456


Epoch 115: 272batch [00:21, 12.46batch/s, loss=936]


epoch 115: avg train loss 847.15, bar train loss 5.516, len train loss 0.218, col train loss 159.147
epoch 115: avg test  loss 846.24, bar  test loss 5.262, len  test loss 0.291, col  test loss 158.683


Epoch 116: 272batch [00:21, 12.40batch/s, loss=884]


epoch 116: avg train loss 845.30, bar train loss 5.492, len train loss 0.213, col train loss 158.787


Epoch 117: 2batch [00:00, 12.42batch/s, loss=883]

epoch 116: avg test  loss 843.64, bar  test loss 5.368, len  test loss 0.200, col  test loss 158.637


Epoch 117: 272batch [00:21, 12.40batch/s, loss=995]


epoch 117: avg train loss 844.30, bar train loss 5.410, len train loss 0.219, col train loss 158.625


Epoch 118: 2batch [00:00, 12.50batch/s, loss=857]

epoch 117: avg test  loss 843.19, bar  test loss 5.336, len  test loss 0.244, col  test loss 158.615


Epoch 118: 272batch [00:21, 12.41batch/s, loss=903]


epoch 118: avg train loss 843.44, bar train loss 5.390, len train loss 0.214, col train loss 158.449


Epoch 119: 2batch [00:00, 12.35batch/s, loss=840]

epoch 118: avg test  loss 842.96, bar  test loss 5.322, len  test loss 0.228, col  test loss 158.578


Epoch 119: 272batch [00:21, 12.42batch/s, loss=902]


epoch 119: avg train loss 842.56, bar train loss 5.342, len train loss 0.214, col train loss 158.310


Epoch 120: 2batch [00:00, 12.20batch/s, loss=878]

epoch 119: avg test  loss 842.38, bar  test loss 5.115, len  test loss 0.233, col  test loss 157.982


Epoch 120: 272batch [00:22, 12.36batch/s, loss=779]


epoch 120: avg train loss 842.04, bar train loss 5.324, len train loss 0.215, col train loss 158.197
epoch 120: avg test  loss 842.21, bar  test loss 5.373, len  test loss 0.238, col  test loss 158.355


Epoch 121: 272batch [00:22, 12.27batch/s, loss=870]


epoch 121: avg train loss 841.66, bar train loss 5.325, len train loss 0.217, col train loss 158.096


Epoch 122: 2batch [00:00, 12.58batch/s, loss=826]

epoch 121: avg test  loss 841.04, bar  test loss 5.296, len  test loss 0.213, col  test loss 158.072


Epoch 122: 272batch [00:22, 12.03batch/s, loss=855]


epoch 122: avg train loss 841.06, bar train loss 5.325, len train loss 0.214, col train loss 157.964


Epoch 123: 2batch [00:00, 12.42batch/s, loss=837]

epoch 122: avg test  loss 839.57, bar  test loss 5.227, len  test loss 0.204, col  test loss 157.794


Epoch 123: 272batch [00:23, 11.81batch/s, loss=912]


epoch 123: avg train loss 840.22, bar train loss 5.289, len train loss 0.215, col train loss 157.813


Epoch 124: 2batch [00:00, 12.05batch/s, loss=867]

epoch 123: avg test  loss 839.63, bar  test loss 5.163, len  test loss 0.218, col  test loss 157.682


Epoch 124: 272batch [00:22, 11.98batch/s, loss=757]


epoch 124: avg train loss 839.90, bar train loss 5.256, len train loss 0.213, col train loss 157.749


Epoch 125: 2batch [00:00, 11.83batch/s, loss=839]

epoch 124: avg test  loss 838.82, bar  test loss 5.160, len  test loss 0.224, col  test loss 157.601


Epoch 125: 272batch [00:22, 11.90batch/s, loss=811]


epoch 125: avg train loss 839.59, bar train loss 5.251, len train loss 0.213, col train loss 157.677
epoch 125: avg test  loss 838.61, bar  test loss 5.181, len  test loss 0.223, col  test loss 157.652


Epoch 126: 272batch [00:23, 11.74batch/s, loss=853]


epoch 126: avg train loss 839.22, bar train loss 5.263, len train loss 0.211, col train loss 157.571


Epoch 127: 2batch [00:00, 11.90batch/s, loss=836]

epoch 126: avg test  loss 838.26, bar  test loss 5.118, len  test loss 0.206, col  test loss 157.508


Epoch 127: 272batch [00:22, 11.85batch/s, loss=904]


epoch 127: avg train loss 838.71, bar train loss 5.220, len train loss 0.210, col train loss 157.507


Epoch 128: 2batch [00:00, 11.98batch/s, loss=842]

epoch 127: avg test  loss 838.75, bar  test loss 5.078, len  test loss 0.225, col  test loss 157.504


Epoch 128: 272batch [00:22, 11.85batch/s, loss=889]


epoch 128: avg train loss 838.62, bar train loss 5.220, len train loss 0.214, col train loss 157.457


Epoch 129: 2batch [00:00, 11.98batch/s, loss=822]

epoch 128: avg test  loss 838.27, bar  test loss 5.029, len  test loss 0.219, col  test loss 157.287


Epoch 129: 272batch [00:22, 11.86batch/s, loss=875]


epoch 129: avg train loss 838.00, bar train loss 5.178, len train loss 0.212, col train loss 157.364


Epoch 130: 2batch [00:00, 11.98batch/s, loss=834]

epoch 129: avg test  loss 837.77, bar  test loss 5.130, len  test loss 0.224, col  test loss 157.407


Epoch 130: 272batch [00:23, 11.65batch/s, loss=829]


epoch 130: avg train loss 837.86, bar train loss 5.172, len train loss 0.212, col train loss 157.317
epoch 130: avg test  loss 836.76, bar  test loss 5.126, len  test loss 0.210, col  test loss 157.322


Epoch 131: 272batch [00:23, 11.73batch/s, loss=889]


epoch 131: avg train loss 837.54, bar train loss 5.158, len train loss 0.213, col train loss 157.239


Epoch 132: 2batch [00:00, 11.63batch/s, loss=854]

epoch 131: avg test  loss 837.97, bar  test loss 5.103, len  test loss 0.220, col  test loss 157.311


Epoch 132: 272batch [00:22, 11.86batch/s, loss=776]


epoch 132: avg train loss 837.25, bar train loss 5.132, len train loss 0.211, col train loss 157.204


Epoch 133: 2batch [00:00, 11.76batch/s, loss=800]

epoch 132: avg test  loss 836.25, bar  test loss 4.997, len  test loss 0.195, col  test loss 156.927


Epoch 133: 272batch [00:23, 11.81batch/s, loss=822]


epoch 133: avg train loss 837.10, bar train loss 5.126, len train loss 0.211, col train loss 157.161


Epoch 134: 2batch [00:00, 12.20batch/s, loss=893]

epoch 133: avg test  loss 837.70, bar  test loss 5.112, len  test loss 0.197, col  test loss 157.340


Epoch 134: 272batch [00:22, 11.85batch/s, loss=957]


epoch 134: avg train loss 836.40, bar train loss 5.081, len train loss 0.213, col train loss 157.053


Epoch 135: 2batch [00:00, 11.90batch/s, loss=832]

epoch 134: avg test  loss 836.21, bar  test loss 4.959, len  test loss 0.200, col  test loss 156.967


Epoch 135: 272batch [00:23, 11.79batch/s, loss=814]


epoch 135: avg train loss 836.23, bar train loss 5.085, len train loss 0.211, col train loss 156.999
epoch 135: avg test  loss 835.43, bar  test loss 5.015, len  test loss 0.193, col  test loss 156.988


Epoch 136: 272batch [00:23, 11.50batch/s, loss=772]


epoch 136: avg train loss 835.82, bar train loss 5.067, len train loss 0.208, col train loss 156.941


Epoch 137: 2batch [00:00, 12.42batch/s, loss=842]

epoch 136: avg test  loss 835.27, bar  test loss 4.949, len  test loss 0.228, col  test loss 156.951


Epoch 137: 272batch [00:22, 11.87batch/s, loss=875]


epoch 137: avg train loss 835.52, bar train loss 5.058, len train loss 0.206, col train loss 156.878


Epoch 138: 2batch [00:00, 12.12batch/s, loss=805]

epoch 137: avg test  loss 835.03, bar  test loss 4.922, len  test loss 0.202, col  test loss 156.727


Epoch 138: 272batch [00:23, 11.74batch/s, loss=834]


epoch 138: avg train loss 835.32, bar train loss 5.045, len train loss 0.207, col train loss 156.815


Epoch 139: 2batch [00:00, 12.35batch/s, loss=845]

epoch 138: avg test  loss 834.33, bar  test loss 4.904, len  test loss 0.196, col  test loss 156.735


Epoch 139: 272batch [00:23, 11.81batch/s, loss=865]


epoch 139: avg train loss 834.57, bar train loss 5.002, len train loss 0.207, col train loss 156.712


Epoch 140: 1batch [00:00,  9.90batch/s, loss=811]

epoch 139: avg test  loss 834.21, bar  test loss 4.989, len  test loss 0.202, col  test loss 156.794


Epoch 140: 272batch [00:23, 11.60batch/s, loss=862]


epoch 140: avg train loss 834.81, bar train loss 5.016, len train loss 0.210, col train loss 156.725
epoch 140: avg test  loss 833.99, bar  test loss 4.899, len  test loss 0.199, col  test loss 156.683


Epoch 141: 272batch [00:23, 11.70batch/s, loss=918]


epoch 141: avg train loss 834.09, bar train loss 4.983, len train loss 0.207, col train loss 156.606


Epoch 142: 2batch [00:00, 11.83batch/s, loss=794]

epoch 141: avg test  loss 833.79, bar  test loss 4.913, len  test loss 0.214, col  test loss 156.520


Epoch 142: 272batch [00:23, 11.79batch/s, loss=889]


epoch 142: avg train loss 833.92, bar train loss 4.984, len train loss 0.208, col train loss 156.554


Epoch 143: 2batch [00:00, 11.76batch/s, loss=846]

epoch 142: avg test  loss 833.23, bar  test loss 4.914, len  test loss 0.191, col  test loss 156.521


Epoch 143: 272batch [00:23, 11.78batch/s, loss=784]


epoch 143: avg train loss 833.78, bar train loss 4.968, len train loss 0.209, col train loss 156.506


Epoch 144: 2batch [00:00, 11.76batch/s, loss=845]

epoch 143: avg test  loss 833.27, bar  test loss 4.931, len  test loss 0.200, col  test loss 156.538


Epoch 144: 272batch [00:23, 11.77batch/s, loss=932]


epoch 144: avg train loss 833.36, bar train loss 4.959, len train loss 0.208, col train loss 156.435


Epoch 145: 2batch [00:00, 11.76batch/s, loss=838]

epoch 144: avg test  loss 833.35, bar  test loss 4.900, len  test loss 0.232, col  test loss 156.521


Epoch 145: 272batch [00:23, 11.73batch/s, loss=847]


epoch 145: avg train loss 833.06, bar train loss 4.933, len train loss 0.210, col train loss 156.384
epoch 145: avg test  loss 832.31, bar  test loss 4.842, len  test loss 0.200, col  test loss 156.409


Epoch 146: 272batch [00:23, 11.73batch/s, loss=920]


epoch 146: avg train loss 832.95, bar train loss 4.933, len train loss 0.205, col train loss 156.352


Epoch 147: 0batch [00:00, ?batch/s, loss=850]

epoch 146: avg test  loss 833.81, bar  test loss 4.846, len  test loss 0.229, col  test loss 156.238


Epoch 147: 272batch [00:23, 11.73batch/s, loss=863]


epoch 147: avg train loss 832.67, bar train loss 4.929, len train loss 0.206, col train loss 156.286


Epoch 148: 0batch [00:00, ?batch/s, loss=815]

epoch 147: avg test  loss 832.69, bar  test loss 4.921, len  test loss 0.203, col  test loss 156.357


Epoch 148: 272batch [00:23, 11.77batch/s, loss=856]


epoch 148: avg train loss 832.41, bar train loss 4.903, len train loss 0.209, col train loss 156.238


Epoch 149: 2batch [00:00, 11.76batch/s, loss=830]

epoch 148: avg test  loss 832.73, bar  test loss 4.749, len  test loss 0.212, col  test loss 156.181


Epoch 149: 272batch [00:23, 11.48batch/s, loss=811]


epoch 149: avg train loss 832.22, bar train loss 4.901, len train loss 0.205, col train loss 156.182


Epoch 150: 0batch [00:00, ?batch/s, loss=829]

epoch 149: avg test  loss 832.25, bar  test loss 4.810, len  test loss 0.221, col  test loss 156.080


Epoch 150: 272batch [00:23, 11.75batch/s, loss=838]


epoch 150: avg train loss 832.19, bar train loss 4.906, len train loss 0.207, col train loss 156.179
epoch 150: avg test  loss 831.98, bar  test loss 4.782, len  test loss 0.196, col  test loss 156.032


Epoch 151: 272batch [00:23, 11.68batch/s, loss=856]


epoch 151: avg train loss 831.74, bar train loss 4.884, len train loss 0.207, col train loss 156.098


Epoch 152: 2batch [00:00, 11.98batch/s, loss=831]

epoch 151: avg test  loss 832.32, bar  test loss 4.774, len  test loss 0.211, col  test loss 156.071


Epoch 152: 272batch [00:23, 11.75batch/s, loss=849]


epoch 152: avg train loss 831.66, bar train loss 4.885, len train loss 0.209, col train loss 156.053


Epoch 153: 0batch [00:00, ?batch/s, loss=812]

epoch 152: avg test  loss 831.25, bar  test loss 4.810, len  test loss 0.206, col  test loss 156.044


Epoch 153: 272batch [00:23, 11.73batch/s, loss=865]


epoch 153: avg train loss 831.59, bar train loss 4.886, len train loss 0.207, col train loss 156.025


Epoch 154: 2batch [00:00, 12.05batch/s, loss=794]

epoch 153: avg test  loss 831.12, bar  test loss 4.767, len  test loss 0.190, col  test loss 155.849


Epoch 154: 272batch [00:23, 11.75batch/s, loss=803]


epoch 154: avg train loss 830.95, bar train loss 4.849, len train loss 0.202, col train loss 155.948


Epoch 155: 2batch [00:00, 12.05batch/s, loss=800]

epoch 154: avg test  loss 831.11, bar  test loss 4.812, len  test loss 0.196, col  test loss 155.931


Epoch 155: 272batch [00:23, 11.74batch/s, loss=850]


epoch 155: avg train loss 830.89, bar train loss 4.851, len train loss 0.208, col train loss 155.907
epoch 155: avg test  loss 831.74, bar  test loss 4.859, len  test loss 0.212, col  test loss 156.124


Epoch 156: 272batch [00:23, 11.66batch/s, loss=882]


epoch 156: avg train loss 830.96, bar train loss 4.864, len train loss 0.207, col train loss 155.888


Epoch 157: 0batch [00:00, ?batch/s, loss=818]

epoch 156: avg test  loss 830.45, bar  test loss 4.769, len  test loss 0.199, col  test loss 155.855


Epoch 157: 272batch [00:23, 11.73batch/s, loss=847]


epoch 157: avg train loss 830.60, bar train loss 4.862, len train loss 0.205, col train loss 155.810


Epoch 158: 2batch [00:00, 12.05batch/s, loss=836]

epoch 157: avg test  loss 830.48, bar  test loss 4.772, len  test loss 0.222, col  test loss 155.812


Epoch 158: 272batch [00:23, 11.69batch/s, loss=799]


epoch 158: avg train loss 830.45, bar train loss 4.835, len train loss 0.207, col train loss 155.794


Epoch 159: 2batch [00:00, 11.90batch/s, loss=833]

epoch 158: avg test  loss 829.90, bar  test loss 4.727, len  test loss 0.201, col  test loss 155.678


Epoch 159: 272batch [00:23, 11.69batch/s, loss=817]


epoch 159: avg train loss 830.18, bar train loss 4.832, len train loss 0.208, col train loss 155.725


Epoch 160: 0batch [00:00, ?batch/s, loss=805]

epoch 159: avg test  loss 829.96, bar  test loss 4.716, len  test loss 0.207, col  test loss 155.716


Epoch 160: 272batch [00:23, 11.52batch/s, loss=797]


epoch 160: avg train loss 830.07, bar train loss 4.831, len train loss 0.205, col train loss 155.705
epoch 160: avg test  loss 830.07, bar  test loss 4.793, len  test loss 0.207, col  test loss 155.680


Epoch 161: 272batch [00:23, 11.59batch/s, loss=816]


epoch 161: avg train loss 829.80, bar train loss 4.817, len train loss 0.208, col train loss 155.640


Epoch 162: 0batch [00:00, ?batch/s]

epoch 161: avg test  loss 829.31, bar  test loss 4.723, len  test loss 0.188, col  test loss 155.534


Epoch 162: 272batch [00:23, 11.59batch/s, loss=861]


epoch 162: avg train loss 829.74, bar train loss 4.818, len train loss 0.203, col train loss 155.637


Epoch 163: 2batch [00:00, 11.56batch/s, loss=795]

epoch 162: avg test  loss 829.72, bar  test loss 4.760, len  test loss 0.196, col  test loss 155.597


Epoch 163: 272batch [00:23, 11.66batch/s, loss=855]


epoch 163: avg train loss 829.86, bar train loss 4.825, len train loss 0.205, col train loss 155.627


Epoch 164: 2batch [00:00, 12.05batch/s, loss=884]

epoch 163: avg test  loss 829.22, bar  test loss 4.690, len  test loss 0.194, col  test loss 155.466


Epoch 164: 272batch [00:23, 11.66batch/s, loss=967]


epoch 164: avg train loss 829.65, bar train loss 4.819, len train loss 0.203, col train loss 155.573


Epoch 165: 0batch [00:00, ?batch/s, loss=818]

epoch 164: avg test  loss 830.75, bar  test loss 4.663, len  test loss 0.304, col  test loss 155.354


Epoch 165: 272batch [00:23, 11.57batch/s, loss=796]


epoch 165: avg train loss 829.14, bar train loss 4.799, len train loss 0.203, col train loss 155.495
epoch 165: avg test  loss 829.93, bar  test loss 4.797, len  test loss 0.207, col  test loss 155.703


Epoch 166: 272batch [00:23, 11.65batch/s, loss=882]


epoch 166: avg train loss 828.96, bar train loss 4.792, len train loss 0.204, col train loss 155.459


Epoch 167: 0batch [00:00, ?batch/s, loss=845]

epoch 166: avg test  loss 828.74, bar  test loss 4.761, len  test loss 0.222, col  test loss 155.550


Epoch 167: 272batch [00:23, 11.60batch/s, loss=831]


epoch 167: avg train loss 828.96, bar train loss 4.800, len train loss 0.202, col train loss 155.441


Epoch 168: 0batch [00:00, ?batch/s, loss=834]

epoch 167: avg test  loss 828.79, bar  test loss 4.729, len  test loss 0.234, col  test loss 155.466


Epoch 168: 272batch [00:23, 11.64batch/s, loss=728]


epoch 168: avg train loss 828.87, bar train loss 4.793, len train loss 0.204, col train loss 155.405


Epoch 169: 2batch [00:00, 12.05batch/s, loss=823]

epoch 168: avg test  loss 829.84, bar  test loss 4.751, len  test loss 0.214, col  test loss 155.604


Epoch 169: 272batch [00:23, 11.76batch/s, loss=774]


epoch 169: avg train loss 828.40, bar train loss 4.768, len train loss 0.199, col train loss 155.350


Epoch 170: 2batch [00:00, 12.12batch/s, loss=876]

epoch 169: avg test  loss 828.45, bar  test loss 4.778, len  test loss 0.205, col  test loss 155.363


Epoch 170: 272batch [00:23, 11.76batch/s, loss=833]


epoch 170: avg train loss 828.32, bar train loss 4.765, len train loss 0.197, col train loss 155.318
epoch 170: avg test  loss 828.48, bar  test loss 4.742, len  test loss 0.200, col  test loss 155.418


Epoch 171: 272batch [00:23, 11.66batch/s, loss=799]


epoch 171: avg train loss 828.30, bar train loss 4.774, len train loss 0.196, col train loss 155.308


Epoch 172: 2batch [00:00, 12.58batch/s, loss=828]

epoch 171: avg test  loss 828.16, bar  test loss 4.700, len  test loss 0.190, col  test loss 155.281


Epoch 172: 272batch [00:23, 11.53batch/s, loss=857]


epoch 172: avg train loss 828.08, bar train loss 4.769, len train loss 0.196, col train loss 155.260


KeyboardInterrupt: 

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 3500, 2000, save_folder="VAE10")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 5600, 2200, save_folder="VAE4")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')