In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime


writer = SummaryWriter(f"{link}/saved_models/new/NVAE1/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z_dim=128, d_dim=45, x_dim=7500, y_dim=2,
                 beta=10, rec_alpha = 1, rec_beta = 1, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z_dim = z_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images_org = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_bar, x_col, images = self.get_encoded_values(self.images_org, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len_new.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar_new.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col_new.npz')
            images= np.load(f'{link}/data/modmirbase_{ds}_images_cat_new.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        self.images= images
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.images_org = self.images_org[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]
        x_org = np.transpose(self.images_org[idx], (2,0,1))
        return (x, y, d, x_len, x_col, x_bar, mount, x_org)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,26,100), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)
        out_x = np.zeros((n,25,100,5))

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    out_col[i,25,j] = 1
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x,i,j),j] = 1 
                    #out_col[i, self.get_color(x[i,:,13,j]), 1, j] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1
                    tc = int(self._get_color(x[i,:,12,j]))
                    out_x[i,13-len1:13,j,tc]=1
                    

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
                    bc = int(self._get_color(x[i,:,13,j]))
                    out_x[i,13:13+len2,j,bc]=1
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len_new.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col_new.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar_new.npz', 'wb') as f:
            np.save(f, out_bar)
        with open(f'{link}/data/modmirbase_{ds}_images_cat_new.npz', 'wb') as f:
            np.save(f, out_x)
        

        return out_len, out_bar, out_col, out_x

    def get_color(self, x, i, j):
        
        col = self._get_color(x[i,:,12,j])+self._get_color(x[i,:,13,j])
        if col == '00':
            return 0
        if col == '01':
            return 1
        if col == '02':
            return 2
        if col == '03':
            return 3
        if col == '04':
            return 4
        if col == '10':
            return 5
        if col == '11':
            return 6
        if col == '12':
            return 7
        if col == '13':
            return 8
        if col == '14':
            return 9
        if col == '20':
            return 10
        if col == '21':
            return 11
        if col == '22':
            return 12
        if col == '23':
            return 13
        if col == '24':
            return 14
        if col == '30':
            return 15
        if col == '31':
            return 16
        if col == '32':
            return 17
        if col == '33':
            return 18
        if col == '34':
            return 19
        if col == '40':
            return 20
        if col == '41':
            return 21
        if col == '42':
            return 22
        if col == '43':
            return 23
        if col == '44':
            return 24
        
        
    
    def _get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return "0" # black
        elif (pixel == np.array([1,0,0])).all():  
            return "1" # red
        elif (pixel == np.array([0,0,1])).all():  
            return "2" # blue
        elif (pixel == np.array([0,1,0])).all():  
            return "3" # green
        elif (pixel == np.array([1,1,0])).all():  
            return "4" # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim, dim1=256, dim2=512):
        super(px, self).__init__()

        self.fc = nn.Sequential(nn.Linear(z_dim, dim1, bias=False),  
                                 nn.ReLU(),
                                nn.Linear(dim1, dim2),
                                nn.ReLU())
        
        # Predicting length and color of each bar
        
        self.color = nn.Sequential(nn.Linear(dim2, dim2),
                                   nn.ReLU(),
                                   nn.Linear(dim2, 2600))
        
        
        self.length_bar = nn.Sequential(nn.Linear(dim2,200), nn.Softplus())
        
        
    def forward(self, z):
        
        h = self.fc(z)
        
        
        
        len_bar = self.length_bar(h).reshape(-1,2,100)
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        
        
        col = self.color(h).reshape(-1,26,100)
        col_bar = nn.Softmax(dim=1)(col)
        
        return len_bar, len_bar_sc, col_bar

    def reconstruct_image(self, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0]),  # yellow
                  5: np.array([1,1,1])  # white
                  }
    
        _color_dict =  {0: (0,0),
                        1: (0,1),
                        2: (0,2),
                        3: (0,3),
                        4: (0,4),
                        5: (1,0),
                        6: (1,1),
                        7: (1,2),
                        8: (1,3),
                        9: (1,4),
                        10: (1,0),
                        11: (2,1),
                        12: (2,2),
                        13: (2,3),
                        14: (2,4),
                        15: (2,0),
                        16: (3,1),
                        17: (3,2),
                        18: (3,3),
                        19: (3,4),
                        20: (3,0),
                        21: (4,1),
                        22: (4,2),
                        23: (4,3),
                        24: (4,4),
                        25: (5,5)
                        }       
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_bar.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            limit = 100
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1, _col_bar_2 = _color_dict[np.argmax(col_bar[i,:,j])]
                    
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output

In [9]:
int(np.round(3.7, 0))
int(3.7)

3

In [10]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [11]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [12]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim, h_dim=1440):
        super(qz, self).__init__()
        self.h_dim = h_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 48, kernel_size=9, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=9, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 96, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(96, 144, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(144, 144, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(144, 240, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(240, 240, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )

        self.fc11 = nn.Sequential(nn.Linear(self.h_dim, z_dim))
        self.fc12 = nn.Sequential(nn.Linear(self.h_dim, z_dim), nn.Softplus())


    def forward(self, x):
        h = self.encoder(x)
        h = h.view(-1, self.h_dim)
        z_loc = self.fc11(h)
        z_scale = self.fc12(h) + 1e-7

        return z_loc, z_scale




In [13]:
enc = qz(128,10,10,1024)
summary(enc, (1,5,25,100))

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 240, 1, 6]            --
│    └─Conv2d: 2-1                       [1, 48, 25, 100]          19,488
│    └─ReLU: 2-2                         [1, 48, 25, 100]          --
│    └─Conv2d: 2-3                       [1, 48, 25, 100]          186,672
│    └─ReLU: 2-4                         [1, 48, 25, 100]          --
│    └─MaxPool2d: 2-5                    [1, 48, 12, 50]           --
│    └─Conv2d: 2-6                       [1, 96, 12, 50]           41,568
│    └─ReLU: 2-7                         [1, 96, 12, 50]           --
│    └─Conv2d: 2-8                       [1, 96, 12, 50]           83,040
│    └─ReLU: 2-9                         [1, 96, 12, 50]           --
│    └─MaxPool2d: 2-10                   [1, 96, 6, 25]            --
│    └─Conv2d: 2-11                      [1, 144, 6, 25]           1

## Full model class

In [14]:
class StampDIVA(nn.Module):
    def __init__(self, args):
        super(StampDIVA, self).__init__()
        self.z_dim = args.z_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim

        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup

        self.cuda()

    def forward(self, d, x, y):
        # Encode
        zd_q_loc, zd_q_scale = self.qz(x)
        
        # Reparameterization trick
        qz = dist.Normal(zd_q_loc, zd_q_scale)
        z_q = qz.rsample()
        
        
        # Decode
        x_bar, x_bar_scale, x_col = self.px(z_q)
        z_p_loc, z_p_scale = torch.zeros(z_q.size()[0], self.z_dim).cuda(),\
                        torch.ones(z_q.size()[0], self.z_dim).cuda()
        pz = dist.Normal(z_p_loc, z_p_scale)

        # Reparameterization trick
        pz = dist.Normal(z_p_loc, z_p_scale)
        
        return x_bar, x_bar_scale, x_col, qz, pz, z_q

    def loss_function(self, d, x, y, out_len, out_bar, out_col):
        
        x_bar, x_bar_scale, x_col, qz, pz, z_q = self.forward(d, x, y)
       
        mse_bar = (((x_bar - out_bar)**2)).mean(dim=(1,2)).sum()
        
        max_bar = torch.argmax(x_col, dim=1)
        acc_bar = (max_bar==torch.argmax(out_col, dim=1)).float().mean((1)).sum().float()
        
        CE_bar = mse_bar#-log_bar
        CE_col = F.cross_entropy(x_col, out_col, reduction='sum')

        KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
        return self.rec_beta * CE_bar \
                  + self.rec_gamma * CE_col \
                  - self.beta * KL_z, \
                  CE_bar, CE_col, mse_bar, acc_bar

In [15]:
default_args = diva_args(z_dim=256, rec_alpha = 10, rec_beta = 10, rec_gamma = 10, 
                         beta=1, warmup=1, prewarmup=0)
enc = StampDIVA(default_args)
summary(enc,[ (1,1),(1,5,25,100),(1,1)])

Layer (type:depth-idx)                   Output Shape              Param #
StampDIVA                                --                        --
├─qz: 1-1                                [1, 256]                  --
│    └─Sequential: 2-1                   [1, 240, 1, 6]            --
│    │    └─Conv2d: 3-1                  [1, 48, 25, 100]          19,488
│    │    └─ReLU: 3-2                    [1, 48, 25, 100]          --
│    │    └─Conv2d: 3-3                  [1, 48, 25, 100]          186,672
│    │    └─ReLU: 3-4                    [1, 48, 25, 100]          --
│    │    └─MaxPool2d: 3-5               [1, 48, 12, 50]           --
│    │    └─Conv2d: 3-6                  [1, 96, 12, 50]           41,568
│    │    └─ReLU: 3-7                    [1, 96, 12, 50]           --
│    │    └─Conv2d: 3-8                  [1, 96, 12, 50]           83,040
│    │    └─ReLU: 3-9                    [1, 96, 12, 50]           --
│    │    └─MaxPool2d: 3-10              [1, 96, 6, 25]            -

# Training the model

## Loading dataset

In [16]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [17]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [18]:
len(RNA_dataset)

34721

In [19]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    no_batches = 0
    mse_bar = 0
    acc_bar = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar,_,_) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        mse_bar += mse
        acc_bar += acc
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar

In [20]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    mse_bar = 0
    acc_bar = 0        
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar,_,_) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)
            loss, bar_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            mse_bar += mse
            acc_bar += acc
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar
  

In [21]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_col, mtr, atr = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_col_test, mte, ate = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train': atr, 'test':ate}, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [22]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        x_org = a[1][-1][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-2][:10].to(DEVICE).float()
        x_2, x_2var, x_3 ,qz, pz, z_q = diva(d,x,y)
        out = diva.px.reconstruct_image(x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x_org[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [23]:
DEVICE

device(type='cuda')

## Model Training

In [24]:
default_args = diva_args(z_dim=512, rec_alpha = 100, rec_beta = 20, rec_gamma = 10, 
                         beta=1, warmup=1, prewarmup=0)

In [25]:
diva = StampDIVA(default_args).to(DEVICE)

In [29]:
diva.load_state_dict(torch.load(f'{link}/saved_models/new/NVAE/checkpoints/193.pth'))

<All keys matched successfully>

In [30]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [31]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=0.0005)

In [32]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [33]:
writer.flush()

In [34]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/new/NVAE1/tensorboard/"

Reusing TensorBoard on port 6006 (pid 24040), started 13:36:14 ago. (Use '!kill 24040' to kill it.)

In [38]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 5, save_folder="new/NVAE1",save_interval=5)

Epoch 6: 272batch [00:48,  5.57batch/s, loss=2.87e+3]


epoch 6: avg train loss 2892.97, bar train loss 4.532, col train loss 278.664


Epoch 7: 0batch [00:00, ?batch/s]

epoch 6: avg test  loss 2891.75, bar  test loss 4.389, col  test loss 278.515


Epoch 7: 272batch [00:48,  5.58batch/s, loss=2.88e+3]


epoch 7: avg train loss 2883.72, bar train loss 4.268, col train loss 278.100


Epoch 8: 0batch [00:00, ?batch/s, loss=2.87e+3]

epoch 7: avg test  loss 2879.29, bar  test loss 4.142, col  test loss 277.900


Epoch 8: 272batch [00:48,  5.57batch/s, loss=2.9e+3] 


epoch 8: avg train loss 2875.65, bar train loss 4.029, col train loss 277.692


Epoch 9: 0batch [00:00, ?batch/s]

epoch 8: avg test  loss 2872.78, bar  test loss 3.880, col  test loss 277.563


Epoch 9: 272batch [00:48,  5.56batch/s, loss=2.91e+3]


epoch 9: avg train loss 2869.35, bar train loss 3.844, col train loss 277.396


Epoch 10: 0batch [00:00, ?batch/s]

epoch 9: avg test  loss 2867.11, bar  test loss 3.751, col  test loss 277.301


Epoch 10: 272batch [00:48,  5.55batch/s, loss=2.84e+3]


epoch 10: avg train loss 2864.24, bar train loss 3.690, col train loss 277.161
epoch 10: avg test  loss 2862.76, bar  test loss 3.621, col  test loss 277.188


Epoch 11: 272batch [00:48,  5.57batch/s, loss=2.82e+3]


epoch 11: avg train loss 2860.68, bar train loss 3.575, col train loss 276.992


Epoch 12: 0batch [00:00, ?batch/s, loss=2.84e+3]

epoch 11: avg test  loss 2861.15, bar  test loss 3.592, col  test loss 277.090


Epoch 12: 272batch [00:48,  5.56batch/s, loss=2.85e+3]


epoch 12: avg train loss 2857.28, bar train loss 3.477, col train loss 276.798


Epoch 13: 1batch [00:00,  5.59batch/s, loss=2.85e+3]

epoch 12: avg test  loss 2856.07, bar  test loss 3.454, col  test loss 276.764


Epoch 13: 272batch [00:48,  5.57batch/s, loss=2.89e+3]


epoch 13: avg train loss 2853.46, bar train loss 3.389, col train loss 276.543


Epoch 14: 0batch [00:00, ?batch/s]

epoch 13: avg test  loss 2852.15, bar  test loss 3.326, col  test loss 276.531


Epoch 14: 272batch [00:48,  5.57batch/s, loss=2.88e+3]


epoch 14: avg train loss 2849.65, bar train loss 3.308, col train loss 276.293


Epoch 15: 0batch [00:00, ?batch/s]

epoch 14: avg test  loss 2849.98, bar  test loss 3.245, col  test loss 276.236


Epoch 15: 272batch [00:48,  5.57batch/s, loss=2.87e+3]


epoch 15: avg train loss 2845.56, bar train loss 3.231, col train loss 275.993
epoch 15: avg test  loss 2846.67, bar  test loss 3.376, col  test loss 276.099


Epoch 16: 272batch [00:48,  5.57batch/s, loss=2.86e+3]


epoch 16: avg train loss 2842.19, bar train loss 3.165, col train loss 275.775


Epoch 17: 0batch [00:00, ?batch/s, loss=2.84e+3]

epoch 16: avg test  loss 2842.04, bar  test loss 3.230, col  test loss 275.830


Epoch 17: 272batch [00:48,  5.57batch/s, loss=2.84e+3]


epoch 17: avg train loss 2839.06, bar train loss 3.104, col train loss 275.557


Epoch 18: 0batch [00:00, ?batch/s, loss=2.8e+3]

epoch 17: avg test  loss 2839.69, bar  test loss 3.091, col  test loss 275.644


Epoch 18: 272batch [00:48,  5.60batch/s, loss=2.8e+3] 


epoch 18: avg train loss 2836.58, bar train loss 3.043, col train loss 275.412


Epoch 19: 0batch [00:00, ?batch/s]

epoch 18: avg test  loss 2837.14, bar  test loss 3.007, col  test loss 275.404


Epoch 19: 272batch [00:48,  5.60batch/s, loss=2.79e+3]


epoch 19: avg train loss 2834.09, bar train loss 2.993, col train loss 275.235


Epoch 20: 1batch [00:00,  5.65batch/s, loss=2.84e+3]

epoch 19: avg test  loss 2835.06, bar  test loss 2.993, col  test loss 275.378


Epoch 20: 272batch [00:48,  5.59batch/s, loss=2.86e+3]


epoch 20: avg train loss 2831.65, bar train loss 2.940, col train loss 275.071
epoch 20: avg test  loss 2833.15, bar  test loss 2.899, col  test loss 275.166


Epoch 21: 272batch [00:48,  5.58batch/s, loss=2.9e+3] 


epoch 21: avg train loss 2829.67, bar train loss 2.893, col train loss 274.951


Epoch 22: 1batch [00:00,  5.56batch/s, loss=2.9e+3]

epoch 21: avg test  loss 2831.03, bar  test loss 2.905, col  test loss 275.144


Epoch 22: 272batch [00:48,  5.59batch/s, loss=2.88e+3]


epoch 22: avg train loss 2828.11, bar train loss 2.846, col train loss 274.858


Epoch 23: 1batch [00:00,  5.65batch/s, loss=2.82e+3]

epoch 22: avg test  loss 2829.93, bar  test loss 2.863, col  test loss 275.055


Epoch 23: 272batch [00:48,  5.58batch/s, loss=2.84e+3]


epoch 23: avg train loss 2826.80, bar train loss 2.814, col train loss 274.774


Epoch 24: 0batch [00:00, ?batch/s]

epoch 23: avg test  loss 2828.52, bar  test loss 2.818, col  test loss 274.947


Epoch 24: 272batch [00:48,  5.58batch/s, loss=2.79e+3]


epoch 24: avg train loss 2825.12, bar train loss 2.775, col train loss 274.667


Epoch 25: 0batch [00:00, ?batch/s]

epoch 24: avg test  loss 2826.91, bar  test loss 2.760, col  test loss 274.878


Epoch 25: 272batch [00:48,  5.57batch/s, loss=2.81e+3]


epoch 25: avg train loss 2823.69, bar train loss 2.735, col train loss 274.592
epoch 25: avg test  loss 2825.47, bar  test loss 2.750, col  test loss 274.775


Epoch 26: 272batch [00:49,  5.54batch/s, loss=2.87e+3]


epoch 26: avg train loss 2822.61, bar train loss 2.705, col train loss 274.526


Epoch 27: 0batch [00:00, ?batch/s]

epoch 26: avg test  loss 2824.44, bar  test loss 2.655, col  test loss 274.692


Epoch 27: 272batch [00:48,  5.55batch/s, loss=2.82e+3]


epoch 27: avg train loss 2821.12, bar train loss 2.670, col train loss 274.429


Epoch 28: 1batch [00:00,  5.56batch/s, loss=2.83e+3]

epoch 27: avg test  loss 2824.19, bar  test loss 2.625, col  test loss 274.650


Epoch 28: 272batch [00:49,  5.53batch/s, loss=2.81e+3]


epoch 28: avg train loss 2819.62, bar train loss 2.634, col train loss 274.331


Epoch 29: 0batch [00:00, ?batch/s]

epoch 28: avg test  loss 2822.64, bar  test loss 2.595, col  test loss 274.558


Epoch 29: 272batch [00:49,  5.55batch/s, loss=2.87e+3]


epoch 29: avg train loss 2818.24, bar train loss 2.602, col train loss 274.249


Epoch 30: 0batch [00:00, ?batch/s]

epoch 29: avg test  loss 2821.02, bar  test loss 2.608, col  test loss 274.561


Epoch 30: 272batch [00:49,  5.54batch/s, loss=2.81e+3]


epoch 30: avg train loss 2816.70, bar train loss 2.576, col train loss 274.129
epoch 30: avg test  loss 2820.05, bar  test loss 2.558, col  test loss 274.316


Epoch 31: 272batch [00:49,  5.54batch/s, loss=2.81e+3]


epoch 31: avg train loss 2814.91, bar train loss 2.543, col train loss 273.991


Epoch 32: 1batch [00:00,  5.52batch/s, loss=2.78e+3]

epoch 31: avg test  loss 2817.27, bar  test loss 2.526, col  test loss 274.223


Epoch 32: 272batch [00:48,  5.57batch/s, loss=2.84e+3]


epoch 32: avg train loss 2813.35, bar train loss 2.525, col train loss 273.863


Epoch 33: 0batch [00:00, ?batch/s]

epoch 32: avg test  loss 2816.34, bar  test loss 2.534, col  test loss 274.145


Epoch 33: 272batch [00:48,  5.55batch/s, loss=2.81e+3]


epoch 33: avg train loss 2811.80, bar train loss 2.498, col train loss 273.738


Epoch 34: 0batch [00:00, ?batch/s]

epoch 33: avg test  loss 2814.72, bar  test loss 2.496, col  test loss 274.077


Epoch 34: 272batch [00:48,  5.55batch/s, loss=2.84e+3]


epoch 34: avg train loss 2810.72, bar train loss 2.481, col train loss 273.652


Epoch 35: 1batch [00:00,  5.52batch/s, loss=2.81e+3]

epoch 34: avg test  loss 2814.06, bar  test loss 2.449, col  test loss 273.979


Epoch 35: 272batch [00:49,  5.53batch/s, loss=2.8e+3] 


epoch 35: avg train loss 2809.61, bar train loss 2.452, col train loss 273.577
epoch 35: avg test  loss 2813.31, bar  test loss 2.467, col  test loss 273.927


Epoch 36: 272batch [00:49,  5.53batch/s, loss=2.81e+3]


epoch 36: avg train loss 2808.57, bar train loss 2.428, col train loss 273.513


Epoch 37: 0batch [00:00, ?batch/s]

epoch 36: avg test  loss 2812.82, bar  test loss 2.447, col  test loss 273.939


Epoch 37: 272batch [00:49,  5.54batch/s, loss=2.82e+3]


epoch 37: avg train loss 2807.59, bar train loss 2.408, col train loss 273.445


Epoch 38: 0batch [00:00, ?batch/s]

epoch 37: avg test  loss 2812.46, bar  test loss 2.491, col  test loss 273.891


Epoch 38: 272batch [00:49,  5.55batch/s, loss=2.82e+3]


epoch 38: avg train loss 2806.58, bar train loss 2.391, col train loss 273.366


Epoch 39: 0batch [00:00, ?batch/s]

epoch 38: avg test  loss 2811.86, bar  test loss 2.405, col  test loss 273.791


Epoch 39: 272batch [00:48,  5.56batch/s, loss=2.8e+3] 


epoch 39: avg train loss 2805.37, bar train loss 2.367, col train loss 273.275


Epoch 40: 0batch [00:00, ?batch/s]

epoch 39: avg test  loss 2809.93, bar  test loss 2.357, col  test loss 273.717


Epoch 40: 272batch [00:49,  5.55batch/s, loss=2.79e+3]


epoch 40: avg train loss 2804.24, bar train loss 2.343, col train loss 273.193
epoch 40: avg test  loss 2809.32, bar  test loss 2.354, col  test loss 273.610


Epoch 41: 272batch [00:49,  5.53batch/s, loss=2.78e+3]


epoch 41: avg train loss 2803.34, bar train loss 2.331, col train loss 273.120


Epoch 42: 0batch [00:00, ?batch/s]

epoch 41: avg test  loss 2808.44, bar  test loss 2.340, col  test loss 273.602


Epoch 42: 272batch [00:49,  5.55batch/s, loss=2.86e+3]


epoch 42: avg train loss 2802.24, bar train loss 2.315, col train loss 273.028


Epoch 43: 0batch [00:00, ?batch/s]

epoch 42: avg test  loss 2807.77, bar  test loss 2.353, col  test loss 273.612


Epoch 43: 272batch [00:49,  5.54batch/s, loss=2.83e+3]


epoch 43: avg train loss 2801.43, bar train loss 2.298, col train loss 272.963


Epoch 44: 0batch [00:00, ?batch/s]

epoch 43: avg test  loss 2806.86, bar  test loss 2.310, col  test loss 273.491


Epoch 44: 272batch [00:49,  5.55batch/s, loss=2.8e+3] 


epoch 44: avg train loss 2800.46, bar train loss 2.282, col train loss 272.892


Epoch 45: 0batch [00:00, ?batch/s]

epoch 44: avg test  loss 2806.02, bar  test loss 2.314, col  test loss 273.422


Epoch 45: 272batch [00:49,  5.54batch/s, loss=2.77e+3]


epoch 45: avg train loss 2799.89, bar train loss 2.273, col train loss 272.842
epoch 45: avg test  loss 2806.31, bar  test loss 2.312, col  test loss 273.463


Epoch 46: 272batch [00:49,  5.54batch/s, loss=2.79e+3]


epoch 46: avg train loss 2799.20, bar train loss 2.255, col train loss 272.800


Epoch 47: 0batch [00:00, ?batch/s]

epoch 46: avg test  loss 2805.46, bar  test loss 2.282, col  test loss 273.376


Epoch 47: 272batch [00:49,  5.55batch/s, loss=2.78e+3]


epoch 47: avg train loss 2798.64, bar train loss 2.247, col train loss 272.750


Epoch 48: 0batch [00:00, ?batch/s]

epoch 47: avg test  loss 2804.78, bar  test loss 2.229, col  test loss 273.312


Epoch 48: 272batch [00:49,  5.55batch/s, loss=2.77e+3]


epoch 48: avg train loss 2797.95, bar train loss 2.236, col train loss 272.692


Epoch 49: 0batch [00:00, ?batch/s]

epoch 48: avg test  loss 2804.75, bar  test loss 2.260, col  test loss 273.305


Epoch 49: 272batch [00:49,  5.54batch/s, loss=2.81e+3]


epoch 49: avg train loss 2797.32, bar train loss 2.223, col train loss 272.642


Epoch 50: 0batch [00:00, ?batch/s]

epoch 49: avg test  loss 2804.15, bar  test loss 2.251, col  test loss 273.288


Epoch 50: 272batch [00:49,  5.53batch/s, loss=2.84e+3]


epoch 50: avg train loss 2796.89, bar train loss 2.218, col train loss 272.604
epoch 50: avg test  loss 2803.68, bar  test loss 2.241, col  test loss 273.229


Epoch 51: 272batch [00:49,  5.54batch/s, loss=2.83e+3]


epoch 51: avg train loss 2795.99, bar train loss 2.201, col train loss 272.539


Epoch 52: 1batch [00:00,  5.62batch/s, loss=2.82e+3]

epoch 51: avg test  loss 2803.47, bar  test loss 2.271, col  test loss 273.180


Epoch 52: 272batch [00:49,  5.53batch/s, loss=2.81e+3]


epoch 52: avg train loss 2795.37, bar train loss 2.188, col train loss 272.478


Epoch 53: 0batch [00:00, ?batch/s]

epoch 52: avg test  loss 2802.66, bar  test loss 2.215, col  test loss 273.185


Epoch 53: 272batch [00:49,  5.54batch/s, loss=2.78e+3]


epoch 53: avg train loss 2794.58, bar train loss 2.175, col train loss 272.416


Epoch 54: 0batch [00:00, ?batch/s, loss=2.79e+3]

epoch 53: avg test  loss 2802.55, bar  test loss 2.198, col  test loss 273.076


Epoch 54: 272batch [00:49,  5.54batch/s, loss=2.83e+3]


epoch 54: avg train loss 2794.23, bar train loss 2.173, col train loss 272.380


Epoch 55: 0batch [00:00, ?batch/s]

epoch 54: avg test  loss 2802.11, bar  test loss 2.205, col  test loss 273.104


Epoch 55: 272batch [00:49,  5.54batch/s, loss=2.81e+3]


epoch 55: avg train loss 2793.47, bar train loss 2.161, col train loss 272.316
epoch 55: avg test  loss 2801.71, bar  test loss 2.194, col  test loss 273.080


Epoch 56: 272batch [00:49,  5.53batch/s, loss=2.81e+3]


epoch 56: avg train loss 2792.99, bar train loss 2.154, col train loss 272.274


Epoch 57: 0batch [00:00, ?batch/s]

epoch 56: avg test  loss 2801.40, bar  test loss 2.197, col  test loss 273.026


Epoch 57: 272batch [00:49,  5.54batch/s, loss=2.81e+3]


epoch 57: avg train loss 2792.15, bar train loss 2.138, col train loss 272.217


Epoch 58: 0batch [00:00, ?batch/s]

epoch 57: avg test  loss 2800.91, bar  test loss 2.179, col  test loss 273.007


Epoch 58: 272batch [00:49,  5.53batch/s, loss=2.82e+3]


epoch 58: avg train loss 2791.50, bar train loss 2.130, col train loss 272.147


Epoch 59: 0batch [00:00, ?batch/s]

epoch 58: avg test  loss 2800.15, bar  test loss 2.161, col  test loss 272.948


Epoch 59: 272batch [00:49,  5.53batch/s, loss=2.77e+3]


epoch 59: avg train loss 2790.62, bar train loss 2.122, col train loss 272.067


Epoch 60: 0batch [00:00, ?batch/s, loss=2.79e+3]

epoch 59: avg test  loss 2799.28, bar  test loss 2.161, col  test loss 272.859


Epoch 60: 272batch [00:49,  5.54batch/s, loss=2.8e+3] 


epoch 60: avg train loss 2789.83, bar train loss 2.113, col train loss 271.993
epoch 60: avg test  loss 2798.95, bar  test loss 2.179, col  test loss 272.828


Epoch 61: 272batch [00:49,  5.53batch/s, loss=2.8e+3] 


epoch 61: avg train loss 2789.10, bar train loss 2.107, col train loss 271.928


Epoch 62: 0batch [00:00, ?batch/s]

epoch 61: avg test  loss 2798.34, bar  test loss 2.174, col  test loss 272.796


Epoch 62: 272batch [00:49,  5.53batch/s, loss=2.78e+3]


epoch 62: avg train loss 2788.73, bar train loss 2.101, col train loss 271.885


Epoch 63: 0batch [00:00, ?batch/s]

epoch 62: avg test  loss 2798.34, bar  test loss 2.179, col  test loss 272.740


Epoch 63: 272batch [00:49,  5.53batch/s, loss=2.79e+3]


epoch 63: avg train loss 2787.91, bar train loss 2.089, col train loss 271.827


Epoch 64: 0batch [00:00, ?batch/s]

epoch 63: avg test  loss 2797.76, bar  test loss 2.138, col  test loss 272.677


Epoch 64: 272batch [00:49,  5.54batch/s, loss=2.8e+3] 


epoch 64: avg train loss 2787.36, bar train loss 2.078, col train loss 271.773


Epoch 65: 1batch [00:00,  5.56batch/s, loss=2.79e+3]

epoch 64: avg test  loss 2797.66, bar  test loss 2.106, col  test loss 272.703


Epoch 65: 272batch [00:49,  5.54batch/s, loss=2.79e+3]


epoch 65: avg train loss 2786.73, bar train loss 2.071, col train loss 271.724
epoch 65: avg test  loss 2797.31, bar  test loss 2.132, col  test loss 272.666


Epoch 66: 272batch [00:49,  5.52batch/s, loss=2.82e+3]


epoch 66: avg train loss 2786.26, bar train loss 2.068, col train loss 271.678


Epoch 67: 0batch [00:00, ?batch/s]

epoch 66: avg test  loss 2796.48, bar  test loss 2.138, col  test loss 272.644


Epoch 67: 272batch [00:49,  5.53batch/s, loss=2.75e+3]


epoch 67: avg train loss 2785.56, bar train loss 2.057, col train loss 271.610


Epoch 68: 0batch [00:00, ?batch/s]

epoch 67: avg test  loss 2795.98, bar  test loss 2.113, col  test loss 272.586


Epoch 68: 272batch [00:49,  5.52batch/s, loss=2.75e+3]


epoch 68: avg train loss 2785.05, bar train loss 2.050, col train loss 271.559


Epoch 69: 1batch [00:00,  5.49batch/s, loss=2.79e+3]

epoch 68: avg test  loss 2795.34, bar  test loss 2.097, col  test loss 272.524


Epoch 69: 272batch [00:49,  5.53batch/s, loss=2.8e+3] 


epoch 69: avg train loss 2784.22, bar train loss 2.043, col train loss 271.497


Epoch 70: 1batch [00:00,  5.62batch/s, loss=2.78e+3]

epoch 69: avg test  loss 2795.73, bar  test loss 2.096, col  test loss 272.537


Epoch 70: 272batch [00:49,  5.53batch/s, loss=2.77e+3]


epoch 70: avg train loss 2783.75, bar train loss 2.038, col train loss 271.440
epoch 70: avg test  loss 2796.00, bar  test loss 2.098, col  test loss 272.472


Epoch 71: 272batch [00:49,  5.52batch/s, loss=2.77e+3]


epoch 71: avg train loss 2783.29, bar train loss 2.029, col train loss 271.409


Epoch 72: 1batch [00:00,  5.65batch/s, loss=2.75e+3]

epoch 71: avg test  loss 2794.75, bar  test loss 2.110, col  test loss 272.482


Epoch 72: 272batch [00:49,  5.54batch/s, loss=2.78e+3]


epoch 72: avg train loss 2782.78, bar train loss 2.023, col train loss 271.354


Epoch 73: 1batch [00:00,  5.65batch/s, loss=2.77e+3]

epoch 72: avg test  loss 2794.58, bar  test loss 2.085, col  test loss 272.480


Epoch 73: 272batch [00:49,  5.54batch/s, loss=2.78e+3]


epoch 73: avg train loss 2782.36, bar train loss 2.023, col train loss 271.315


Epoch 74: 0batch [00:00, ?batch/s]

epoch 73: avg test  loss 2794.17, bar  test loss 2.085, col  test loss 272.410


Epoch 74: 272batch [00:49,  5.55batch/s, loss=2.8e+3] 


epoch 74: avg train loss 2781.63, bar train loss 2.013, col train loss 271.251


Epoch 75: 0batch [00:00, ?batch/s]

epoch 74: avg test  loss 2793.49, bar  test loss 2.046, col  test loss 272.357


Epoch 75: 272batch [00:49,  5.52batch/s, loss=2.75e+3]


epoch 75: avg train loss 2781.19, bar train loss 2.003, col train loss 271.211
epoch 75: avg test  loss 2794.08, bar  test loss 2.055, col  test loss 272.380


Epoch 76: 272batch [00:49,  5.53batch/s, loss=2.83e+3]


epoch 76: avg train loss 2780.64, bar train loss 1.999, col train loss 271.161


Epoch 77: 1batch [00:00,  5.52batch/s, loss=2.79e+3]

epoch 76: avg test  loss 2793.76, bar  test loss 2.053, col  test loss 272.359


Epoch 77: 272batch [00:49,  5.53batch/s, loss=2.79e+3]


epoch 77: avg train loss 2780.23, bar train loss 1.997, col train loss 271.112


Epoch 78: 0batch [00:00, ?batch/s]

epoch 77: avg test  loss 2793.15, bar  test loss 2.045, col  test loss 272.292


Epoch 78: 272batch [00:49,  5.53batch/s, loss=2.79e+3]


epoch 78: avg train loss 2779.51, bar train loss 1.983, col train loss 271.067


Epoch 79: 1batch [00:00,  5.56batch/s, loss=2.78e+3]

epoch 78: avg test  loss 2793.43, bar  test loss 2.096, col  test loss 272.290


Epoch 79: 272batch [00:49,  5.53batch/s, loss=2.77e+3]


epoch 79: avg train loss 2779.08, bar train loss 1.982, col train loss 271.013


Epoch 80: 0batch [00:00, ?batch/s]

epoch 79: avg test  loss 2793.11, bar  test loss 2.065, col  test loss 272.356


Epoch 80: 272batch [00:49,  5.52batch/s, loss=2.8e+3] 


epoch 80: avg train loss 2778.56, bar train loss 1.978, col train loss 270.970
epoch 80: avg test  loss 2793.20, bar  test loss 2.039, col  test loss 272.308


Epoch 81: 272batch [00:49,  5.55batch/s, loss=2.74e+3]


epoch 81: avg train loss 2778.22, bar train loss 1.973, col train loss 270.927


Epoch 82: 1batch [00:00,  5.62batch/s, loss=2.76e+3]

epoch 81: avg test  loss 2792.68, bar  test loss 2.034, col  test loss 272.270


Epoch 82: 272batch [00:49,  5.52batch/s, loss=2.77e+3]


epoch 82: avg train loss 2777.83, bar train loss 1.971, col train loss 270.885


Epoch 83: 0batch [00:00, ?batch/s]

epoch 82: avg test  loss 2792.44, bar  test loss 2.036, col  test loss 272.275


Epoch 83: 272batch [00:49,  5.52batch/s, loss=2.77e+3]


epoch 83: avg train loss 2777.46, bar train loss 1.965, col train loss 270.862


Epoch 84: 0batch [00:00, ?batch/s]

epoch 83: avg test  loss 2792.00, bar  test loss 2.051, col  test loss 272.198


Epoch 84: 272batch [00:49,  5.53batch/s, loss=2.77e+3]


epoch 84: avg train loss 2776.80, bar train loss 1.958, col train loss 270.802


Epoch 85: 0batch [00:00, ?batch/s]

epoch 84: avg test  loss 2792.61, bar  test loss 2.017, col  test loss 272.261


Epoch 85: 272batch [00:49,  5.52batch/s, loss=2.78e+3]


epoch 85: avg train loss 2776.49, bar train loss 1.957, col train loss 270.772
epoch 85: avg test  loss 2792.06, bar  test loss 2.011, col  test loss 272.182


Epoch 86: 272batch [00:49,  5.53batch/s, loss=2.72e+3]


epoch 86: avg train loss 2776.04, bar train loss 1.952, col train loss 270.727


Epoch 87: 0batch [00:00, ?batch/s]

epoch 86: avg test  loss 2791.46, bar  test loss 2.016, col  test loss 272.169


Epoch 87: 272batch [00:49,  5.53batch/s, loss=2.78e+3]


epoch 87: avg train loss 2775.50, bar train loss 1.943, col train loss 270.686


Epoch 88: 1batch [00:00,  5.52batch/s, loss=2.76e+3]

epoch 87: avg test  loss 2791.55, bar  test loss 1.999, col  test loss 272.146


Epoch 88: 272batch [00:49,  5.52batch/s, loss=2.78e+3]


epoch 88: avg train loss 2775.11, bar train loss 1.939, col train loss 270.645


Epoch 89: 1batch [00:00,  5.46batch/s, loss=2.8e+3]

epoch 88: avg test  loss 2791.46, bar  test loss 2.019, col  test loss 272.175


Epoch 89: 272batch [00:49,  5.50batch/s, loss=2.74e+3]


epoch 89: avg train loss 2774.79, bar train loss 1.937, col train loss 270.606


Epoch 90: 0batch [00:00, ?batch/s]

epoch 89: avg test  loss 2790.75, bar  test loss 1.998, col  test loss 272.135


Epoch 90: 272batch [00:49,  5.52batch/s, loss=2.79e+3]


epoch 90: avg train loss 2774.28, bar train loss 1.934, col train loss 270.559
epoch 90: avg test  loss 2791.07, bar  test loss 2.017, col  test loss 272.106


Epoch 91: 272batch [00:49,  5.52batch/s, loss=2.79e+3]


epoch 91: avg train loss 2773.74, bar train loss 1.929, col train loss 270.510


Epoch 92: 0batch [00:00, ?batch/s]

epoch 91: avg test  loss 2791.18, bar  test loss 1.994, col  test loss 272.120


Epoch 92: 272batch [00:49,  5.52batch/s, loss=2.77e+3]


epoch 92: avg train loss 2773.56, bar train loss 1.923, col train loss 270.484


Epoch 93: 0batch [00:00, ?batch/s]

epoch 92: avg test  loss 2790.89, bar  test loss 2.003, col  test loss 272.075


Epoch 93: 272batch [00:49,  5.52batch/s, loss=2.76e+3]


epoch 93: avg train loss 2772.98, bar train loss 1.915, col train loss 270.433


Epoch 94: 0batch [00:00, ?batch/s]

epoch 93: avg test  loss 2790.51, bar  test loss 1.998, col  test loss 272.024


Epoch 94: 272batch [00:49,  5.52batch/s, loss=2.75e+3]


epoch 94: avg train loss 2772.54, bar train loss 1.913, col train loss 270.399


Epoch 95: 0batch [00:00, ?batch/s]

epoch 94: avg test  loss 2790.54, bar  test loss 2.008, col  test loss 272.075


Epoch 95: 272batch [00:49,  5.49batch/s, loss=2.78e+3]


epoch 95: avg train loss 2772.29, bar train loss 1.912, col train loss 270.367
epoch 95: avg test  loss 2791.19, bar  test loss 1.978, col  test loss 272.070


Epoch 96: 272batch [00:49,  5.52batch/s, loss=2.79e+3]


epoch 96: avg train loss 2771.69, bar train loss 1.910, col train loss 270.310


Epoch 97: 0batch [00:00, ?batch/s]

epoch 96: avg test  loss 2790.39, bar  test loss 1.994, col  test loss 272.027


Epoch 97: 272batch [00:49,  5.54batch/s, loss=2.8e+3] 


epoch 97: avg train loss 2771.16, bar train loss 1.903, col train loss 270.266


Epoch 98: 1batch [00:00,  5.59batch/s, loss=2.77e+3]

epoch 97: avg test  loss 2791.06, bar  test loss 1.966, col  test loss 271.995


Epoch 98: 272batch [00:49,  5.52batch/s, loss=2.77e+3]


epoch 98: avg train loss 2770.81, bar train loss 1.899, col train loss 270.221


Epoch 99: 0batch [00:00, ?batch/s]

epoch 98: avg test  loss 2789.89, bar  test loss 1.967, col  test loss 271.971


Epoch 99: 272batch [00:49,  5.51batch/s, loss=2.76e+3]


epoch 99: avg train loss 2770.50, bar train loss 1.898, col train loss 270.202


Epoch 100: 0batch [00:00, ?batch/s]

epoch 99: avg test  loss 2789.44, bar  test loss 1.977, col  test loss 272.005


Epoch 100: 272batch [00:49,  5.50batch/s, loss=2.85e+3]


epoch 100: avg train loss 2770.22, bar train loss 1.892, col train loss 270.174
epoch 100: avg test  loss 2789.53, bar  test loss 1.968, col  test loss 271.980


Epoch 101: 272batch [00:50,  5.43batch/s, loss=2.76e+3]


epoch 101: avg train loss 2769.66, bar train loss 1.887, col train loss 270.119


Epoch 102: 0batch [00:00, ?batch/s]

epoch 101: avg test  loss 2789.51, bar  test loss 1.952, col  test loss 271.946


Epoch 102: 272batch [00:49,  5.49batch/s, loss=2.73e+3]


epoch 102: avg train loss 2769.15, bar train loss 1.879, col train loss 270.068


Epoch 103: 1batch [00:00,  5.68batch/s, loss=2.76e+3]

epoch 102: avg test  loss 2789.68, bar  test loss 1.972, col  test loss 271.976


Epoch 103: 272batch [00:49,  5.51batch/s, loss=2.77e+3]


epoch 103: avg train loss 2768.65, bar train loss 1.880, col train loss 270.031


Epoch 104: 0batch [00:00, ?batch/s]

epoch 103: avg test  loss 2789.60, bar  test loss 1.952, col  test loss 271.969


Epoch 104: 272batch [00:49,  5.51batch/s, loss=2.78e+3]


epoch 104: avg train loss 2768.62, bar train loss 1.879, col train loss 270.011


Epoch 105: 0batch [00:00, ?batch/s, loss=2.75e+3]

epoch 104: avg test  loss 2789.16, bar  test loss 1.958, col  test loss 271.943


Epoch 105: 272batch [00:49,  5.51batch/s, loss=2.8e+3] 


epoch 105: avg train loss 2768.16, bar train loss 1.873, col train loss 269.967
epoch 105: avg test  loss 2788.56, bar  test loss 1.954, col  test loss 271.875


Epoch 106: 272batch [00:49,  5.50batch/s, loss=2.76e+3]


epoch 106: avg train loss 2767.93, bar train loss 1.869, col train loss 269.949


Epoch 107: 1batch [00:00,  5.59batch/s, loss=2.76e+3]

epoch 106: avg test  loss 2788.42, bar  test loss 1.956, col  test loss 271.915


Epoch 107: 272batch [00:49,  5.51batch/s, loss=2.73e+3]


epoch 107: avg train loss 2767.19, bar train loss 1.864, col train loss 269.882


Epoch 108: 0batch [00:00, ?batch/s]

epoch 107: avg test  loss 2788.60, bar  test loss 1.973, col  test loss 271.868


Epoch 108: 272batch [00:50,  5.43batch/s, loss=2.78e+3]


epoch 108: avg train loss 2766.80, bar train loss 1.859, col train loss 269.851


Epoch 109: 0batch [00:00, ?batch/s]

epoch 108: avg test  loss 2788.94, bar  test loss 1.915, col  test loss 271.890


Epoch 109: 272batch [00:49,  5.47batch/s, loss=2.73e+3]


epoch 109: avg train loss 2766.65, bar train loss 1.854, col train loss 269.823


Epoch 110: 0batch [00:00, ?batch/s]

epoch 109: avg test  loss 2788.86, bar  test loss 1.955, col  test loss 271.931


Epoch 110: 272batch [00:49,  5.50batch/s, loss=2.78e+3]


epoch 110: avg train loss 2766.13, bar train loss 1.852, col train loss 269.786
epoch 110: avg test  loss 2788.85, bar  test loss 1.940, col  test loss 271.885


Epoch 111: 272batch [00:49,  5.51batch/s, loss=2.77e+3]


epoch 111: avg train loss 2766.05, bar train loss 1.854, col train loss 269.761


Epoch 112: 0batch [00:00, ?batch/s]

epoch 111: avg test  loss 2788.15, bar  test loss 1.952, col  test loss 271.842


Epoch 112: 272batch [00:49,  5.51batch/s, loss=2.76e+3]


epoch 112: avg train loss 2765.35, bar train loss 1.844, col train loss 269.705


Epoch 113: 0batch [00:00, ?batch/s]

epoch 112: avg test  loss 2788.13, bar  test loss 1.967, col  test loss 271.845


Epoch 113: 272batch [00:49,  5.51batch/s, loss=2.8e+3] 


epoch 113: avg train loss 2765.12, bar train loss 1.843, col train loss 269.686


Epoch 114: 0batch [00:00, ?batch/s]

epoch 113: avg test  loss 2788.51, bar  test loss 1.952, col  test loss 271.861


Epoch 114: 272batch [00:49,  5.52batch/s, loss=2.74e+3]


epoch 114: avg train loss 2764.77, bar train loss 1.839, col train loss 269.653


Epoch 115: 0batch [00:00, ?batch/s]

epoch 114: avg test  loss 2788.44, bar  test loss 1.925, col  test loss 271.835


Epoch 115: 272batch [00:49,  5.50batch/s, loss=2.78e+3]


epoch 115: avg train loss 2764.53, bar train loss 1.835, col train loss 269.630
epoch 115: avg test  loss 2788.46, bar  test loss 1.922, col  test loss 271.849


Epoch 116: 272batch [00:49,  5.52batch/s, loss=2.74e+3]


epoch 116: avg train loss 2764.06, bar train loss 1.832, col train loss 269.592


Epoch 117: 0batch [00:00, ?batch/s]

epoch 116: avg test  loss 2788.33, bar  test loss 1.928, col  test loss 271.867


Epoch 117: 272batch [00:49,  5.48batch/s, loss=2.77e+3]


epoch 117: avg train loss 2763.75, bar train loss 1.823, col train loss 269.562


Epoch 118: 1batch [00:00,  5.59batch/s, loss=2.75e+3]

epoch 117: avg test  loss 2787.86, bar  test loss 1.933, col  test loss 271.791


Epoch 118: 272batch [00:48,  5.58batch/s, loss=2.77e+3]


epoch 118: avg train loss 2763.39, bar train loss 1.824, col train loss 269.518


Epoch 119: 0batch [00:00, ?batch/s]

epoch 118: avg test  loss 2788.30, bar  test loss 1.948, col  test loss 271.833


Epoch 119: 272batch [00:48,  5.58batch/s, loss=2.75e+3]


epoch 119: avg train loss 2763.06, bar train loss 1.825, col train loss 269.477


Epoch 120: 0batch [00:00, ?batch/s, loss=2.76e+3]

epoch 119: avg test  loss 2788.06, bar  test loss 1.910, col  test loss 271.803


Epoch 120: 272batch [00:51,  5.24batch/s, loss=2.77e+3]


epoch 120: avg train loss 2762.65, bar train loss 1.817, col train loss 269.458
epoch 120: avg test  loss 2787.71, bar  test loss 1.920, col  test loss 271.798


Epoch 121: 272batch [00:49,  5.48batch/s, loss=2.76e+3]


epoch 121: avg train loss 2762.38, bar train loss 1.814, col train loss 269.433


Epoch 122: 0batch [00:00, ?batch/s]

epoch 121: avg test  loss 2787.83, bar  test loss 1.903, col  test loss 271.794


Epoch 122: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 122: avg train loss 2762.04, bar train loss 1.811, col train loss 269.383


Epoch 123: 1batch [00:00,  5.65batch/s, loss=2.73e+3]

epoch 122: avg test  loss 2787.80, bar  test loss 1.923, col  test loss 271.793


Epoch 123: 272batch [00:49,  5.49batch/s, loss=2.73e+3]


epoch 123: avg train loss 2761.87, bar train loss 1.811, col train loss 269.384


Epoch 124: 0batch [00:00, ?batch/s, loss=2.76e+3]

epoch 123: avg test  loss 2788.28, bar  test loss 1.924, col  test loss 271.806


Epoch 124: 272batch [00:49,  5.51batch/s, loss=2.77e+3]


epoch 124: avg train loss 2761.54, bar train loss 1.804, col train loss 269.342


Epoch 125: 1batch [00:00,  5.56batch/s, loss=2.75e+3]

epoch 124: avg test  loss 2787.68, bar  test loss 1.888, col  test loss 271.774


Epoch 125: 272batch [00:48,  5.57batch/s, loss=2.76e+3]


epoch 125: avg train loss 2761.21, bar train loss 1.801, col train loss 269.315
epoch 125: avg test  loss 2787.33, bar  test loss 1.904, col  test loss 271.729


Epoch 126: 272batch [00:50,  5.41batch/s, loss=2.79e+3]


epoch 126: avg train loss 2760.88, bar train loss 1.801, col train loss 269.281


Epoch 127: 0batch [00:00, ?batch/s]

epoch 126: avg test  loss 2787.56, bar  test loss 1.898, col  test loss 271.765


Epoch 127: 272batch [00:50,  5.43batch/s, loss=2.71e+3]


epoch 127: avg train loss 2760.48, bar train loss 1.790, col train loss 269.255


Epoch 128: 0batch [00:00, ?batch/s]

epoch 127: avg test  loss 2787.78, bar  test loss 1.911, col  test loss 271.820


Epoch 128: 272batch [00:50,  5.43batch/s, loss=2.76e+3]


epoch 128: avg train loss 2760.28, bar train loss 1.794, col train loss 269.232


Epoch 129: 0batch [00:00, ?batch/s]

epoch 128: avg test  loss 2787.76, bar  test loss 1.887, col  test loss 271.763


Epoch 129: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 129: avg train loss 2759.97, bar train loss 1.789, col train loss 269.196


Epoch 130: 1batch [00:00,  5.59batch/s, loss=2.76e+3]

epoch 129: avg test  loss 2787.55, bar  test loss 1.894, col  test loss 271.767


Epoch 130: 272batch [00:48,  5.55batch/s, loss=2.75e+3]


epoch 130: avg train loss 2759.54, bar train loss 1.783, col train loss 269.152
epoch 130: avg test  loss 2787.20, bar  test loss 1.883, col  test loss 271.773


Epoch 131: 272batch [00:48,  5.55batch/s, loss=2.8e+3] 


epoch 131: avg train loss 2759.29, bar train loss 1.782, col train loss 269.132


Epoch 132: 0batch [00:00, ?batch/s]

epoch 131: avg test  loss 2787.49, bar  test loss 1.906, col  test loss 271.714


Epoch 132: 272batch [00:49,  5.46batch/s, loss=2.73e+3]


epoch 132: avg train loss 2759.13, bar train loss 1.782, col train loss 269.109


Epoch 133: 0batch [00:00, ?batch/s]

epoch 132: avg test  loss 2787.35, bar  test loss 1.905, col  test loss 271.752


Epoch 133: 272batch [00:50,  5.39batch/s, loss=2.74e+3]


epoch 133: avg train loss 2758.89, bar train loss 1.779, col train loss 269.098


Epoch 134: 1batch [00:00,  5.62batch/s, loss=2.78e+3]

epoch 133: avg test  loss 2787.23, bar  test loss 1.916, col  test loss 271.708


Epoch 134: 272batch [00:50,  5.37batch/s, loss=2.8e+3] 


epoch 134: avg train loss 2758.44, bar train loss 1.772, col train loss 269.053


Epoch 135: 1batch [00:00,  5.59batch/s, loss=2.75e+3]

epoch 134: avg test  loss 2787.36, bar  test loss 1.881, col  test loss 271.748


Epoch 135: 272batch [00:49,  5.47batch/s, loss=2.75e+3]


epoch 135: avg train loss 2758.16, bar train loss 1.772, col train loss 269.017
epoch 135: avg test  loss 2787.17, bar  test loss 1.873, col  test loss 271.727


Epoch 136: 272batch [00:49,  5.54batch/s, loss=2.78e+3]


epoch 136: avg train loss 2757.75, bar train loss 1.769, col train loss 268.990


Epoch 137: 1batch [00:00,  5.65batch/s, loss=2.76e+3]

epoch 136: avg test  loss 2787.27, bar  test loss 1.864, col  test loss 271.724


Epoch 137: 272batch [00:48,  5.58batch/s, loss=2.78e+3]


epoch 137: avg train loss 2757.64, bar train loss 1.765, col train loss 268.981


Epoch 138: 0batch [00:00, ?batch/s]

epoch 137: avg test  loss 2786.85, bar  test loss 1.892, col  test loss 271.716


Epoch 138: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 138: avg train loss 2757.44, bar train loss 1.764, col train loss 268.952


Epoch 139: 0batch [00:00, ?batch/s]

epoch 138: avg test  loss 2787.47, bar  test loss 1.882, col  test loss 271.764


Epoch 139: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 139: avg train loss 2757.36, bar train loss 1.763, col train loss 268.944


Epoch 140: 0batch [00:00, ?batch/s]

epoch 139: avg test  loss 2787.04, bar  test loss 1.879, col  test loss 271.700


Epoch 140: 272batch [00:48,  5.59batch/s, loss=2.76e+3]


epoch 140: avg train loss 2757.00, bar train loss 1.759, col train loss 268.905
epoch 140: avg test  loss 2787.16, bar  test loss 1.860, col  test loss 271.773


Epoch 141: 272batch [00:48,  5.58batch/s, loss=2.76e+3]


epoch 141: avg train loss 2756.75, bar train loss 1.758, col train loss 268.879


Epoch 142: 0batch [00:00, ?batch/s, loss=2.75e+3]

epoch 141: avg test  loss 2786.62, bar  test loss 1.883, col  test loss 271.719


Epoch 142: 272batch [00:48,  5.55batch/s, loss=2.78e+3]


epoch 142: avg train loss 2756.31, bar train loss 1.757, col train loss 268.849


Epoch 143: 0batch [00:00, ?batch/s]

epoch 142: avg test  loss 2787.01, bar  test loss 1.861, col  test loss 271.741


Epoch 143: 272batch [00:48,  5.56batch/s, loss=2.73e+3]


epoch 143: avg train loss 2756.08, bar train loss 1.753, col train loss 268.824


Epoch 144: 0batch [00:00, ?batch/s]

epoch 143: avg test  loss 2787.00, bar  test loss 1.864, col  test loss 271.719


Epoch 144: 272batch [00:48,  5.56batch/s, loss=2.77e+3]


epoch 144: avg train loss 2755.83, bar train loss 1.750, col train loss 268.803


Epoch 145: 0batch [00:00, ?batch/s]

epoch 144: avg test  loss 2787.18, bar  test loss 1.867, col  test loss 271.728


Epoch 145: 272batch [00:48,  5.57batch/s, loss=2.72e+3]


epoch 145: avg train loss 2755.48, bar train loss 1.751, col train loss 268.764
epoch 145: avg test  loss 2787.48, bar  test loss 1.857, col  test loss 271.731


Epoch 146: 272batch [00:48,  5.56batch/s, loss=2.75e+3]


epoch 146: avg train loss 2755.47, bar train loss 1.748, col train loss 268.757


Epoch 147: 0batch [00:00, ?batch/s]

epoch 146: avg test  loss 2786.81, bar  test loss 1.879, col  test loss 271.699


Epoch 147: 272batch [00:48,  5.56batch/s, loss=2.71e+3]


epoch 147: avg train loss 2755.04, bar train loss 1.745, col train loss 268.709


Epoch 148: 0batch [00:00, ?batch/s]

epoch 147: avg test  loss 2786.96, bar  test loss 1.875, col  test loss 271.736


Epoch 148: 272batch [00:48,  5.56batch/s, loss=2.76e+3]


epoch 148: avg train loss 2754.82, bar train loss 1.743, col train loss 268.696


Epoch 149: 0batch [00:00, ?batch/s, loss=2.78e+3]

epoch 148: avg test  loss 2787.03, bar  test loss 1.867, col  test loss 271.727


Epoch 149: 272batch [00:49,  5.55batch/s, loss=2.75e+3]


epoch 149: avg train loss 2754.62, bar train loss 1.742, col train loss 268.674


Epoch 150: 1batch [00:00,  5.65batch/s, loss=2.75e+3]

epoch 149: avg test  loss 2786.94, bar  test loss 1.848, col  test loss 271.713


Epoch 150: 272batch [00:48,  5.55batch/s, loss=2.75e+3]


epoch 150: avg train loss 2754.59, bar train loss 1.744, col train loss 268.664
epoch 150: avg test  loss 2787.18, bar  test loss 1.852, col  test loss 271.734


Epoch 151: 272batch [00:48,  5.55batch/s, loss=2.75e+3]


epoch 151: avg train loss 2754.13, bar train loss 1.737, col train loss 268.630


Epoch 152: 0batch [00:00, ?batch/s]

epoch 151: avg test  loss 2787.08, bar  test loss 1.859, col  test loss 271.708


Epoch 152: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 152: avg train loss 2754.03, bar train loss 1.736, col train loss 268.622


Epoch 153: 0batch [00:00, ?batch/s]

epoch 152: avg test  loss 2787.10, bar  test loss 1.857, col  test loss 271.701


Epoch 153: 272batch [00:48,  5.55batch/s, loss=2.75e+3]


epoch 153: avg train loss 2753.70, bar train loss 1.735, col train loss 268.595


Epoch 154: 1batch [00:00,  5.56batch/s, loss=2.77e+3]

epoch 153: avg test  loss 2787.22, bar  test loss 1.849, col  test loss 271.709


Epoch 154: 272batch [00:48,  5.56batch/s, loss=2.78e+3]


epoch 154: avg train loss 2753.64, bar train loss 1.732, col train loss 268.581


Epoch 155: 1batch [00:00,  5.65batch/s, loss=2.73e+3]

epoch 154: avg test  loss 2787.25, bar  test loss 1.843, col  test loss 271.703


Epoch 155: 272batch [00:48,  5.56batch/s, loss=2.74e+3]


epoch 155: avg train loss 2753.18, bar train loss 1.729, col train loss 268.539
epoch 155: avg test  loss 2786.79, bar  test loss 1.860, col  test loss 271.682


Epoch 156: 272batch [00:48,  5.57batch/s, loss=2.75e+3]


epoch 156: avg train loss 2753.01, bar train loss 1.732, col train loss 268.520


Epoch 157: 0batch [00:00, ?batch/s]

epoch 156: avg test  loss 2786.65, bar  test loss 1.862, col  test loss 271.715


Epoch 157: 272batch [00:48,  5.59batch/s, loss=2.74e+3]


epoch 157: avg train loss 2752.91, bar train loss 1.730, col train loss 268.508


Epoch 158: 0batch [00:00, ?batch/s]

epoch 157: avg test  loss 2786.45, bar  test loss 1.856, col  test loss 271.671


Epoch 158: 272batch [00:49,  5.55batch/s, loss=2.76e+3]


epoch 158: avg train loss 2752.61, bar train loss 1.725, col train loss 268.475


Epoch 159: 0batch [00:00, ?batch/s]

epoch 158: avg test  loss 2787.02, bar  test loss 1.863, col  test loss 271.675


Epoch 159: 272batch [00:49,  5.55batch/s, loss=2.75e+3]


epoch 159: avg train loss 2752.49, bar train loss 1.722, col train loss 268.464


Epoch 160: 0batch [00:00, ?batch/s]

epoch 159: avg test  loss 2787.07, bar  test loss 1.858, col  test loss 271.683


Epoch 160: 272batch [00:48,  5.56batch/s, loss=2.77e+3]


epoch 160: avg train loss 2752.32, bar train loss 1.721, col train loss 268.445
epoch 160: avg test  loss 2786.92, bar  test loss 1.839, col  test loss 271.706


Epoch 161: 272batch [00:48,  5.55batch/s, loss=2.76e+3]


epoch 161: avg train loss 2752.18, bar train loss 1.720, col train loss 268.436


Epoch 162: 1batch [00:00,  5.62batch/s, loss=2.75e+3]

epoch 161: avg test  loss 2787.32, bar  test loss 1.856, col  test loss 271.677


Epoch 162: 272batch [00:48,  5.55batch/s, loss=2.79e+3]


epoch 162: avg train loss 2751.97, bar train loss 1.718, col train loss 268.419


Epoch 163: 0batch [00:00, ?batch/s]

epoch 162: avg test  loss 2787.02, bar  test loss 1.851, col  test loss 271.666


Epoch 163: 272batch [00:49,  5.55batch/s, loss=2.77e+3]


epoch 163: avg train loss 2751.68, bar train loss 1.716, col train loss 268.392


Epoch 164: 0batch [00:00, ?batch/s]

epoch 163: avg test  loss 2786.87, bar  test loss 1.853, col  test loss 271.696


Epoch 164: 272batch [00:48,  5.55batch/s, loss=2.79e+3]


epoch 164: avg train loss 2751.55, bar train loss 1.716, col train loss 268.374


Epoch 165: 0batch [00:00, ?batch/s]

epoch 164: avg test  loss 2787.14, bar  test loss 1.853, col  test loss 271.710


Epoch 165: 272batch [00:48,  5.56batch/s, loss=2.76e+3]


epoch 165: avg train loss 2751.39, bar train loss 1.714, col train loss 268.360
epoch 165: avg test  loss 2786.58, bar  test loss 1.837, col  test loss 271.689


Epoch 166: 272batch [00:49,  5.55batch/s, loss=2.78e+3]


epoch 166: avg train loss 2751.07, bar train loss 1.710, col train loss 268.328


Epoch 167: 0batch [00:00, ?batch/s]

epoch 166: avg test  loss 2786.80, bar  test loss 1.834, col  test loss 271.664


Epoch 167: 272batch [00:49,  5.55batch/s, loss=2.78e+3]


epoch 167: avg train loss 2751.14, bar train loss 1.715, col train loss 268.333


Epoch 168: 0batch [00:00, ?batch/s]

epoch 167: avg test  loss 2787.45, bar  test loss 1.840, col  test loss 271.661


Epoch 168: 272batch [00:49,  5.55batch/s, loss=2715.25]


epoch 168: avg train loss 2750.66, bar train loss 1.708, col train loss 268.290


Epoch 169: 0batch [00:00, ?batch/s]

epoch 168: avg test  loss 2786.86, bar  test loss 1.835, col  test loss 271.693


Epoch 169: 272batch [00:49,  5.55batch/s, loss=2.73e+3]


epoch 169: avg train loss 2750.72, bar train loss 1.706, col train loss 268.292


Epoch 170: 0batch [00:00, ?batch/s, loss=2.74e+3]

epoch 169: avg test  loss 2786.56, bar  test loss 1.855, col  test loss 271.657


Epoch 170: 272batch [00:49,  5.54batch/s, loss=2.74e+3]


epoch 170: avg train loss 2750.45, bar train loss 1.707, col train loss 268.260
epoch 170: avg test  loss 2787.29, bar  test loss 1.857, col  test loss 271.708


Epoch 171: 272batch [00:49,  5.55batch/s, loss=2.75e+3]


epoch 171: avg train loss 2750.30, bar train loss 1.704, col train loss 268.244


Epoch 172: 0batch [00:00, ?batch/s, loss=2.75e+3]

epoch 171: avg test  loss 2787.18, bar  test loss 1.836, col  test loss 271.687


Epoch 172: 272batch [00:49,  5.55batch/s, loss=2.76e+3]


epoch 172: avg train loss 2750.15, bar train loss 1.704, col train loss 268.232


Epoch 173: 0batch [00:00, ?batch/s]

epoch 172: avg test  loss 2787.01, bar  test loss 1.840, col  test loss 271.658


Epoch 173: 272batch [00:48,  5.57batch/s, loss=2.72e+3]


epoch 173: avg train loss 2749.96, bar train loss 1.701, col train loss 268.215


Epoch 174: 0batch [00:00, ?batch/s]

epoch 173: avg test  loss 2787.23, bar  test loss 1.839, col  test loss 271.683


Epoch 174: 272batch [00:49,  5.53batch/s, loss=2.76e+3]


epoch 174: avg train loss 2749.78, bar train loss 1.699, col train loss 268.199


Epoch 175: 1batch [00:00,  5.65batch/s, loss=2.75e+3]

epoch 174: avg test  loss 2787.65, bar  test loss 1.840, col  test loss 271.694


Epoch 175: 272batch [00:49,  5.52batch/s, loss=2.7e+3] 


epoch 175: avg train loss 2749.52, bar train loss 1.701, col train loss 268.167
epoch 175: avg test  loss 2786.75, bar  test loss 1.838, col  test loss 271.714


Epoch 176: 272batch [00:49,  5.54batch/s, loss=2.73e+3]


epoch 176: avg train loss 2749.25, bar train loss 1.697, col train loss 268.151


Epoch 177: 1batch [00:00,  5.56batch/s, loss=2.74e+3]

epoch 176: avg test  loss 2787.06, bar  test loss 1.830, col  test loss 271.733


Epoch 177: 272batch [00:49,  5.55batch/s, loss=2.75e+3]


epoch 177: avg train loss 2749.23, bar train loss 1.696, col train loss 268.142


Epoch 178: 0batch [00:00, ?batch/s]

epoch 177: avg test  loss 2787.32, bar  test loss 1.832, col  test loss 271.685


Epoch 178: 272batch [00:48,  5.58batch/s, loss=2.72e+3]


epoch 178: avg train loss 2749.06, bar train loss 1.697, col train loss 268.126


Epoch 179: 0batch [00:00, ?batch/s]

epoch 178: avg test  loss 2786.78, bar  test loss 1.840, col  test loss 271.686


Epoch 179: 272batch [00:48,  5.56batch/s, loss=2.77e+3]


epoch 179: avg train loss 2748.80, bar train loss 1.692, col train loss 268.107


Epoch 180: 0batch [00:00, ?batch/s]

epoch 179: avg test  loss 2787.38, bar  test loss 1.817, col  test loss 271.680


Epoch 180: 272batch [00:48,  5.55batch/s, loss=2.76e+3]


epoch 180: avg train loss 2748.57, bar train loss 1.689, col train loss 268.087
epoch 180: avg test  loss 2787.18, bar  test loss 1.835, col  test loss 271.648


Epoch 181: 272batch [00:49,  5.55batch/s, loss=2.76e+3]


epoch 181: avg train loss 2748.56, bar train loss 1.692, col train loss 268.074


Epoch 182: 0batch [00:00, ?batch/s, loss=2.74e+3]

epoch 181: avg test  loss 2786.81, bar  test loss 1.832, col  test loss 271.666


Epoch 182: 272batch [00:49,  5.55batch/s, loss=2.73e+3]


epoch 182: avg train loss 2748.50, bar train loss 1.692, col train loss 268.077


Epoch 183: 0batch [00:00, ?batch/s]

epoch 182: avg test  loss 2787.19, bar  test loss 1.825, col  test loss 271.675


Epoch 183: 272batch [00:49,  5.54batch/s, loss=2.73e+3]


epoch 183: avg train loss 2748.29, bar train loss 1.691, col train loss 268.054


Epoch 184: 1batch [00:00,  5.62batch/s, loss=2.74e+3]

epoch 183: avg test  loss 2787.26, bar  test loss 1.842, col  test loss 271.682


Epoch 184: 272batch [00:49,  5.54batch/s, loss=2.77e+3]


epoch 184: avg train loss 2748.09, bar train loss 1.689, col train loss 268.033


Epoch 185: 0batch [00:00, ?batch/s]

epoch 184: avg test  loss 2787.04, bar  test loss 1.823, col  test loss 271.660


Epoch 185: 272batch [00:49,  5.55batch/s, loss=2.73e+3]


epoch 185: avg train loss 2747.75, bar train loss 1.685, col train loss 268.014
epoch 185: avg test  loss 2787.06, bar  test loss 1.833, col  test loss 271.670


Epoch 186: 272batch [00:49,  5.55batch/s, loss=2.8e+3] 


epoch 186: avg train loss 2747.81, bar train loss 1.689, col train loss 268.007


Epoch 187: 0batch [00:00, ?batch/s]

epoch 186: avg test  loss 2787.06, bar  test loss 1.834, col  test loss 271.654


Epoch 187: 272batch [00:49,  5.55batch/s, loss=2.7e+3] 


epoch 187: avg train loss 2747.64, bar train loss 1.686, col train loss 267.990


Epoch 188: 0batch [00:00, ?batch/s]

epoch 187: avg test  loss 2787.15, bar  test loss 1.824, col  test loss 271.665


Epoch 188: 272batch [00:49,  5.54batch/s, loss=2.72e+3]


epoch 188: avg train loss 2747.37, bar train loss 1.682, col train loss 267.971


Epoch 189: 0batch [00:00, ?batch/s]

epoch 188: avg test  loss 2786.48, bar  test loss 1.828, col  test loss 271.666


Epoch 189: 272batch [00:48,  5.55batch/s, loss=2.75e+3]


epoch 189: avg train loss 2747.18, bar train loss 1.682, col train loss 267.942


Epoch 190: 0batch [00:00, ?batch/s]

epoch 189: avg test  loss 2786.84, bar  test loss 1.829, col  test loss 271.662


Epoch 190: 272batch [00:49,  5.53batch/s, loss=2.77e+3]


epoch 190: avg train loss 2747.16, bar train loss 1.682, col train loss 267.937
epoch 190: avg test  loss 2787.18, bar  test loss 1.831, col  test loss 271.694


Epoch 191: 272batch [00:49,  5.53batch/s, loss=2.72e+3]


epoch 191: avg train loss 2747.22, bar train loss 1.681, col train loss 267.945


Epoch 192: 1batch [00:00,  5.52batch/s, loss=2.75e+3]

epoch 191: avg test  loss 2787.42, bar  test loss 1.814, col  test loss 271.677


Epoch 192: 272batch [00:49,  5.53batch/s, loss=2.73e+3]


epoch 192: avg train loss 2746.76, bar train loss 1.680, col train loss 267.908


Epoch 193: 0batch [00:00, ?batch/s]

epoch 192: avg test  loss 2787.37, bar  test loss 1.832, col  test loss 271.697


Epoch 193: 272batch [00:50,  5.44batch/s, loss=2.76e+3]


epoch 193: avg train loss 2746.67, bar train loss 1.678, col train loss 267.895


Epoch 194: 0batch [00:00, ?batch/s]

epoch 193: avg test  loss 2787.09, bar  test loss 1.821, col  test loss 271.677


Epoch 194: 49batch [00:09,  5.26batch/s, loss=2.74e+3]


KeyboardInterrupt: 

In [35]:
#torch.save(diva.state_dict(), f'{link}/saved_models/new/NVAE/checkpoints/193.pth')

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 193, save_folder="new/NVAE1")

Epoch 194: 272batch [01:15,  3.58batch/s, loss=2.74e+3]


epoch 194: avg train loss 2747.05, bar train loss 1.683, col train loss 267.919
epoch 194: avg test  loss 2787.57, bar  test loss 1.810, col  test loss 271.655


Epoch 195: 272batch [01:07,  4.00batch/s, loss=2.71e+3]


epoch 195: avg train loss 2746.20, bar train loss 1.674, col train loss 267.840
epoch 195: avg test  loss 2787.51, bar  test loss 1.822, col  test loss 271.740


Epoch 196: 272batch [01:08,  4.00batch/s, loss=2.72e+3]


epoch 196: avg train loss 2746.22, bar train loss 1.674, col train loss 267.852


Epoch 197: 0batch [00:00, ?batch/s]

epoch 196: avg test  loss 2786.61, bar  test loss 1.828, col  test loss 271.638


Epoch 197: 272batch [01:07,  4.00batch/s, loss=2.78e+3]


epoch 197: avg train loss 2746.09, bar train loss 1.674, col train loss 267.832


Epoch 198: 0batch [00:00, ?batch/s]

epoch 197: avg test  loss 2787.20, bar  test loss 1.822, col  test loss 271.642


Epoch 198: 272batch [01:07,  4.01batch/s, loss=2.76e+3]


epoch 198: avg train loss 2745.82, bar train loss 1.670, col train loss 267.811


Epoch 199: 0batch [00:00, ?batch/s]

epoch 198: avg test  loss 2787.07, bar  test loss 1.827, col  test loss 271.696


Epoch 199: 272batch [01:07,  4.01batch/s, loss=2.75e+3]


epoch 199: avg train loss 2745.71, bar train loss 1.671, col train loss 267.802


Epoch 200: 0batch [00:00, ?batch/s]

epoch 199: avg test  loss 2786.75, bar  test loss 1.832, col  test loss 271.630


Epoch 200: 272batch [01:07,  4.00batch/s, loss=2.71e+3]


epoch 200: avg train loss 2745.57, bar train loss 1.671, col train loss 267.780
epoch 200: avg test  loss 2786.82, bar  test loss 1.822, col  test loss 271.629


Epoch 201: 272batch [01:08,  4.00batch/s, loss=2.77e+3]


epoch 201: avg train loss 2745.62, bar train loss 1.669, col train loss 267.785


Epoch 202: 0batch [00:00, ?batch/s]

epoch 201: avg test  loss 2787.37, bar  test loss 1.828, col  test loss 271.683


Epoch 202: 272batch [01:07,  4.00batch/s, loss=2.73e+3]


epoch 202: avg train loss 2745.47, bar train loss 1.671, col train loss 267.768


Epoch 203: 0batch [00:00, ?batch/s]

epoch 202: avg test  loss 2786.80, bar  test loss 1.818, col  test loss 271.641


Epoch 203: 272batch [01:08,  4.00batch/s, loss=2.74e+3]


epoch 203: avg train loss 2745.23, bar train loss 1.667, col train loss 267.744


Epoch 204: 0batch [00:00, ?batch/s]

epoch 203: avg test  loss 2786.97, bar  test loss 1.826, col  test loss 271.652


Epoch 204: 272batch [01:07,  4.00batch/s, loss=2.79e+3]


epoch 204: avg train loss 2745.13, bar train loss 1.666, col train loss 267.740


Epoch 205: 0batch [00:00, ?batch/s]

epoch 204: avg test  loss 2787.29, bar  test loss 1.816, col  test loss 271.632


Epoch 205: 272batch [01:07,  4.00batch/s, loss=2.74e+3]


epoch 205: avg train loss 2744.91, bar train loss 1.667, col train loss 267.717
epoch 205: avg test  loss 2787.28, bar  test loss 1.807, col  test loss 271.661


Epoch 206: 272batch [01:07,  4.00batch/s, loss=2.74e+3]


epoch 206: avg train loss 2744.91, bar train loss 1.666, col train loss 267.717


Epoch 207: 0batch [00:00, ?batch/s]

epoch 206: avg test  loss 2787.17, bar  test loss 1.806, col  test loss 271.655


Epoch 207: 272batch [01:08,  4.00batch/s, loss=2.73e+3]


epoch 207: avg train loss 2744.85, bar train loss 1.665, col train loss 267.703


Epoch 208: 0batch [00:00, ?batch/s]

epoch 207: avg test  loss 2787.18, bar  test loss 1.808, col  test loss 271.680


Epoch 208: 272batch [01:07,  4.00batch/s, loss=2.73e+3]


epoch 208: avg train loss 2744.59, bar train loss 1.662, col train loss 267.692


Epoch 209: 0batch [00:00, ?batch/s]

epoch 208: avg test  loss 2786.46, bar  test loss 1.824, col  test loss 271.622


Epoch 209: 272batch [01:07,  4.00batch/s, loss=2.79e+3]


epoch 209: avg train loss 2744.43, bar train loss 1.666, col train loss 267.671


Epoch 210: 0batch [00:00, ?batch/s]

epoch 209: avg test  loss 2786.86, bar  test loss 1.819, col  test loss 271.650


Epoch 210: 272batch [01:07,  4.00batch/s, loss=2.7e+3] 


epoch 210: avg train loss 2744.35, bar train loss 1.659, col train loss 267.667
epoch 210: avg test  loss 2787.08, bar  test loss 1.825, col  test loss 271.641


Epoch 211: 272batch [01:07,  4.00batch/s, loss=2.75e+3]


epoch 211: avg train loss 2744.29, bar train loss 1.664, col train loss 267.651


Epoch 212: 0batch [00:00, ?batch/s]

epoch 211: avg test  loss 2787.50, bar  test loss 1.818, col  test loss 271.674


Epoch 212: 272batch [01:07,  4.00batch/s, loss=2.72e+3]


epoch 212: avg train loss 2744.22, bar train loss 1.663, col train loss 267.647


Epoch 213: 0batch [00:00, ?batch/s]

epoch 212: avg test  loss 2786.67, bar  test loss 1.839, col  test loss 271.651


Epoch 213: 272batch [01:08,  4.00batch/s, loss=2.72e+3]


epoch 213: avg train loss 2744.03, bar train loss 1.662, col train loss 267.635


Epoch 214: 0batch [00:00, ?batch/s]

epoch 213: avg test  loss 2787.08, bar  test loss 1.823, col  test loss 271.659


Epoch 214: 272batch [01:07,  4.00batch/s, loss=2.7e+3] 


epoch 214: avg train loss 2743.83, bar train loss 1.659, col train loss 267.611


Epoch 215: 0batch [00:00, ?batch/s]

epoch 214: avg test  loss 2787.03, bar  test loss 1.808, col  test loss 271.636


Epoch 215: 272batch [01:07,  4.00batch/s, loss=2.74e+3]


epoch 215: avg train loss 2743.74, bar train loss 1.658, col train loss 267.592
epoch 215: avg test  loss 2787.29, bar  test loss 1.802, col  test loss 271.648


Epoch 216: 272batch [01:07,  4.01batch/s, loss=2.77e+3]


epoch 216: avg train loss 2743.33, bar train loss 1.656, col train loss 267.563


Epoch 217: 0batch [00:00, ?batch/s]

epoch 216: avg test  loss 2787.74, bar  test loss 1.822, col  test loss 271.682


Epoch 217: 272batch [01:08,  4.00batch/s, loss=2.75e+3]


epoch 217: avg train loss 2743.37, bar train loss 1.653, col train loss 267.558


Epoch 218: 0batch [00:00, ?batch/s]

epoch 217: avg test  loss 2787.32, bar  test loss 1.812, col  test loss 271.668


Epoch 218: 272batch [01:07,  4.00batch/s, loss=2.75e+3]


epoch 218: avg train loss 2743.18, bar train loss 1.657, col train loss 267.552


Epoch 219: 0batch [00:00, ?batch/s]

epoch 218: avg test  loss 2787.69, bar  test loss 1.818, col  test loss 271.673


Epoch 219: 272batch [01:08,  4.00batch/s, loss=2.74e+3]


epoch 219: avg train loss 2743.12, bar train loss 1.654, col train loss 267.532


Epoch 220: 0batch [00:00, ?batch/s]

epoch 219: avg test  loss 2787.17, bar  test loss 1.802, col  test loss 271.662


Epoch 220: 272batch [01:07,  4.00batch/s, loss=2.77e+3]


epoch 220: avg train loss 2743.12, bar train loss 1.656, col train loss 267.532
epoch 220: avg test  loss 2786.98, bar  test loss 1.804, col  test loss 271.634


Epoch 221: 272batch [01:07,  4.00batch/s, loss=2.71e+3]


epoch 221: avg train loss 2742.86, bar train loss 1.653, col train loss 267.516


Epoch 222: 0batch [00:00, ?batch/s]

epoch 221: avg test  loss 2787.03, bar  test loss 1.805, col  test loss 271.651


Epoch 222: 272batch [01:07,  4.01batch/s, loss=2.77e+3]


epoch 222: avg train loss 2743.04, bar train loss 1.655, col train loss 267.522


Epoch 223: 0batch [00:00, ?batch/s]

epoch 222: avg test  loss 2787.73, bar  test loss 1.816, col  test loss 271.712


Epoch 223: 272batch [01:08,  3.99batch/s, loss=2.75e+3]


epoch 223: avg train loss 2742.77, bar train loss 1.650, col train loss 267.512


Epoch 224: 0batch [00:00, ?batch/s]

epoch 223: avg test  loss 2787.60, bar  test loss 1.813, col  test loss 271.666


Epoch 224: 272batch [01:08,  3.98batch/s, loss=2.74e+3]


epoch 224: avg train loss 2742.64, bar train loss 1.653, col train loss 267.491


Epoch 225: 0batch [00:00, ?batch/s]

epoch 224: avg test  loss 2787.32, bar  test loss 1.823, col  test loss 271.694


Epoch 225: 272batch [01:08,  4.00batch/s, loss=2.75e+3]


epoch 225: avg train loss 2742.51, bar train loss 1.649, col train loss 267.479
epoch 225: avg test  loss 2787.10, bar  test loss 1.815, col  test loss 271.674


Epoch 226: 272batch [01:08,  3.99batch/s, loss=2.72e+3]


epoch 226: avg train loss 2742.37, bar train loss 1.646, col train loss 267.468


Epoch 227: 0batch [00:00, ?batch/s]

epoch 226: avg test  loss 2787.11, bar  test loss 1.831, col  test loss 271.657


Epoch 227: 272batch [01:08,  4.00batch/s, loss=2.73e+3]


epoch 227: avg train loss 2742.11, bar train loss 1.648, col train loss 267.435


Epoch 228: 0batch [00:00, ?batch/s]

epoch 227: avg test  loss 2786.91, bar  test loss 1.818, col  test loss 271.653


Epoch 228: 272batch [01:08,  4.00batch/s, loss=2.74e+3]


epoch 228: avg train loss 2742.07, bar train loss 1.646, col train loss 267.435


Epoch 229: 0batch [00:00, ?batch/s]

epoch 228: avg test  loss 2787.04, bar  test loss 1.832, col  test loss 271.672


Epoch 229: 272batch [01:07,  4.01batch/s, loss=2.79e+3]


epoch 229: avg train loss 2742.09, bar train loss 1.648, col train loss 267.434


Epoch 230: 0batch [00:00, ?batch/s]

epoch 229: avg test  loss 2786.97, bar  test loss 1.802, col  test loss 271.693


Epoch 230: 272batch [01:08,  4.00batch/s, loss=2.73e+3]


epoch 230: avg train loss 2741.98, bar train loss 1.647, col train loss 267.422
epoch 230: avg test  loss 2787.58, bar  test loss 1.813, col  test loss 271.654


Epoch 231: 272batch [01:07,  4.00batch/s, loss=2.76e+3]


epoch 231: avg train loss 2741.90, bar train loss 1.645, col train loss 267.411


Epoch 232: 0batch [00:00, ?batch/s]

epoch 231: avg test  loss 2787.04, bar  test loss 1.803, col  test loss 271.677


Epoch 232: 272batch [01:08,  4.00batch/s, loss=2.77e+3]


epoch 232: avg train loss 2741.71, bar train loss 1.645, col train loss 267.388


Epoch 233: 0batch [00:00, ?batch/s]

epoch 232: avg test  loss 2787.38, bar  test loss 1.805, col  test loss 271.658


Epoch 233: 272batch [01:07,  4.01batch/s, loss=2.73e+3]


epoch 233: avg train loss 2741.71, bar train loss 1.647, col train loss 267.391


Epoch 234: 0batch [00:00, ?batch/s]

epoch 233: avg test  loss 2787.72, bar  test loss 1.815, col  test loss 271.685


Epoch 234: 272batch [01:07,  4.01batch/s, loss=2.75e+3]


epoch 234: avg train loss 2741.73, bar train loss 1.649, col train loss 267.392


In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')