In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime


writer = SummaryWriter(f"{link}/saved_models/VAE5/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z_dim=64, d_dim=45, x_dim=7500, y_dim=2,
                 beta=10, rec_alpha = 1, rec_beta = 1, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z_dim = z_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_col, x_bar = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        

        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
      
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        return (x, y, d, x_len, x_col, x_bar)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        print(x.shape, "x")
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n,100), dtype=np.uint8)
        out_col = np.zeros((n,5,200), dtype=np.uint8)
        out_bar = np.zeros((n,13,200), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                #print(x[i,:,12,j])
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i,rna_len-1] = 1
                    broke = True
                    break
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x[i,:,12,j]) ,2*j] = 1 
                    out_col[i, self.get_color(x[i,:,13,j]), 2*j+1] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, len1-1, 2*j] = 1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, len2-1, 2*j+1] = 1
            if not broke:
                out_len[i, rna_len-1] = 1


        with open(f'{link}/data/modmirbase_{ds}_images_len.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim):
        super(px, self).__init__()

        self.fc1 = nn.Sequential(nn.Linear(z_dim, 400, bias=False),  
                                 nn.ReLU())
        
        self.fc2 = nn.Sequential(nn.Linear(400, 200, bias=False),  
                                 nn.ReLU())
        # Predicting length and color of each bar
        self.up1 = nn.Upsample(scale_factor=5)
        self.de1 = nn.Sequential(nn.ConvTranspose1d(5,25,kernel_size = 5,
                                                    stride = 1, padding = 2,
                                                    bias=False),
                                 nn.ReLU(),
                                 nn.ConvTranspose1d(25,50,kernel_size = 5,
                                                    stride = 1, padding = 2,
                                                    bias=False),
                                 nn.ReLU(),
                                 )
        # Predicting color of each bar
        self.color_bar = nn.Sequential(nn.Conv1d(50,5, kernel_size = 3, padding = 'same'),
                                      nn.Softmax(dim=1))
        
        # Predicting the length of each bar
        self.length_bar = nn.Sequential(nn.Conv1d(50, 13, kernel_size = 3, padding = 'same'),
                                        nn.Softmax(dim=1))

        # Predicting length of the RNA strand
        self.length_RNA = nn.Sequential(nn.Linear(200,100), nn.Softmax(dim=0))
        
    def forward(self, z):
        
        h = self.fc1(z)
        h = self.fc2(h)
        len_RNA = self.length_RNA(h)
        
        h = h.view(-1, 5, 40)
        h = self.up1(h)
        h = self.de1(h)
        
        len_bar = self.length_bar(h)
        col_bar = self.color_bar(h)
        
        return len_RNA, len_bar, col_bar

    def reconstruct_image(self, len_RNA, len_bar, col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0])  # yellow
                  }
    
        
        len_RNA = len_RNA.cpu().numpy()#.reshape((100,))
        len_bar = len_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_RNA.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            if sample:
                limit = np.random.choice(np.arange(100), p = len_RNA[i])
            else:
                limit = np.argmax(len_RNA[i])

            for j in range(limit+1):
                if sample:
                    _len_bar_1 = np.random.choice(np.arange(1,14), p = len_bar[i, :,2*j]) 
                    _len_bar_2 = np.random.choice(np.arange(1,14), p = len_bar[i, :, 2*j+1])
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = np.argmax(len_bar[i,:, 2*j]) + 1 
                    _len_bar_2 = np.argmax(len_bar[i,:, 2*j + 1]) + 1
                    _col_bar_1 = np.argmax(col_bar[i,:, 2*j])
                    _col_bar_2 = np.argmax(col_bar[i,:, 2*j+1])
                
                h1 = 13-_len_bar_1
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
        
                # paint lower bar
                output[i, 13:13+_len_bar_2, j] = color_dict[_col_bar_2]
        
        
        return output


In [9]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [10]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [11]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim):
        super(qz, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding = 'same',bias=False),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding = 'same', bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, bias=False),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )

        self.fc11 = nn.Sequential(nn.Linear(5632, z_dim))
        self.fc12 = nn.Sequential(nn.Linear(5632, z_dim), nn.Softplus())

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[3].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()
        torch.nn.init.xavier_uniform_(self.fc12[0].weight)
        self.fc12[0].bias.data.zero_()

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(-1, 5632)
        z_loc = self.fc11(h)
        z_scale = self.fc12(h) + 1e-7

        return z_loc, z_scale




In [12]:
enc = qz(128,10,10,10)
summary(enc, (1,3,25,100))

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 256, 2, 11]           --
│    └─Conv2d: 2-1                       [1, 64, 25, 100]          4,800
│    └─ReLU: 2-2                         [1, 64, 25, 100]          --
│    └─MaxPool2d: 2-3                    [1, 64, 12, 50]           --
│    └─Conv2d: 2-4                       [1, 128, 12, 50]          73,728
│    └─ReLU: 2-5                         [1, 128, 12, 50]          --
│    └─MaxPool2d: 2-6                    [1, 128, 6, 25]           --
│    └─Conv2d: 2-7                       [1, 256, 4, 23]           294,912
│    └─ReLU: 2-8                         [1, 256, 4, 23]           --
│    └─MaxPool2d: 2-9                    [1, 256, 2, 11]           --
├─Sequential: 1-2                        [1, 10]                   --
│    └─Linear: 2-10                      [1, 10]                   56,330

## Full model class

In [13]:
class StampDIVA(nn.Module):
    def __init__(self, args):
        super(StampDIVA, self).__init__()
        self.z_dim = args.z_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim

        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup

        self.cuda()

    def forward(self, d, x, y):
        # Encode
        zd_q_loc, zd_q_scale = self.qz(x)
        
        # Reparameterization trick
        qz = dist.Normal(zd_q_loc, zd_q_scale)
        z_q = qz.rsample()
        
        
        # Decode
        x_len, x_bar, x_col = self.px(z_q)
        
        z_p_loc, z_p_scale = torch.zeros(z_q.size()[0], self.z_dim).cuda(),\
                        torch.ones(z_q.size()[0], self.z_dim).cuda()
        pz = dist.Normal(z_p_loc, z_p_scale)

        # Reparameterization trick
        pz = dist.Normal(z_p_loc, z_p_scale)
        
        return x_len, x_bar, x_col, qz, pz, z_q

    def loss_function(self, d, x, y, out_len, out_bar, out_col):
        #print(1111)
        x_len, x_bar, x_col, qz, pz, z_q = self.forward(d, x, y)
          
          #print(x_len.shape, x_bar.shape)

        #print(out_len.shape)
        mask = 1 - F.one_hot(torch.argmax(out_len, dim =1)*2+1, 200).cumsum(dim=1)[:,None,:]
         # print(mask.shape)
        #print(222222) 
        x_bar = mask.repeat(1,13,1)*x_bar
        x_col = mask.repeat(1,5,1)*x_col
        CE_len = F.cross_entropy(x_len, out_len, reduction='sum')
        CE_bar = F.cross_entropy(x_bar, out_bar, reduction='sum')
        CE_col = F.cross_entropy(x_col, out_col, reduction='sum')

        KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
        return self.rec_alpha * CE_len \
                  + self.rec_beta * CE_bar \
                  + self.rec_gamma * CE_col \
                  - self.beta * KL_z, \
                  CE_bar, CE_len, CE_col 

In [14]:
w = torch.zeros((5,100))
w[0,55] = 1
w[1,66] = 1
w[2,15] = 1
w[3,35] = 1
w[4,45] = 1

In [15]:
x = torch.argmax(w, dim =1)*2+2

In [16]:
r = F.one_hot(x,200)
b = 1-r.cumsum(dim=1)

In [17]:
out = torch.randn((5,6,200))

In [18]:
b.shape

torch.Size([5, 200])

In [19]:
b = b[:,None,:]

In [20]:
b.shape

torch.Size([5, 1, 200])

# Training the model

## Loading dataset

In [21]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [22]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


## Training functions

In [23]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
    no_batches = 0
    train_corr = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, len_loss, col_loss = model.loss_function(d.float(), x.float(), y.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        epoch_len_loss += len_loss
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    

    return train_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss

In [24]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    epoch_len_loss = 0
        
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)
            loss, bar_loss, len_loss, col_loss = model.loss_function(d.float(), x.float(), y.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            epoch_len_loss += len_loss
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_len_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
  
    return test_loss, epoch_bar_loss, epoch_len_loss, epoch_col_loss
  

In [25]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_len, avg_loss_col = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", len train loss {:.3f}".format(avg_loss_len)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_alpha * avg_loss_len + diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_len_test, avg_loss_col_test = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", len  test loss {:.3f}".format(avg_loss_len_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_alpha * avg_loss_len_test + diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)


        if epoch % save_interval == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')
            save_reconstructions(epoch, test_loader, diva, name=save_folder)


    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.cpu().detach().numpy() for i in epoch_loss_sup]
    test_loss = [i.cpu().detach().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [26]:
def save_reconstructions(epoch, test_loader, diva, name='diva'):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        x_1, x_2, x_3, _, _, _ = diva(d,x,y)
        out = diva.px.reconstruct_image(x_1, x_2, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}.png')
    plt.close(fig)

In [27]:
DEVICE

device(type='cuda')

## Model Training

In [28]:
default_args = diva_args(z_dim=1024, rec_alpha = 1000, rec_beta = 300, rec_gamma = 200, 
                         beta=1, warmup=1, prewarmup=0)

In [29]:
diva = StampDIVA(default_args).to(DEVICE)

In [30]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE1/checkpoints/1400.pth'))

In [31]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [32]:
optimizer = optim.Adam(diva.parameters(), lr=0.0005)

In [33]:
%tensorboard  --logdir="D:/users/Marko/downloads/mirna/saved_models/VAE5/tensorboard/"

Reusing TensorBoard on port 6006 (pid 22272), started 3 days, 3:49:50 ago. (Use '!kill 22272' to kill it.)

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 0, save_folder="VAE5")

Epoch 1: 272batch [00:17, 15.60batch/s, loss=1.01e+5]


epoch 1: avg train loss 113555.60, bar train loss 250.212, len train loss 4.601, col train loss 169.092


Epoch 2: 0batch [00:00, ?batch/s]

epoch 1: avg test  loss 112666.75, bar  test loss 106.344, len  test loss 1.972, col  test loss 71.984


Epoch 2: 272batch [00:14, 18.75batch/s, loss=1.1e+5] 


epoch 2: avg train loss 112258.09, bar train loss 247.707, len train loss 4.600, col train loss 166.488


Epoch 3: 2batch [00:00, 18.87batch/s, loss=1.19e+5]

epoch 2: avg test  loss 111576.99, bar  test loss 105.945, len  test loss 1.971, col  test loss 70.180


Epoch 3: 272batch [00:14, 18.25batch/s, loss=1.17e+5]


epoch 3: avg train loss 111269.19, bar train loss 247.001, len train loss 4.597, col train loss 162.388


Epoch 4: 2batch [00:00, 19.42batch/s, loss=1.15e+5]

epoch 3: avg test  loss 111048.46, bar  test loss 105.789, len  test loss 1.969, col  test loss 69.202


Epoch 4: 272batch [00:14, 18.46batch/s, loss=1.05e+5]


epoch 4: avg train loss 110916.99, bar train loss 246.730, len train loss 4.590, col train loss 160.959


Epoch 5: 2batch [00:00, 18.87batch/s, loss=1.08e+5]

epoch 4: avg test  loss 110820.34, bar  test loss 105.738, len  test loss 1.966, col  test loss 68.792


Epoch 5: 272batch [00:14, 18.76batch/s, loss=1.09e+5]


epoch 5: avg train loss 110718.76, bar train loss 246.640, len train loss 4.581, col train loss 160.071
epoch 5: avg test  loss 110643.35, bar  test loss 105.684, len  test loss 1.964, col  test loss 68.464


Epoch 6: 272batch [00:14, 18.64batch/s, loss=1.14e+5]


epoch 6: avg train loss 110474.02, bar train loss 246.506, len train loss 4.576, col train loss 158.962


Epoch 7: 2batch [00:00, 19.05batch/s, loss=1.06e+5]

epoch 6: avg test  loss 110320.91, bar  test loss 105.624, len  test loss 1.960, col  test loss 67.821


Epoch 7: 272batch [00:14, 18.49batch/s, loss=1.09e+5]


epoch 7: avg train loss 109548.83, bar train loss 245.893, len train loss 4.566, col train loss 154.959


Epoch 8: 2batch [00:00, 19.05batch/s, loss=1.08e+5]

epoch 7: avg test  loss 108892.00, bar  test loss 105.146, len  test loss 1.957, col  test loss 65.279


Epoch 8: 272batch [00:14, 18.84batch/s, loss=1.1e+5] 


epoch 8: avg train loss 108556.11, bar train loss 245.051, len train loss 4.555, col train loss 151.009


Epoch 9: 2batch [00:00, 18.52batch/s, loss=1.07e+5]

epoch 8: avg test  loss 108398.05, bar  test loss 104.959, len  test loss 1.953, col  test loss 64.426


Epoch 9: 272batch [00:14, 18.52batch/s, loss=9.64e+4]


epoch 9: avg train loss 108142.55, bar train loss 244.652, len train loss 4.546, col train loss 149.450


Epoch 10: 2batch [00:00, 17.39batch/s, loss=1.12e+5]

epoch 9: avg test  loss 108029.96, bar  test loss 104.780, len  test loss 1.950, col  test loss 63.872


Epoch 10: 272batch [00:14, 18.35batch/s, loss=1.09e+5]


epoch 10: avg train loss 107828.92, bar train loss 244.272, len train loss 4.534, col train loss 148.395
epoch 10: avg test  loss 107789.88, bar  test loss 104.670, len  test loss 1.947, col  test loss 63.535


Epoch 11: 272batch [00:14, 18.51batch/s, loss=1.06e+5]


epoch 11: avg train loss 105974.07, bar train loss 238.346, len train loss 4.533, col train loss 147.784


Epoch 12: 2batch [00:00, 18.87batch/s, loss=1.07e+5]

epoch 11: avg test  loss 105292.84, bar  test loss 101.075, len  test loss 1.946, col  test loss 63.425


Epoch 12: 272batch [00:14, 18.39batch/s, loss=1.17e+5]


epoch 12: avg train loss 104920.49, bar train loss 235.025, len train loss 4.526, col train loss 147.317


Epoch 13: 2batch [00:00, 19.05batch/s, loss=1.06e+5]

epoch 12: avg test  loss 104829.68, bar  test loss 100.590, len  test loss 1.945, col  test loss 63.093


Epoch 13: 272batch [00:14, 18.50batch/s, loss=1.03e+5]


epoch 13: avg train loss 104568.59, bar train loss 234.185, len train loss 4.519, col train loss 146.744


Epoch 14: 2batch [00:00, 18.69batch/s, loss=1.03e+5]

epoch 13: avg test  loss 104641.25, bar  test loss 100.361, len  test loss 1.940, col  test loss 62.978


Epoch 14: 272batch [00:14, 18.52batch/s, loss=1.01e+5]


epoch 14: avg train loss 104303.05, bar train loss 233.593, len train loss 4.511, col train loss 146.240


Epoch 15: 2batch [00:00, 18.87batch/s, loss=1.08e+5]

epoch 14: avg test  loss 104376.54, bar  test loss 100.135, len  test loss 1.939, col  test loss 62.793


Epoch 15: 272batch [00:14, 18.50batch/s, loss=1.06e+5]


epoch 15: avg train loss 104082.23, bar train loss 233.109, len train loss 4.505, col train loss 145.806
epoch 15: avg test  loss 104227.65, bar  test loss 100.005, len  test loss 1.938, col  test loss 62.652


Epoch 16: 272batch [00:14, 18.33batch/s, loss=1.01e+5]


epoch 16: avg train loss 103871.77, bar train loss 232.643, len train loss 4.499, col train loss 145.412


Epoch 17: 2batch [00:00, 19.23batch/s, loss=1e+5]   

epoch 16: avg test  loss 104014.05, bar  test loss 99.768, len  test loss 1.934, col  test loss 62.465


Epoch 17: 272batch [00:14, 18.64batch/s, loss=1.1e+5] 


epoch 17: avg train loss 103670.94, bar train loss 232.212, len train loss 4.490, col train loss 145.026


Epoch 18: 2batch [00:00, 18.69batch/s, loss=1.04e+5]

epoch 17: avg test  loss 103871.12, bar  test loss 99.672, len  test loss 1.930, col  test loss 62.318


Epoch 18: 272batch [00:14, 18.64batch/s, loss=1.01e+5]


epoch 18: avg train loss 103475.47, bar train loss 231.774, len train loss 4.483, col train loss 144.664


Epoch 19: 2batch [00:00, 19.05batch/s, loss=1.03e+5]

epoch 18: avg test  loss 103670.23, bar  test loss 99.457, len  test loss 1.928, col  test loss 62.177


Epoch 19: 272batch [00:14, 18.66batch/s, loss=1.15e+5]


epoch 19: avg train loss 103288.37, bar train loss 231.374, len train loss 4.476, col train loss 144.295


Epoch 20: 2batch [00:00, 18.87batch/s, loss=1.03e+5]

epoch 19: avg test  loss 103510.02, bar  test loss 99.318, len  test loss 1.925, col  test loss 62.072


Epoch 20: 272batch [00:14, 18.73batch/s, loss=9.86e+4]


epoch 20: avg train loss 103114.45, bar train loss 230.982, len train loss 4.466, col train loss 144.018
epoch 20: avg test  loss 103385.87, bar  test loss 99.223, len  test loss 1.920, col  test loss 61.995


Epoch 21: 272batch [00:14, 18.42batch/s, loss=1.08e+5]


epoch 21: avg train loss 102946.08, bar train loss 230.634, len train loss 4.458, col train loss 143.687


Epoch 22: 2batch [00:00, 18.87batch/s, loss=1.03e+5]

epoch 21: avg test  loss 103178.52, bar  test loss 98.996, len  test loss 1.915, col  test loss 61.817


Epoch 22: 272batch [00:14, 18.75batch/s, loss=9.83e+4]


epoch 22: avg train loss 102774.66, bar train loss 230.258, len train loss 4.450, col train loss 143.359


Epoch 23: 2batch [00:00, 18.35batch/s, loss=1.04e+5]

epoch 22: avg test  loss 103088.63, bar  test loss 98.925, len  test loss 1.911, col  test loss 61.745


Epoch 23: 272batch [00:14, 18.61batch/s, loss=1.01e+5]


epoch 23: avg train loss 102618.29, bar train loss 229.899, len train loss 4.441, col train loss 143.115


Epoch 24: 2batch [00:00, 19.05batch/s, loss=1.01e+5]

epoch 23: avg test  loss 102952.27, bar  test loss 98.718, len  test loss 1.908, col  test loss 61.674


Epoch 24: 272batch [00:14, 18.69batch/s, loss=1.02e+5]


epoch 24: avg train loss 102460.37, bar train loss 229.563, len train loss 4.433, col train loss 142.832


Epoch 25: 2batch [00:00, 18.35batch/s, loss=1.05e+5]

epoch 24: avg test  loss 102834.80, bar  test loss 98.635, len  test loss 1.908, col  test loss 61.585


Epoch 25: 272batch [00:14, 18.60batch/s, loss=1.09e+5]


epoch 25: avg train loss 102328.97, bar train loss 229.276, len train loss 4.425, col train loss 142.594
epoch 25: avg test  loss 102731.76, bar  test loss 98.572, len  test loss 1.903, col  test loss 61.534


Epoch 26: 272batch [00:14, 18.34batch/s, loss=9.97e+4]


epoch 26: avg train loss 102212.77, bar train loss 229.034, len train loss 4.413, col train loss 142.406


Epoch 27: 2batch [00:00, 18.35batch/s, loss=1.04e+5]

epoch 26: avg test  loss 102618.56, bar  test loss 98.445, len  test loss 1.900, col  test loss 61.454


Epoch 27: 272batch [00:14, 18.62batch/s, loss=9.65e+4]


epoch 27: avg train loss 102068.02, bar train loss 228.728, len train loss 4.406, col train loss 142.131


Epoch 28: 2batch [00:00, 18.87batch/s, loss=1.02e+5]

epoch 27: avg test  loss 102497.95, bar  test loss 98.395, len  test loss 1.895, col  test loss 61.288


Epoch 28: 272batch [00:14, 18.68batch/s, loss=9.27e+4]


epoch 28: avg train loss 101670.40, bar train loss 227.508, len train loss 4.396, col train loss 141.980


Epoch 29: 2batch [00:00, 18.69batch/s, loss=1e+5]   

epoch 28: avg test  loss 101408.16, bar  test loss 96.723, len  test loss 1.891, col  test loss 61.350


Epoch 29: 272batch [00:14, 18.74batch/s, loss=1.06e+5]


epoch 29: avg train loss 100661.86, bar train loss 224.016, len train loss 4.388, col train loss 142.087


Epoch 30: 2batch [00:00, 18.35batch/s, loss=9.86e+4]

epoch 29: avg test  loss 100966.14, bar  test loss 96.197, len  test loss 1.886, col  test loss 61.299


Epoch 30: 272batch [00:14, 18.47batch/s, loss=1.03e+5]


epoch 30: avg train loss 100444.71, bar train loss 223.436, len train loss 4.380, col train loss 141.875
epoch 30: avg test  loss 100859.76, bar  test loss 96.095, len  test loss 1.880, col  test loss 61.150


Epoch 31: 272batch [00:14, 18.53batch/s, loss=1.01e+5]


epoch 31: avg train loss 100269.84, bar train loss 223.047, len train loss 4.371, col train loss 141.591


Epoch 32: 2batch [00:00, 18.87batch/s, loss=1.02e+5]

epoch 31: avg test  loss 100690.25, bar  test loss 95.923, len  test loss 1.879, col  test loss 61.071


Epoch 32: 272batch [00:14, 18.75batch/s, loss=9.59e+4]


epoch 32: avg train loss 100128.02, bar train loss 222.717, len train loss 4.365, col train loss 141.379


Epoch 33: 2batch [00:00, 18.87batch/s, loss=9.95e+4]

epoch 32: avg test  loss 100538.86, bar  test loss 95.732, len  test loss 1.872, col  test loss 61.028


Epoch 33: 272batch [00:14, 18.85batch/s, loss=9.97e+4]


epoch 33: avg train loss 99995.05, bar train loss 222.418, len train loss 4.361, col train loss 141.173


Epoch 34: 2batch [00:00, 18.87batch/s, loss=1.01e+5]

epoch 33: avg test  loss 100505.41, bar  test loss 95.720, len  test loss 1.873, col  test loss 60.985


Epoch 34: 272batch [00:14, 18.82batch/s, loss=9.9e+4] 


epoch 34: avg train loss 99866.53, bar train loss 222.121, len train loss 4.355, col train loss 140.979


Epoch 35: 2batch [00:00, 19.23batch/s, loss=1.02e+5]

epoch 34: avg test  loss 100326.95, bar  test loss 95.555, len  test loss 1.873, col  test loss 60.850


Epoch 35: 272batch [00:14, 18.74batch/s, loss=8.95e+4]


epoch 35: avg train loss 99744.29, bar train loss 221.841, len train loss 4.352, col train loss 140.782
epoch 35: avg test  loss 100197.37, bar  test loss 95.428, len  test loss 1.870, col  test loss 60.788


Epoch 36: 272batch [00:14, 18.52batch/s, loss=9.74e+4]


epoch 36: avg train loss 99625.02, bar train loss 221.551, len train loss 4.348, col train loss 140.608


Epoch 37: 2batch [00:00, 19.23batch/s, loss=9.99e+4]

epoch 36: avg test  loss 100150.67, bar  test loss 95.344, len  test loss 1.867, col  test loss 60.747


Epoch 37: 272batch [00:14, 18.78batch/s, loss=9.57e+4]


epoch 37: avg train loss 99512.82, bar train loss 221.304, len train loss 4.345, col train loss 140.403


Epoch 38: 2batch [00:00, 18.69batch/s, loss=9.57e+4]

epoch 37: avg test  loss 100001.76, bar  test loss 95.195, len  test loss 1.865, col  test loss 60.682


Epoch 38: 272batch [00:14, 18.74batch/s, loss=9.83e+4]


epoch 38: avg train loss 99403.83, bar train loss 221.066, len train loss 4.338, col train loss 140.236


Epoch 39: 2batch [00:00, 18.87batch/s, loss=9.87e+4]

epoch 38: avg test  loss 100009.98, bar  test loss 95.267, len  test loss 1.866, col  test loss 60.600


Epoch 39: 272batch [00:14, 18.77batch/s, loss=9.68e+4]


epoch 39: avg train loss 99308.83, bar train loss 220.848, len train loss 4.337, col train loss 140.061


Epoch 40: 2batch [00:00, 18.52batch/s, loss=9.88e+4]

epoch 39: avg test  loss 99848.98, bar  test loss 95.065, len  test loss 1.865, col  test loss 60.552


Epoch 40: 272batch [00:14, 18.77batch/s, loss=9.45e+4]


epoch 40: avg train loss 99203.51, bar train loss 220.620, len train loss 4.330, col train loss 139.909
epoch 40: avg test  loss 99823.32, bar  test loss 95.034, len  test loss 1.861, col  test loss 60.513


Epoch 41: 272batch [00:14, 18.49batch/s, loss=1.04e+5]


epoch 41: avg train loss 99129.00, bar train loss 220.464, len train loss 4.328, col train loss 139.754


Epoch 42: 2batch [00:00, 18.69batch/s, loss=9.77e+4]

epoch 41: avg test  loss 99724.70, bar  test loss 94.946, len  test loss 1.861, col  test loss 60.451


Epoch 42: 272batch [00:14, 18.77batch/s, loss=1.06e+5]


epoch 42: avg train loss 99031.91, bar train loss 220.255, len train loss 4.323, col train loss 139.586


Epoch 43: 2batch [00:00, 18.52batch/s, loss=9.95e+4]

epoch 42: avg test  loss 99723.86, bar  test loss 94.979, len  test loss 1.857, col  test loss 60.438


Epoch 43: 272batch [00:14, 18.56batch/s, loss=9.77e+4]


epoch 43: avg train loss 98923.80, bar train loss 220.035, len train loss 4.321, col train loss 139.373


Epoch 44: 2batch [00:00, 18.87batch/s, loss=9.76e+4]

epoch 43: avg test  loss 99573.34, bar  test loss 94.805, len  test loss 1.857, col  test loss 60.304


Epoch 44: 272batch [00:14, 18.52batch/s, loss=9.88e+4]


epoch 44: avg train loss 98858.16, bar train loss 219.873, len train loss 4.318, col train loss 139.289


Epoch 45: 2batch [00:00, 18.69batch/s, loss=9.71e+4]

epoch 44: avg test  loss 99719.57, bar  test loss 94.965, len  test loss 1.858, col  test loss 60.433


Epoch 45: 272batch [00:14, 18.52batch/s, loss=1.02e+5]


epoch 45: avg train loss 98782.51, bar train loss 219.729, len train loss 4.314, col train loss 139.132
epoch 45: avg test  loss 99437.50, bar  test loss 94.674, len  test loss 1.857, col  test loss 60.197


Epoch 46: 272batch [00:14, 18.26batch/s, loss=9.66e+4]


epoch 46: avg train loss 98695.18, bar train loss 219.539, len train loss 4.314, col train loss 138.967


Epoch 47: 2batch [00:00, 19.05batch/s, loss=1.03e+5]

epoch 46: avg test  loss 99398.10, bar  test loss 94.671, len  test loss 1.856, col  test loss 60.195


Epoch 47: 272batch [00:14, 18.56batch/s, loss=9.75e+4]


epoch 47: avg train loss 98613.44, bar train loss 219.391, len train loss 4.312, col train loss 138.780


Epoch 48: 2batch [00:00, 18.18batch/s, loss=9.93e+4]

epoch 47: avg test  loss 99303.73, bar  test loss 94.606, len  test loss 1.854, col  test loss 60.120


Epoch 48: 272batch [00:14, 18.25batch/s, loss=9.38e+4]


epoch 48: avg train loss 98552.56, bar train loss 219.267, len train loss 4.310, col train loss 138.677


Epoch 49: 2batch [00:00, 18.35batch/s, loss=9.95e+4]

epoch 48: avg test  loss 99258.01, bar  test loss 94.561, len  test loss 1.852, col  test loss 60.066


Epoch 49: 272batch [00:15, 18.05batch/s, loss=9.08e+4]


epoch 49: avg train loss 98490.53, bar train loss 219.122, len train loss 4.308, col train loss 138.554


Epoch 50: 2batch [00:00, 18.02batch/s, loss=1.01e+5]

epoch 49: avg test  loss 99218.25, bar  test loss 94.479, len  test loss 1.852, col  test loss 60.023


Epoch 50: 272batch [00:15, 18.03batch/s, loss=9.2e+4] 


epoch 50: avg train loss 98397.74, bar train loss 218.937, len train loss 4.305, col train loss 138.394
epoch 50: avg test  loss 99367.59, bar  test loss 94.617, len  test loss 1.851, col  test loss 60.197


Epoch 51: 272batch [00:15, 17.96batch/s, loss=1.04e+5]


epoch 51: avg train loss 98349.81, bar train loss 218.855, len train loss 4.305, col train loss 138.265


Epoch 52: 2batch [00:00, 19.05batch/s, loss=1.01e+5]

epoch 51: avg test  loss 99160.87, bar  test loss 94.413, len  test loss 1.850, col  test loss 59.994


Epoch 52: 272batch [00:14, 18.35batch/s, loss=1.08e+5]


epoch 52: avg train loss 98284.32, bar train loss 218.700, len train loss 4.303, col train loss 138.163


Epoch 53: 2batch [00:00, 18.69batch/s, loss=9.99e+4]

epoch 52: avg test  loss 99052.51, bar  test loss 94.341, len  test loss 1.854, col  test loss 59.911


Epoch 53: 272batch [00:15, 18.09batch/s, loss=9.24e+4]


epoch 53: avg train loss 98225.15, bar train loss 218.602, len train loss 4.302, col train loss 138.014


Epoch 54: 2batch [00:00, 18.52batch/s, loss=9.55e+4]

epoch 53: avg test  loss 99019.99, bar  test loss 94.350, len  test loss 1.851, col  test loss 59.904


Epoch 54: 272batch [00:14, 18.16batch/s, loss=9.98e+4]


epoch 54: avg train loss 98151.17, bar train loss 218.463, len train loss 4.301, col train loss 137.867


Epoch 55: 2batch [00:00, 18.52batch/s, loss=9.96e+4]

epoch 54: avg test  loss 98964.27, bar  test loss 94.279, len  test loss 1.851, col  test loss 59.851


Epoch 55: 272batch [00:15, 17.86batch/s, loss=1.05e+5]


epoch 55: avg train loss 98102.88, bar train loss 218.361, len train loss 4.300, col train loss 137.760
epoch 55: avg test  loss 98874.04, bar  test loss 94.215, len  test loss 1.851, col  test loss 59.708


Epoch 56: 272batch [00:15, 18.04batch/s, loss=9.97e+4]


epoch 56: avg train loss 98044.40, bar train loss 218.270, len train loss 4.300, col train loss 137.613


Epoch 57: 2batch [00:00, 18.35batch/s, loss=9.87e+4]

epoch 56: avg test  loss 98889.07, bar  test loss 94.223, len  test loss 1.848, col  test loss 59.783


Epoch 57: 272batch [00:14, 18.75batch/s, loss=9.46e+4]


epoch 57: avg train loss 97992.91, bar train loss 218.163, len train loss 4.299, col train loss 137.504


Epoch 58: 2batch [00:00, 18.69batch/s, loss=9.78e+4]

epoch 57: avg test  loss 98851.29, bar  test loss 94.165, len  test loss 1.848, col  test loss 59.780


Epoch 58: 272batch [00:14, 18.66batch/s, loss=9.58e+4]


epoch 58: avg train loss 97910.18, bar train loss 218.014, len train loss 4.296, col train loss 137.324


Epoch 59: 2batch [00:00, 18.69batch/s, loss=9.64e+4]

epoch 58: avg test  loss 98820.85, bar  test loss 94.169, len  test loss 1.848, col  test loss 59.688


Epoch 59: 272batch [00:14, 18.28batch/s, loss=1.02e+5]


epoch 59: avg train loss 97857.84, bar train loss 217.915, len train loss 4.296, col train loss 137.199


Epoch 60: 2batch [00:00, 17.09batch/s, loss=9.86e+4]

epoch 59: avg test  loss 98704.81, bar  test loss 94.066, len  test loss 1.850, col  test loss 59.560


Epoch 60: 272batch [00:14, 18.30batch/s, loss=1.01e+5]


epoch 60: avg train loss 97776.01, bar train loss 217.746, len train loss 4.296, col train loss 137.042
epoch 60: avg test  loss 98724.05, bar  test loss 94.109, len  test loss 1.846, col  test loss 59.604


Epoch 61: 272batch [00:15, 17.75batch/s, loss=9.58e+4]


epoch 61: avg train loss 97746.16, bar train loss 217.712, len train loss 4.294, col train loss 136.943


Epoch 62: 2batch [00:00, 17.39batch/s, loss=9.77e+4]

epoch 61: avg test  loss 98699.48, bar  test loss 94.075, len  test loss 1.850, col  test loss 59.541


Epoch 62: 272batch [00:15, 17.75batch/s, loss=9.41e+4]


epoch 62: avg train loss 97696.14, bar train loss 217.610, len train loss 4.295, col train loss 136.845


Epoch 63: 2batch [00:00, 17.54batch/s, loss=9.96e+4]

epoch 62: avg test  loss 98697.27, bar  test loss 94.080, len  test loss 1.848, col  test loss 59.569


Epoch 63: 272batch [00:15, 18.05batch/s, loss=9.58e+4]


epoch 63: avg train loss 97629.39, bar train loss 217.484, len train loss 4.295, col train loss 136.688


Epoch 64: 2batch [00:00, 17.86batch/s, loss=95594.0]

epoch 63: avg test  loss 98634.54, bar  test loss 94.058, len  test loss 1.848, col  test loss 59.439


Epoch 64: 272batch [00:15, 17.54batch/s, loss=1.03e+5]


epoch 64: avg train loss 97582.67, bar train loss 217.429, len train loss 4.291, col train loss 136.549


Epoch 65: 2batch [00:00, 17.54batch/s, loss=1.02e+5]

epoch 64: avg test  loss 98538.58, bar  test loss 93.950, len  test loss 1.848, col  test loss 59.424


Epoch 65: 272batch [00:15, 17.62batch/s, loss=1.03e+5]


epoch 65: avg train loss 97528.69, bar train loss 217.321, len train loss 4.293, col train loss 136.427
epoch 65: avg test  loss 98507.91, bar  test loss 93.928, len  test loss 1.848, col  test loss 59.361


Epoch 66: 116batch [00:06, 17.07batch/s, loss=9.77e+4]

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 2200, 1000, save_folder="VAE5")

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 5600, 2200, save_folder="VAE4")

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')