In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [19]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime


writer = SummaryWriter(f"{link}/saved_models/new/IVAE1/tensorboard")

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

# Model Classes

In [6]:
class diva_args:

    def __init__(self, z_dim=128, d_dim=45, x_dim=7500, y_dim=2,
                 beta=10, rec_alpha = 1, rec_beta = 1, 
                 rec_gamma = 1, warmup = 1, prewarmup = 1):

        self.z_dim = z_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.beta = beta
        self.rec_alpha = rec_alpha
        self.rec_beta = rec_beta
        self.rec_gamma = rec_gamma
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [7]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        # loading encoded images
        print("loading encodings")
        if create_encodings:
            x_len, x_bar, x_col = self.get_encoded_values(self.images, ds)
        else:
            x_len = np.load(f'{link}/data/modmirbase_{ds}_images_len_new.npz')
            x_bar = np.load(f'{link}/data/modmirbase_{ds}_images_bar_new.npz')
            x_col = np.load(f'{link}/data/modmirbase_{ds}_images_col_new.npz')
        
        self.x_len = x_len
        self.x_bar = x_bar
        self.x_col = x_col
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
          
        if use_subset:    
            idxes = [i == 'hsa' and np.random.choice([True, False]) for i in self.names]
            self.names_ohe = self.names_ohe[idxes]
            self.labels = self.labels[idxes]
            self.images = self.images[idxes]
            self.x_len = self.x_len[idxes]
            self.x_col = self.x_col[idxes]
            self.x_bar = self.x_bar[idxes]
            self.mountain = self.mountain[idxes]
    
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_len = self.x_len[idx]
        x_col = self.x_col[idx]
        x_bar = self.x_bar[idx]
        mount = self.mountain[idx]                        
        return (x, y, d, x_len, x_col, x_bar, mount)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        out_len = np.zeros((n), dtype=np.uint8)
        out_col = np.zeros((n,26,100), dtype=np.uint8)
        out_bar = np.zeros((n,2,100), dtype=np.uint8)

        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            rna_len = 0
            broke = False
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    out_len[i] = rna_len
                    broke = True
                    out_col[i,25,j] = 1
                else:
                    rna_len += 1
                    # check color of bars
                    out_col[i, self.get_color(x,i,j),j] = 1 
                    #out_col[i, self.get_color(x[i,:,13,j]), 1, j] = 1
                    # check length of bars
                    len1 = 0
                    # loop until white pixel
                    while not (x[i,:,12-len1,j] == np.array([1.,1.,1.])).all():
                        len1 += 1
                        if 13-len1 == 0:
                            break
                    out_bar[i, 0, j] = len1

                    len2 = 0
                    while not (x[i,:,13+len2,j] == np.array([1.,1.,1.])).all():
                        len2 += 1
                        if 13+len2 == 25:
                            break
                    out_bar[i, 1, j] = len2
            if not broke:
                out_len[i] = rna_len


        with open(f'{link}/data/modmirbase_{ds}_images_len_new.npz', 'wb') as f:
            np.save(f, out_len)
        with open(f'{link}/data/modmirbase_{ds}_images_col_new.npz', 'wb') as f:
            np.save(f, out_col)
        with open(f'{link}/data/modmirbase_{ds}_images_bar_new.npz', 'wb') as f:
            np.save(f, out_bar)
        

        return out_len, out_bar, out_col

    def get_color(self, x, i, j):
        
        col = self._get_color(x[i,:,12,j])+self._get_color(x[i,:,13,j])
        if col == '00':
            return 0
        if col == '01':
            return 1
        if col == '02':
            return 2
        if col == '03':
            return 3
        if col == '04':
            return 4
        if col == '10':
            return 5
        if col == '11':
            return 6
        if col == '12':
            return 7
        if col == '13':
            return 8
        if col == '14':
            return 9
        if col == '20':
            return 10
        if col == '21':
            return 11
        if col == '22':
            return 12
        if col == '23':
            return 13
        if col == '24':
            return 14
        if col == '30':
            return 15
        if col == '31':
            return 16
        if col == '32':
            return 17
        if col == '33':
            return 18
        if col == '34':
            return 19
        if col == '40':
            return 20
        if col == '41':
            return 21
        if col == '42':
            return 22
        if col == '43':
            return 23
        if col == '44':
            return 24
        
        
    
    def _get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return "0" # black
        elif (pixel == np.array([1,0,0])).all():  
            return "1" # red
        elif (pixel == np.array([0,0,1])).all():  
            return "2" # blue
        elif (pixel == np.array([0,1,0])).all():  
            return "3" # green
        elif (pixel == np.array([1,1,0])).all():  
            return "4" # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [8]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim, dim1=256, dim2=512):
        super(px, self).__init__()

        self.fc = nn.Sequential(nn.Linear(z_dim, dim1, bias=False),  
                                 nn.ReLU(),
                                nn.Linear(dim1, dim2),
                                nn.ReLU())
        
        # Predicting length and color of each bar
        
        self.color = nn.Sequential(nn.Linear(dim2, 2600))
        
        
        self.length_bar = nn.Sequential(nn.Linear(dim2,200), nn.Softplus())
        
        
    def forward(self, z):
        
        h = self.fc(z)
        
        
        
        len_bar = self.length_bar(h).reshape(-1,2,100)
        len_bar_sc = nn.Parameter(torch.tensor([1.])).to(DEVICE)
        
        
        col = self.color(h).reshape(-1,26,100)
        col_bar = nn.Softmax(dim=1)(col)
        
        return len_bar, len_bar_sc, col_bar

    def reconstruct_image(self, len_bar, var_bar ,col_bar, sample=False):
        """
        reconstructs RNA image given output from decoder
        even indexes of len_bar and col_bar   -> top
        uneven indexes of len_bar and col_bar -> bottom
        function does not support sampling yet
        color reconstructions: 0: black
                               1: red
                               2: blue
                               3: green
                               4: yellow
        """
        color_dict = {
                  0: np.array([0,0,0]), # black
                  1: np.array([1,0,0]), # red
                  3: np.array([0,1,0]), # green
                  2: np.array([0,0,1]), # blue
                  4: np.array([1,1,0]),  # yellow
                  5: np.array([1,1,1])  # white
                  }
    
        _color_dict =  {0: (0,0),
                        1: (0,1),
                        2: (0,2),
                        3: (0,3),
                        4: (0,4),
                        5: (1,0),
                        6: (1,1),
                        7: (1,2),
                        8: (1,3),
                        9: (1,4),
                        10: (1,0),
                        11: (2,1),
                        12: (2,2),
                        13: (2,3),
                        14: (2,4),
                        15: (2,0),
                        16: (3,1),
                        17: (3,2),
                        18: (3,3),
                        19: (3,4),
                        20: (3,0),
                        21: (4,1),
                        22: (4,2),
                        23: (4,3),
                        24: (4,4),
                        25: (5,5)
                        }       
        len_bar = len_bar.cpu().numpy()
        var_bar = var_bar.cpu().numpy()
        col_bar = col_bar.cpu().numpy()
        n = len_bar.shape[0]
        output = np.ones((n,25,100,3))

        for i in range(n):
            limit = 100
            for j in range(limit):
                if sample:
                    _len_bar_1 = int(np.round(np.random.normal(loc=len_bar[i,0,j], scale=var_bar[i,0,j])))
                    _len_bar_2 = int(np.round(np.random.normal(loc=len_bar[i,1,j], scale=var_bar[i,1,j])))
                    _col_bar_1 = np.random.choice(np.arange(5), p = col_bar[i, :, 2*j])
                    _col_bar_2 = np.random.choice(np.arange(5), p = col_bar[i,:, 2*j+1])
                else:
                    _len_bar_1 = int(np.round(len_bar[i,0,j])) 
                    _len_bar_2 = int(np.round(len_bar[i,1,j]))
                    _col_bar_1, _col_bar_2 = _color_dict[np.argmax(col_bar[i,:,j])]
                    
                h1 = max(0,13-_len_bar_1)
                # paint upper bar
                output[i, h1:13, j] = color_dict[_col_bar_1]
                h2 = min(25,13+_len_bar_2)
                # paint lower bar
                output[i, 13:h2, j] = color_dict[_col_bar_2]
        
        
        return output

In [9]:
int(np.round(3.7, 0))
int(3.7)

3

In [10]:
# pzy_ = pzy(45, 7500, 2, 32,32,32)
# summary(pzy_, (1,2))
# pzy_ = px(45, 7500, 2, 32,32,32)
# summary(pzy_, [(1,32),(1,32),(1,32)])

## Endcoder Classes

In [11]:
#pzy_.reconstruct_image(torch.zeros((1,100)), torch.zeros((1,13,200)), torch.zeros(1,5,200)).shape

In [12]:
class inception_color(nn.Module):
    def __init__(self, filters):
        super(inception_color, self).__init__()
        
        self.filters = filters
        
        self.color_tower = nn.Sequential(
            nn.Conv2d(3, 1, kernel_size=1, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d((13,1),stride=(13,1), ceil_mode=True)
            
        )
        self.shape_tower = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.ZeroPad2d((0,0,0,1)), # shape -1,26,100,12
            nn.Conv2d(12, self.filters, kernel_size=(13,5), stride=(13,1)),
            nn.ReLU())



    def forward(self, x):
        col = self.color_tower(x)
        col = col.view(-1,200)
        shp = self.shape_tower(x)
        shp = shp.view(-1, 2*96*self.filters)

        out = torch.cat([col,shp],1)
        return out

In [13]:
class inception_A(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=16, out_channels=32):
        super(inception_A, self).__init__()
        
        self.tower_1 = nn.Sequential(
            nn.AvgPool2d((3,3), stride=1, padding=1, count_include_pad=False),
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_2 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_3 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU())
        
        self.tower_3a = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(7,3), padding='same'),
            nn.ReLU()
        )
        
        self.tower_3b = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(3,7), padding='same'),
            nn.ReLU()
        )
        
        self.tower_4 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=(1,1), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(3,7), padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(7,3), padding='same'),
            nn.ReLU()
        )
        
        self.out = nn.ReLU()
        
    def forward(self, x):
        t1 = self.tower_1(x)
        t2 = self.tower_2(x)
        t3 = self.tower_3(x)
        ta = self.tower_3a(t3)
        tb = self.tower_3b(t3)
        t4 = self.tower_4(x)
        cat = torch.cat([t1,t2,ta,tb,t4],1)
        out = self.out(cat)
        return out
    
    
class inception_B(nn.Module):
    def __init__(self, in_channels=128, hidden_channels=64, out_channels=128):
        super(inception_B, self).__init__()
        
        self.tower_1 = nn.Sequential(
            nn.AvgPool2d((3,3), stride=1, padding=1, count_include_pad=False),
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_2 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_3 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(1,7), padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(1,7), padding='same'),
            nn.ReLU(),
            
            
        )
        
        
        self.tower_4 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=(1,1), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=(1,7), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(7,1), padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(1,7), padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(7,1), padding='same'),
            nn.ReLU()
        )
        
        self.out = nn.ReLU()
        
    def forward(self, x):
        t1 = self.tower_1(x)
        t2 = self.tower_2(x)
        t3 = self.tower_3(x)
        t4 = self.tower_4(x)
        cat = torch.cat([t1,t2,t3,t4],1)
        out = self.out(cat)
        return out
    
class inception_C(nn.Module):
    def __init__(self, in_channels=128, hidden_channels=64, out_channels=128):
        super(inception_C, self).__init__()
        
        self.tower_1 = nn.Sequential(
            nn.AvgPool2d((3,3), stride=1, padding=1, count_include_pad=False),
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_2 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU()
        )
        
        self.tower_3 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU())
        
        self.tower_3a = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(5,1), padding='same'),
            nn.ReLU()
        )
        
        self.tower_3b = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(1,5), padding='same'),
            nn.ReLU()
        )
        
        
        self.tower_4 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=(1,1), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=(1,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=(3,1), padding='same'),
            nn.ReLU(),
        )
        
        self.tower_4a = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(3,1), padding='same'),
            nn.ReLU()
        )
        
        self.tower_4b = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(1,3), padding='same'),
            nn.ReLU()
        )
        
        self.out = nn.ReLU()
        
    def forward(self, x):
        t1 = self.tower_1(x)
        t2 = self.tower_2(x)
        t3 = self.tower_3(x)
        t3a = self.tower_3a(t3)
        t3b = self.tower_3b(t3)
        t4 = self.tower_4(x)
        t4a = self.tower_4a(t4)
        t4b = self.tower_4b(t4)
        cat = torch.cat([t1,t2,t3a,t3b,t4a,t4b],1)
        out = self.out(cat)
        return out

In [14]:
class reduction_A(nn.Module):
    def __init__(self, in_channels=128, hidden_channels=128, out_channels=128):
        super(reduction_A, self).__init__()
        
        self.tower_1 = nn.Sequential(
            nn.MaxPool2d((3,7), stride=(2,3)),
        )
        
        self.tower_2 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=(3,7), stride=(2,3), padding='valid'),
            nn.ReLU()
        )
        
        self.tower_3 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(3,7), stride=(2,3), padding='valid'),
            nn.ReLU(),
        )
        
        self.out = nn.ReLU()
        
    def forward(self, x):
        t1 = self.tower_1(x)
        t2 = self.tower_2(x)
        t3 = self.tower_3(x)
        cat = torch.cat([t1,t2,t3],1)
        out = self.out(cat)
        return out
    
    
class reduction_B(nn.Module):
    def __init__(self, in_channels=384, hidden_channels=64, out_channels=128):
        super(reduction_B, self).__init__()
        
        self.tower_1 = nn.Sequential(
            nn.MaxPool2d((3,7), stride=(2,3)),
        )
        
        self.tower_2 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=(3,7), stride=(2,3), padding='valid'),
            nn.ReLU()
        )
        
        self.tower_3 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=(1,7), padding='same'),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=(7,1), padding='same'),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3,7), stride=(2,3), padding='valid'),
            nn.ReLU(),
        )
        
        self.out = nn.ReLU()
        
    def forward(self, x):
        t1 = self.tower_1(x)
        t2 = self.tower_2(x)
        t3 = self.tower_3(x)
        cat = torch.cat([t1,t2,t3],1)
        out = self.out(cat)
        return out

In [15]:
inc = inception_color(4)
summary(inc, (1,3,25,100))



Layer (type:depth-idx)                   Output Shape              Param #
inception_color                          --                        --
├─Sequential: 1-1                        [1, 1, 2, 100]            --
│    └─Conv2d: 2-1                       [1, 1, 25, 100]           4
│    └─ReLU: 2-2                         [1, 1, 25, 100]           --
│    └─MaxPool2d: 2-3                    [1, 1, 2, 100]            --
├─Sequential: 1-2                        [1, 4, 2, 96]             --
│    └─Conv2d: 2-4                       [1, 12, 25, 100]          48
│    └─ReLU: 2-5                         [1, 12, 25, 100]          --
│    └─ZeroPad2d: 2-6                    [1, 12, 26, 100]          --
│    └─Conv2d: 2-7                       [1, 4, 2, 96]             3,124
│    └─ReLU: 2-8                         [1, 4, 2, 96]             --
Total params: 3,176
Trainable params: 3,176
Non-trainable params: 0
Total mult-adds (M): 0.73
Input size (MB): 0.03
Forward/backward pass size (MB): 0.27

In [16]:
inc = inc.to(DEVICE)

In [17]:
inc(torch.zeros((1,3,25,100)).to(DEVICE)).shape

torch.Size([1, 968])

In [35]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim):
        super(qz, self).__init__()

        self.inc_A = nn.Sequential( 
            inception_A(3, 8, 8),
            inception_A(40,16,16),
        )
        
        self.top = nn.Sequential(
            nn.Conv2d(80, 96, kernel_size=(3,3), stride=2, padding='valid'),
            nn.ReLU(),
            nn.Conv2d(96, 96, kernel_size=(3,3), stride=2, padding='valid'),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
#         self.red_A = nn.Sequential(
#             reduction_A(80,16,16)
#         )
#         self.inc_B = nn.Sequential(
#             inception_B(112,16,16)
#         )
#         self.red_B = nn.Sequential(
#             reduction_B(64,16,16)
#         )
#         self.inc_C = nn.Sequential(
#             inception_C(96,64,96),
#             #inception_C(512,32,64),
#             nn.MaxPool2d(2,2)
#         )
        
        self.inc_COL = inception_color(4)
        
        self.fc = nn.Sequential(nn.Linear(2304, 512), nn.ReLU())
        
        self.z_mu = nn.Sequential(nn.Linear(512+968, z_dim))
        self.z_si = nn.Sequential(nn.Linear(512+968, z_dim), nn.Softplus())


    def forward(self, x):
        h = self.inc_A(x)
        h = self.top(h)
        h = h.view(-1,2304)
        h = self.fc(h)
        
        c = self.inc_COL(x)
        ch = torch.cat([c,h],1)
        z_loc = self.z_mu(ch)
        z_scale = self.z_si(ch) + 1e-7

        return z_loc, z_scale




In [36]:
enc = qz(128,10,10,1024)
summary(enc, (1,3,25,100))

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 80, 25, 100]          --
│    └─inception_A: 2-1                  [1, 40, 25, 100]          --
│    │    └─Sequential: 3-1              [1, 8, 25, 100]           32
│    │    └─Sequential: 3-2              [1, 8, 25, 100]           32
│    │    └─Sequential: 3-3              [1, 8, 25, 100]           32
│    │    └─Sequential: 3-4              [1, 8, 25, 100]           1,352
│    │    └─Sequential: 3-5              [1, 8, 25, 100]           1,352
│    │    └─Sequential: 3-6              [1, 8, 25, 100]           2,736
│    │    └─ReLU: 3-7                    [1, 40, 25, 100]          --
│    └─inception_A: 2-2                  [1, 80, 25, 100]          --
│    │    └─Sequential: 3-8              [1, 16, 25, 100]          656
│    │    └─Sequential: 3-9              [1, 16, 25, 100]          656
│   

## Full model class

In [37]:
class StampDIVA(nn.Module):
    def __init__(self, args):
        super(StampDIVA, self).__init__()
        self.z_dim = args.z_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim

        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        

        self.beta = args.beta
        
        self.rec_alpha = args.rec_alpha
        self.rec_beta = args.rec_beta
        self.rec_gamma = args.rec_gamma

        self.warmup = args.warmup
        self.prewarmup = args.prewarmup

        self.cuda()

    def forward(self, d, x, y):
        # Encode
        zd_q_loc, zd_q_scale = self.qz(x)
        
        # Reparameterization trick
        qz = dist.Normal(zd_q_loc, zd_q_scale)
        z_q = qz.rsample()
        
        
        # Decode
        x_bar, x_bar_scale, x_col = self.px(z_q)
        z_p_loc, z_p_scale = torch.zeros(z_q.size()[0], self.z_dim).cuda(),\
                        torch.ones(z_q.size()[0], self.z_dim).cuda()
        pz = dist.Normal(z_p_loc, z_p_scale)

        # Reparameterization trick
        pz = dist.Normal(z_p_loc, z_p_scale)
        
        return x_bar, x_bar_scale, x_col, qz, pz, z_q

    def loss_function(self, d, x, y, out_len, out_bar, out_col):
        
        x_bar, x_bar_scale, x_col, qz, pz, z_q = self.forward(d, x, y)
       
        mse_bar = (((x_bar - out_bar)**2)).mean(dim=(1,2)).sum()
        
        max_bar = torch.argmax(x_col, dim=1)
        acc_bar = (max_bar==torch.argmax(out_col, dim=1)).float().mean((1)).sum().float()
        
        CE_bar = mse_bar#-log_bar
        CE_col = F.cross_entropy(x_col, out_col, reduction='sum')

        KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
        return self.rec_beta * CE_bar \
                  + self.rec_gamma * CE_col \
                  - self.beta * KL_z, \
                  CE_bar, CE_col, mse_bar, acc_bar

# Training the model

## Loading dataset

In [38]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [39]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
loading encodings
Loading Names! (~5s)


In [40]:
len(RNA_dataset)

34721

In [41]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    no_batches = 0
    mse_bar = 0
    acc_bar = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, x_len, x_col, x_bar,_) in pbar:
        # To device
        x, y, d , x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)

        optimizer.zero_grad()
        loss, bar_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(), x_len.float(), x_bar.float(), x_col.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        epoch_bar_loss += bar_loss
        epoch_col_loss += col_loss
        mse_bar += mse
        acc_bar += acc
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    epoch_bar_loss /= len(train_loader.dataset)
    epoch_col_loss /= len(train_loader.dataset)
    acc_bar /= len(train_loader.dataset)
    mse_bar /= len(train_loader.dataset)
    
    return train_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar

In [42]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    epoch_bar_loss = 0
    epoch_col_loss = 0
    mse_bar = 0
    acc_bar = 0        
    with torch.no_grad():
        for batch_idx, (x,y,d,x_len,x_col,x_bar,_) in enumerate(test_loader):
            x, y, d, x_len, x_bar, x_col = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE), x_len.to(DEVICE), x_bar.to(DEVICE), x_col.to(DEVICE)
            loss, bar_loss, col_loss, mse, acc = model.loss_function(d.float(), x.float(), y.float(),x_len.float(),x_bar.float(),x_col.float())
            test_loss += loss
            epoch_bar_loss += bar_loss
            epoch_col_loss += col_loss
            mse_bar += mse
            acc_bar += acc
    test_loss /= len(test_loader.dataset)
    epoch_bar_loss /= len(test_loader.dataset)
    epoch_col_loss /= len(test_loader.dataset)
    acc_bar /= len(test_loader.dataset)
    mse_bar /= len(test_loader.dataset)
    
    return test_loss, epoch_bar_loss, epoch_col_loss, mse_bar, acc_bar
  

In [43]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    epoch_loss_sup = []
    test_loss = []
    
    for epoch in range(start_epoch+1, end_epoch+1):
        diva.beta = min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)])
        if epoch< args.prewarmup:
            diva.beta = args.beta/args.prewarmup
        train_loss, avg_loss_bar, avg_loss_col, mtr, atr = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", bar train loss {:.3f}".format(avg_loss_bar)
        str_print += ", col train loss {:.3f}".format(avg_loss_col)
        print(str_print)

        rec_loss_train = diva.rec_beta * avg_loss_bar + diva.rec_gamma * avg_loss_col
        dis_loss_train = train_loss - rec_loss_train

        test_lss, avg_loss_bar_test, avg_loss_col_test, mte, ate = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", bar  test loss {:.3f}".format(avg_loss_bar_test)
        str_print += ", col  test loss {:.3f}".format(avg_loss_col_test)
        print(str_print)

        rec_loss_test = diva.rec_beta * avg_loss_bar_test + diva.rec_gamma * avg_loss_col_test
        dis_loss_test = test_lss - rec_loss_test

        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
            writer.add_scalars("bar_acc",{'train': atr, 'test':ate}, epoch)

        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{link}/saved_models/{save_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [44]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        m = a[1][-1][:10].to(DEVICE).float()
        x_2, x_2var, x_3 ,qz, pz, z_q = diva(d,x,y)
        out = diva.px.reconstruct_image(x_2, x_2var, x_3)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(out[i])
        ax[i, 0].imshow(x[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{link}/saved_models/{name}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [45]:
DEVICE

device(type='cuda')

## Model Training

In [46]:
default_args = diva_args(z_dim=1024, rec_alpha = 20, rec_beta = 10, rec_gamma = 10, 
                         beta=1, warmup=1, prewarmup=0)

In [47]:
diva = StampDIVA(default_args).to(DEVICE)

In [48]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/VAE10/checkpoints/905.pth'))

In [49]:
train_loader = DataLoader(RNA_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=128)

In [54]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
#optimizer = optim.Adam(diva.parameters(), lr=0.005)
optimizer = optim.RMSprop(diva.parameters(), lr=0.001, eps=0.1, momentum=0.05)

In [55]:
RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

(10, 100)

In [56]:
writer.flush()

In [57]:
%tensorboard --logdir="D:/users/Marko/downloads/mirna/saved_models/new/IVAE1/tensorboard/"

In [58]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 500, 0, save_folder="new/IVAE1",save_interval=5)

Epoch 1: 272batch [00:45,  6.04batch/s, loss=2.86e+3]


epoch 1: avg train loss 2922.62, bar train loss 9.990, col train loss 281.673
epoch 1: avg test  loss 2890.41, bar  test loss 7.907, col  test loss 281.076


Epoch 2: 272batch [00:44,  6.07batch/s, loss=2.84e+3]


epoch 2: avg train loss 2891.20, bar train loss 7.924, col train loss 281.091


Epoch 3: 1batch [00:00,  5.95batch/s, loss=2.88e+3]

epoch 2: avg test  loss 2890.98, bar  test loss 7.923, col  test loss 281.062


Epoch 3: 272batch [00:44,  6.06batch/s, loss=2.8e+3] 


epoch 3: avg train loss 2889.37, bar train loss 7.709, col train loss 281.020


Epoch 4: 1batch [00:00,  6.02batch/s, loss=2.88e+3]

epoch 3: avg test  loss 2882.04, bar  test loss 7.083, col  test loss 280.849


Epoch 4: 272batch [00:45,  6.02batch/s, loss=2.88e+3]


epoch 4: avg train loss 2880.16, bar train loss 6.835, col train loss 280.793


Epoch 5: 1batch [00:00,  5.99batch/s, loss=2.87e+3]

epoch 4: avg test  loss 2876.03, bar  test loss 6.468, col  test loss 280.709


Epoch 5: 272batch [00:45,  6.03batch/s, loss=2.85e+3]


epoch 5: avg train loss 2874.42, bar train loss 6.322, col train loss 280.606
epoch 5: avg test  loss 2870.95, bar  test loss 6.165, col  test loss 280.404


Epoch 6: 272batch [00:45,  6.00batch/s, loss=2.84e+3]


epoch 6: avg train loss 2869.84, bar train loss 6.036, col train loss 280.333


Epoch 7: 1batch [00:00,  5.99batch/s, loss=2.86e+3]

epoch 6: avg test  loss 2867.56, bar  test loss 5.928, col  test loss 280.181


Epoch 7: 272batch [00:45,  6.00batch/s, loss=2.89e+3]


epoch 7: avg train loss 2867.16, bar train loss 5.902, col train loss 280.141


Epoch 8: 1batch [00:00,  6.14batch/s, loss=2.84e+3]

epoch 7: avg test  loss 2865.10, bar  test loss 5.733, col  test loss 280.071


Epoch 8: 272batch [00:45,  6.02batch/s, loss=2.85e+3]


epoch 8: avg train loss 2864.73, bar train loss 5.744, col train loss 279.990


Epoch 9: 1batch [00:00,  6.02batch/s, loss=2.89e+3]

epoch 8: avg test  loss 2864.26, bar  test loss 5.901, col  test loss 279.850


Epoch 9: 272batch [00:45,  6.01batch/s, loss=2.85e+3]


epoch 9: avg train loss 2861.69, bar train loss 5.630, col train loss 279.741


Epoch 10: 1batch [00:00,  5.95batch/s, loss=2.86e+3]

epoch 9: avg test  loss 2859.84, bar  test loss 5.566, col  test loss 279.636


Epoch 10: 272batch [00:45,  6.00batch/s, loss=2.88e+3]


epoch 10: avg train loss 2859.72, bar train loss 5.524, col train loss 279.587
epoch 10: avg test  loss 2858.70, bar  test loss 5.444, col  test loss 279.518


Epoch 11: 272batch [00:45,  6.01batch/s, loss=2.81e+3]


epoch 11: avg train loss 2857.34, bar train loss 5.345, col train loss 279.462


Epoch 12: 1batch [00:00,  6.06batch/s, loss=2.88e+3]

epoch 11: avg test  loss 2855.81, bar  test loss 5.219, col  test loss 279.404


Epoch 12: 272batch [00:45,  5.98batch/s, loss=2.93e+3]


epoch 12: avg train loss 2855.65, bar train loss 5.189, col train loss 279.399


Epoch 13: 1batch [00:00,  6.02batch/s, loss=2.85e+3]

epoch 12: avg test  loss 2854.89, bar  test loss 5.140, col  test loss 279.392


Epoch 13: 272batch [00:45,  5.98batch/s, loss=2.84e+3]


epoch 13: avg train loss 2854.32, bar train loss 5.075, col train loss 279.334


Epoch 14: 0batch [00:00, ?batch/s]

epoch 13: avg test  loss 2853.32, bar  test loss 4.958, col  test loss 279.311


Epoch 14: 272batch [00:44,  6.05batch/s, loss=2.9e+3] 


epoch 14: avg train loss 2853.37, bar train loss 5.009, col train loss 279.273


Epoch 15: 0batch [00:00, ?batch/s]

epoch 14: avg test  loss 2853.53, bar  test loss 5.093, col  test loss 279.264


Epoch 15: 272batch [00:45,  6.00batch/s, loss=2.86e+3]


epoch 15: avg train loss 2851.63, bar train loss 4.958, col train loss 279.110
epoch 15: avg test  loss 2850.54, bar  test loss 4.948, col  test loss 279.071


Epoch 16: 272batch [00:45,  6.04batch/s, loss=2.89e+3]


epoch 16: avg train loss 2850.62, bar train loss 4.935, col train loss 278.993


Epoch 17: 1batch [00:00,  6.13batch/s, loss=2.86e+3]

epoch 16: avg test  loss 2850.25, bar  test loss 5.044, col  test loss 278.933


Epoch 17: 272batch [00:45,  6.04batch/s, loss=2.86e+3]


epoch 17: avg train loss 2848.87, bar train loss 4.879, col train loss 278.826


Epoch 18: 1batch [00:00,  6.06batch/s, loss=2.85e+3]

epoch 17: avg test  loss 2847.63, bar  test loss 4.753, col  test loss 278.813


Epoch 18: 272batch [00:45,  5.97batch/s, loss=2.89e+3]


epoch 18: avg train loss 2847.17, bar train loss 4.804, col train loss 278.689


Epoch 19: 1batch [00:00,  5.95batch/s, loss=2.89e+3]

epoch 18: avg test  loss 2846.58, bar  test loss 4.759, col  test loss 278.636


Epoch 19: 272batch [00:45,  6.03batch/s, loss=2.78e+3]


epoch 19: avg train loss 2846.14, bar train loss 4.756, col train loss 278.593


Epoch 20: 0batch [00:00, ?batch/s]

epoch 19: avg test  loss 2845.45, bar  test loss 4.603, col  test loss 278.605


Epoch 20: 272batch [00:45,  6.03batch/s, loss=2.79e+3]


epoch 20: avg train loss 2845.28, bar train loss 4.705, col train loss 278.529
epoch 20: avg test  loss 2844.72, bar  test loss 4.643, col  test loss 278.515


Epoch 21: 272batch [00:45,  5.98batch/s, loss=2.85e+3]


epoch 21: avg train loss 2844.23, bar train loss 4.692, col train loss 278.423


Epoch 22: 0batch [00:00, ?batch/s]

epoch 21: avg test  loss 2843.90, bar  test loss 4.652, col  test loss 278.371


Epoch 22: 272batch [00:45,  6.02batch/s, loss=2.87e+3]


epoch 22: avg train loss 2843.39, bar train loss 4.670, col train loss 278.321


Epoch 23: 1batch [00:00,  6.06batch/s, loss=2.84e+3]

epoch 22: avg test  loss 2843.25, bar  test loss 4.630, col  test loss 278.325


Epoch 23: 272batch [00:45,  5.99batch/s, loss=2.89e+3]


epoch 23: avg train loss 2842.47, bar train loss 4.651, col train loss 278.222


Epoch 24: 1batch [00:00,  6.13batch/s, loss=2.85e+3]

epoch 23: avg test  loss 2842.24, bar  test loss 4.567, col  test loss 278.237


Epoch 24: 272batch [00:45,  6.01batch/s, loss=2.87e+3]


epoch 24: avg train loss 2841.96, bar train loss 4.639, col train loss 278.157


Epoch 25: 1batch [00:00,  6.10batch/s, loss=2.84e+3]

epoch 24: avg test  loss 2842.25, bar  test loss 4.525, col  test loss 278.186


Epoch 25: 272batch [00:45,  6.03batch/s, loss=2.82e+3]


epoch 25: avg train loss 2841.36, bar train loss 4.612, col train loss 278.093
epoch 25: avg test  loss 2841.85, bar  test loss 4.730, col  test loss 278.136


Epoch 26: 272batch [00:45,  5.97batch/s, loss=2.76e+3]


epoch 26: avg train loss 2840.62, bar train loss 4.592, col train loss 278.035


Epoch 27: 1batch [00:00,  6.06batch/s, loss=2.86e+3]

epoch 26: avg test  loss 2841.04, bar  test loss 4.466, col  test loss 278.112


Epoch 27: 272batch [00:45,  5.98batch/s, loss=2.81e+3]


epoch 27: avg train loss 2839.93, bar train loss 4.553, col train loss 277.981


Epoch 28: 1batch [00:00,  6.06batch/s, loss=2.84e+3]

epoch 27: avg test  loss 2839.95, bar  test loss 4.508, col  test loss 277.999


Epoch 28: 272batch [00:45,  6.00batch/s, loss=2.86e+3]


epoch 28: avg train loss 2839.26, bar train loss 4.529, col train loss 277.913


Epoch 29: 1batch [00:00,  6.06batch/s, loss=2.85e+3]

epoch 28: avg test  loss 2840.07, bar  test loss 4.528, col  test loss 277.968


Epoch 29: 272batch [00:45,  5.99batch/s, loss=2.8e+3] 


epoch 29: avg train loss 2838.55, bar train loss 4.500, col train loss 277.858


Epoch 30: 1batch [00:00,  6.21batch/s, loss=2.84e+3]

epoch 29: avg test  loss 2839.22, bar  test loss 4.644, col  test loss 277.887


Epoch 30: 272batch [00:45,  6.04batch/s, loss=2.81e+3]


epoch 30: avg train loss 2837.92, bar train loss 4.473, col train loss 277.793
epoch 30: avg test  loss 2837.99, bar  test loss 4.491, col  test loss 277.854


Epoch 31: 272batch [00:45,  6.01batch/s, loss=2.82e+3]


epoch 31: avg train loss 2837.31, bar train loss 4.458, col train loss 277.734


Epoch 32: 1batch [00:00,  5.78batch/s, loss=2.83e+3]

epoch 31: avg test  loss 2838.04, bar  test loss 4.433, col  test loss 277.794


Epoch 32: 272batch [00:45,  6.01batch/s, loss=2.85e+3]


epoch 32: avg train loss 2836.58, bar train loss 4.428, col train loss 277.658


Epoch 33: 1batch [00:00,  6.29batch/s, loss=2.83e+3]

epoch 32: avg test  loss 2836.51, bar  test loss 4.415, col  test loss 277.664


Epoch 33: 272batch [00:45,  6.03batch/s, loss=2.87e+3]


epoch 33: avg train loss 2835.50, bar train loss 4.402, col train loss 277.554


Epoch 34: 1batch [00:00,  6.06batch/s, loss=2.84e+3]

epoch 33: avg test  loss 2835.46, bar  test loss 4.352, col  test loss 277.608


Epoch 34: 272batch [00:45,  6.00batch/s, loss=2.79e+3]


epoch 34: avg train loss 2834.64, bar train loss 4.367, col train loss 277.477


Epoch 35: 1batch [00:00,  6.17batch/s, loss=2.81e+3]

epoch 34: avg test  loss 2835.34, bar  test loss 4.364, col  test loss 277.583


Epoch 35: 272batch [00:45,  6.01batch/s, loss=2.8e+3] 


epoch 35: avg train loss 2834.17, bar train loss 4.343, col train loss 277.438
epoch 35: avg test  loss 2834.55, bar  test loss 4.321, col  test loss 277.495


Epoch 36: 272batch [00:45,  5.98batch/s, loss=2.89e+3]


epoch 36: avg train loss 2833.62, bar train loss 4.317, col train loss 277.383


Epoch 37: 1batch [00:00,  6.06batch/s, loss=2.84e+3]

epoch 36: avg test  loss 2833.85, bar  test loss 4.374, col  test loss 277.410


Epoch 37: 272batch [00:45,  6.00batch/s, loss=2.81e+3]


epoch 37: avg train loss 2832.94, bar train loss 4.292, col train loss 277.322


Epoch 38: 1batch [00:00,  6.10batch/s, loss=2.83e+3]

epoch 37: avg test  loss 2833.34, bar  test loss 4.310, col  test loss 277.375


Epoch 38: 272batch [00:42,  6.41batch/s, loss=2.8e+3] 


epoch 38: avg train loss 2832.24, bar train loss 4.271, col train loss 277.259


Epoch 39: 0batch [00:00, ?batch/s]

epoch 38: avg test  loss 2832.38, bar  test loss 4.223, col  test loss 277.336


Epoch 39: 272batch [00:46,  5.81batch/s, loss=2.79e+3]


epoch 39: avg train loss 2831.91, bar train loss 4.257, col train loss 277.211


Epoch 40: 0batch [00:00, ?batch/s, loss=2.82e+3]

epoch 39: avg test  loss 2832.29, bar  test loss 4.251, col  test loss 277.293


Epoch 40: 272batch [00:46,  5.82batch/s, loss=2.86e+3]


epoch 40: avg train loss 2831.32, bar train loss 4.234, col train loss 277.167
epoch 40: avg test  loss 2832.06, bar  test loss 4.236, col  test loss 277.230


Epoch 41: 272batch [00:46,  5.83batch/s, loss=2.81e+3]


epoch 41: avg train loss 2830.29, bar train loss 4.199, col train loss 277.073


Epoch 42: 0batch [00:00, ?batch/s]

epoch 41: avg test  loss 2830.79, bar  test loss 4.281, col  test loss 277.169


Epoch 42: 272batch [00:46,  5.81batch/s, loss=2.84e+3]


epoch 42: avg train loss 2830.08, bar train loss 4.191, col train loss 277.040


Epoch 43: 1batch [00:00,  5.92batch/s, loss=2.81e+3]

epoch 42: avg test  loss 2830.54, bar  test loss 4.154, col  test loss 277.107


Epoch 43: 272batch [00:46,  5.83batch/s, loss=2.77e+3]


epoch 43: avg train loss 2829.64, bar train loss 4.168, col train loss 276.996


Epoch 44: 1batch [00:00,  5.88batch/s, loss=2.82e+3]

epoch 43: avg test  loss 2830.03, bar  test loss 4.154, col  test loss 277.081


Epoch 44: 272batch [00:46,  5.82batch/s, loss=2.81e+3]


epoch 44: avg train loss 2828.95, bar train loss 4.142, col train loss 276.942


Epoch 45: 1batch [00:00,  5.88batch/s, loss=2.81e+3]

epoch 44: avg test  loss 2829.74, bar  test loss 4.262, col  test loss 277.028


Epoch 45: 272batch [00:46,  5.80batch/s, loss=2.83e+3]


epoch 45: avg train loss 2828.60, bar train loss 4.129, col train loss 276.908
epoch 45: avg test  loss 2829.45, bar  test loss 4.149, col  test loss 277.038


Epoch 46: 272batch [00:46,  5.79batch/s, loss=2.85e+3]


epoch 46: avg train loss 2828.34, bar train loss 4.120, col train loss 276.885


Epoch 47: 1batch [00:00,  5.68batch/s, loss=2.86e+3]

epoch 46: avg test  loss 2828.47, bar  test loss 4.126, col  test loss 276.986


Epoch 47: 272batch [00:46,  5.80batch/s, loss=2.81e+3]


epoch 47: avg train loss 2827.94, bar train loss 4.101, col train loss 276.851


Epoch 48: 1batch [00:00,  5.78batch/s, loss=2.81e+3]

epoch 47: avg test  loss 2828.39, bar  test loss 4.079, col  test loss 276.968


Epoch 48: 272batch [00:47,  5.73batch/s, loss=2.83e+3]


epoch 48: avg train loss 2827.30, bar train loss 4.079, col train loss 276.791


Epoch 49: 1batch [00:00,  5.62batch/s, loss=2.84e+3]

epoch 48: avg test  loss 2827.62, bar  test loss 4.047, col  test loss 276.850


Epoch 49: 272batch [00:47,  5.71batch/s, loss=2.83e+3]


epoch 49: avg train loss 2826.61, bar train loss 4.060, col train loss 276.722


Epoch 50: 0batch [00:00, ?batch/s, loss=2.8e+3]

epoch 49: avg test  loss 2827.56, bar  test loss 4.080, col  test loss 276.925


Epoch 50: 272batch [00:47,  5.73batch/s, loss=2.79e+3]


epoch 50: avg train loss 2826.25, bar train loss 4.045, col train loss 276.701
epoch 50: avg test  loss 2826.80, bar  test loss 4.086, col  test loss 276.789


Epoch 51: 272batch [00:47,  5.76batch/s, loss=2.84e+3]


epoch 51: avg train loss 2825.81, bar train loss 4.031, col train loss 276.659


Epoch 52: 1batch [00:00,  5.75batch/s, loss=2.82e+3]

epoch 51: avg test  loss 2825.92, bar  test loss 3.969, col  test loss 276.717


Epoch 52: 169batch [00:29,  5.68batch/s, loss=2.82e+3]


KeyboardInterrupt: 

In [68]:
optimizer = optim.RMSprop(diva.parameters(), lr=0.001, eps=0.1, momentum=0.2)

In [64]:
diva.load_state_dict(torch.load(f'{link}/saved_models/new/IVAE1/checkpoints/50.pth'))

<All keys matched successfully>

In [67]:
 lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 100, 63, save_folder="new/IVAE1")

Epoch 64: 272batch [00:46,  5.83batch/s, loss=2.86e+3]


epoch 64: avg train loss 2828.26, bar train loss 4.152, col train loss 276.759


Epoch 65: 1batch [00:00,  5.99batch/s, loss=2.81e+3]

epoch 64: avg test  loss 2826.25, bar  test loss 4.014, col  test loss 276.771


Epoch 65: 272batch [00:46,  5.80batch/s, loss=2.85e+3]


epoch 65: avg train loss 2825.43, bar train loss 4.010, col train loss 276.615
epoch 65: avg test  loss 2826.56, bar  test loss 4.085, col  test loss 276.692


Epoch 66: 272batch [00:46,  5.80batch/s, loss=2.85e+3]


epoch 66: avg train loss 2824.57, bar train loss 3.983, col train loss 276.542


Epoch 67: 1batch [00:00,  5.81batch/s, loss=2.82e+3]

epoch 66: avg test  loss 2825.07, bar  test loss 3.991, col  test loss 276.665


Epoch 67: 272batch [00:47,  5.77batch/s, loss=2.8e+3] 


epoch 67: avg train loss 2823.83, bar train loss 3.959, col train loss 276.487


Epoch 68: 0batch [00:00, ?batch/s]

epoch 67: avg test  loss 2824.46, bar  test loss 3.903, col  test loss 276.579


Epoch 68: 272batch [00:47,  5.77batch/s, loss=2.86e+3]


epoch 68: avg train loss 2823.18, bar train loss 3.948, col train loss 276.438


Epoch 69: 1batch [00:00,  5.65batch/s, loss=2.84e+3]

epoch 68: avg test  loss 2824.96, bar  test loss 4.092, col  test loss 276.549


Epoch 69: 272batch [00:47,  5.70batch/s, loss=2.88e+3]


epoch 69: avg train loss 2822.59, bar train loss 3.925, col train loss 276.387


Epoch 70: 1batch [00:00,  5.85batch/s, loss=2.82e+3]

epoch 69: avg test  loss 2823.75, bar  test loss 3.926, col  test loss 276.463


Epoch 70: 272batch [00:47,  5.74batch/s, loss=2.82e+3]


epoch 70: avg train loss 2822.15, bar train loss 3.914, col train loss 276.347
epoch 70: avg test  loss 2822.92, bar  test loss 3.944, col  test loss 276.396


Epoch 71: 272batch [00:47,  5.76batch/s, loss=2.83e+3]


epoch 71: avg train loss 2821.34, bar train loss 3.892, col train loss 276.296


Epoch 72: 0batch [00:00, ?batch/s]

epoch 71: avg test  loss 2822.43, bar  test loss 3.937, col  test loss 276.339


Epoch 72: 272batch [00:47,  5.78batch/s, loss=2.84e+3]


epoch 72: avg train loss 2820.63, bar train loss 3.874, col train loss 276.234


Epoch 73: 1batch [00:00,  5.81batch/s, loss=2.8e+3]

epoch 72: avg test  loss 2821.43, bar  test loss 3.873, col  test loss 276.315


Epoch 73: 272batch [00:46,  5.85batch/s, loss=2.81e+3]


epoch 73: avg train loss 2819.67, bar train loss 3.846, col train loss 276.151


Epoch 74: 1batch [00:00,  5.78batch/s, loss=2.84e+3]

epoch 73: avg test  loss 2820.55, bar  test loss 3.827, col  test loss 276.263


Epoch 74: 272batch [00:46,  5.82batch/s, loss=2.86e+3]


epoch 74: avg train loss 2819.35, bar train loss 3.844, col train loss 276.126


Epoch 75: 1batch [00:00,  5.81batch/s, loss=2.84e+3]

epoch 74: avg test  loss 2820.02, bar  test loss 3.767, col  test loss 276.215


Epoch 75: 272batch [00:47,  5.76batch/s, loss=2.85e+3]


epoch 75: avg train loss 2818.85, bar train loss 3.829, col train loss 276.093
epoch 75: avg test  loss 2819.65, bar  test loss 3.778, col  test loss 276.152


Epoch 76: 272batch [00:47,  5.75batch/s, loss=2.82e+3]


epoch 76: avg train loss 2818.06, bar train loss 3.811, col train loss 276.031


Epoch 77: 0batch [00:00, ?batch/s]

epoch 76: avg test  loss 2818.88, bar  test loss 3.816, col  test loss 276.129


Epoch 77: 272batch [00:47,  5.77batch/s, loss=2.85e+3]


epoch 77: avg train loss 2817.24, bar train loss 3.790, col train loss 275.957


Epoch 78: 0batch [00:00, ?batch/s]

epoch 77: avg test  loss 2817.79, bar  test loss 3.746, col  test loss 276.068


Epoch 78: 272batch [00:46,  5.84batch/s, loss=2.82e+3]


epoch 78: avg train loss 2816.61, bar train loss 3.773, col train loss 275.916


Epoch 79: 1batch [00:00,  5.78batch/s, loss=2.81e+3]

epoch 78: avg test  loss 2818.07, bar  test loss 3.782, col  test loss 276.080


Epoch 79: 272batch [00:46,  5.83batch/s, loss=2.75e+3]


epoch 79: avg train loss 2816.11, bar train loss 3.763, col train loss 275.878


Epoch 80: 0batch [00:00, ?batch/s]

epoch 79: avg test  loss 2817.84, bar  test loss 3.806, col  test loss 276.066


Epoch 80: 272batch [00:46,  5.80batch/s, loss=2.84e+3]


epoch 80: avg train loss 2815.59, bar train loss 3.745, col train loss 275.838
epoch 80: avg test  loss 2816.76, bar  test loss 3.706, col  test loss 275.930


Epoch 81: 272batch [00:47,  5.78batch/s, loss=2.8e+3] 


epoch 81: avg train loss 2815.27, bar train loss 3.722, col train loss 275.824


Epoch 82: 1batch [00:00,  5.99batch/s, loss=2.8e+3]

epoch 81: avg test  loss 2816.23, bar  test loss 3.731, col  test loss 275.934


Epoch 82: 272batch [00:47,  5.72batch/s, loss=2.83e+3]


epoch 82: avg train loss 2814.80, bar train loss 3.708, col train loss 275.779


Epoch 83: 0batch [00:00, ?batch/s]

epoch 82: avg test  loss 2815.55, bar  test loss 3.687, col  test loss 275.940


Epoch 83: 272batch [00:47,  5.77batch/s, loss=2.83e+3]


epoch 83: avg train loss 2814.06, bar train loss 3.689, col train loss 275.736


Epoch 84: 1batch [00:00,  5.59batch/s, loss=2.84e+3]

epoch 83: avg test  loss 2814.88, bar  test loss 3.669, col  test loss 275.846


Epoch 84: 272batch [00:47,  5.77batch/s, loss=2.81e+3]


epoch 84: avg train loss 2813.79, bar train loss 3.683, col train loss 275.707


Epoch 85: 0batch [00:00, ?batch/s]

epoch 84: avg test  loss 2814.97, bar  test loss 3.673, col  test loss 275.827


Epoch 85: 272batch [00:46,  5.82batch/s, loss=2.8e+3] 


epoch 85: avg train loss 2813.35, bar train loss 3.674, col train loss 275.665
epoch 85: avg test  loss 2813.65, bar  test loss 3.628, col  test loss 275.749


Epoch 86: 272batch [00:47,  5.78batch/s, loss=2.83e+3]


epoch 86: avg train loss 2812.81, bar train loss 3.651, col train loss 275.639


Epoch 87: 0batch [00:00, ?batch/s]

epoch 86: avg test  loss 2813.90, bar  test loss 3.555, col  test loss 275.779


Epoch 87: 272batch [00:46,  5.80batch/s, loss=2.79e+3]


epoch 87: avg train loss 2812.66, bar train loss 3.641, col train loss 275.629


Epoch 88: 1batch [00:00,  5.88batch/s, loss=2.83e+3]

epoch 87: avg test  loss 2813.51, bar  test loss 3.656, col  test loss 275.726


Epoch 88: 272batch [00:47,  5.78batch/s, loss=2.81e+3]


epoch 88: avg train loss 2812.29, bar train loss 3.625, col train loss 275.596


Epoch 89: 1batch [00:00,  5.81batch/s, loss=2.79e+3]

epoch 88: avg test  loss 2814.12, bar  test loss 3.601, col  test loss 275.818


Epoch 89: 272batch [00:47,  5.75batch/s, loss=2.83e+3]


epoch 89: avg train loss 2812.03, bar train loss 3.619, col train loss 275.591


Epoch 90: 1batch [00:00,  5.85batch/s, loss=2.8e+3]

epoch 89: avg test  loss 2813.26, bar  test loss 3.617, col  test loss 275.684


Epoch 90: 272batch [00:47,  5.73batch/s, loss=2.83e+3]


epoch 90: avg train loss 2811.82, bar train loss 3.610, col train loss 275.574
epoch 90: avg test  loss 2813.06, bar  test loss 3.607, col  test loss 275.675


Epoch 91: 272batch [00:47,  5.75batch/s, loss=2.83e+3]


epoch 91: avg train loss 2811.44, bar train loss 3.603, col train loss 275.534


Epoch 92: 1batch [00:00,  5.81batch/s, loss=2.81e+3]

epoch 91: avg test  loss 2812.93, bar  test loss 3.567, col  test loss 275.663


Epoch 92: 272batch [00:47,  5.71batch/s, loss=2.84e+3]


epoch 92: avg train loss 2811.39, bar train loss 3.582, col train loss 275.554


Epoch 93: 0batch [00:00, ?batch/s]

epoch 92: avg test  loss 2812.69, bar  test loss 3.595, col  test loss 275.667


Epoch 93: 272batch [00:47,  5.75batch/s, loss=2.79e+3]


epoch 93: avg train loss 2811.16, bar train loss 3.583, col train loss 275.521


Epoch 94: 1batch [00:00,  5.71batch/s, loss=2.85e+3]

epoch 93: avg test  loss 2812.46, bar  test loss 3.547, col  test loss 275.676


Epoch 94: 272batch [00:47,  5.70batch/s, loss=2.83e+3]


epoch 94: avg train loss 2810.70, bar train loss 3.566, col train loss 275.487


Epoch 95: 0batch [00:00, ?batch/s]

epoch 94: avg test  loss 2811.83, bar  test loss 3.541, col  test loss 275.589


Epoch 95: 272batch [00:47,  5.74batch/s, loss=2.78e+3]


epoch 95: avg train loss 2810.42, bar train loss 3.563, col train loss 275.462
epoch 95: avg test  loss 2811.98, bar  test loss 3.625, col  test loss 275.686


Epoch 96: 272batch [00:47,  5.75batch/s, loss=2.76e+3]


epoch 96: avg train loss 2810.38, bar train loss 3.552, col train loss 275.459


Epoch 97: 0batch [00:00, ?batch/s, loss=2.81e+3]

epoch 96: avg test  loss 2811.57, bar  test loss 3.505, col  test loss 275.637


Epoch 97: 272batch [00:47,  5.76batch/s, loss=2.83e+3]


epoch 97: avg train loss 2810.14, bar train loss 3.543, col train loss 275.443


Epoch 98: 0batch [00:00, ?batch/s]

epoch 97: avg test  loss 2811.56, bar  test loss 3.589, col  test loss 275.560


Epoch 98: 272batch [00:47,  5.78batch/s, loss=2.81e+3]


epoch 98: avg train loss 2809.84, bar train loss 3.537, col train loss 275.407


Epoch 99: 0batch [00:00, ?batch/s]

epoch 98: avg test  loss 2811.49, bar  test loss 3.543, col  test loss 275.588


Epoch 99: 272batch [00:47,  5.75batch/s, loss=2.8e+3] 


epoch 99: avg train loss 2809.56, bar train loss 3.518, col train loss 275.399


Epoch 100: 1batch [00:00,  5.88batch/s, loss=2.82e+3]

epoch 99: avg test  loss 2810.79, bar  test loss 3.493, col  test loss 275.521


Epoch 100: 272batch [00:47,  5.77batch/s, loss=2.85e+3]


epoch 100: avg train loss 2809.42, bar train loss 3.516, col train loss 275.373
epoch 100: avg test  loss 2810.52, bar  test loss 3.542, col  test loss 275.530


In [69]:
lss3, lss_t3 = train(default_args, train_loader, test_loader, diva, optimizer, 500, 100, save_folder="new/IVAE1")

Epoch 101: 272batch [00:46,  5.81batch/s, loss=2.79e+3]


epoch 101: avg train loss 2812.91, bar train loss 3.693, col train loss 275.573


Epoch 102: 1batch [00:00,  5.85batch/s, loss=2.79e+3]

epoch 101: avg test  loss 2811.46, bar  test loss 3.589, col  test loss 275.574


Epoch 102: 272batch [00:46,  5.82batch/s, loss=2.81e+3]


epoch 102: avg train loss 2810.02, bar train loss 3.544, col train loss 275.412


Epoch 103: 1batch [00:00,  5.85batch/s, loss=2.81e+3]

epoch 102: avg test  loss 2810.96, bar  test loss 3.529, col  test loss 275.481


Epoch 103: 272batch [00:46,  5.81batch/s, loss=2.82e+3]


epoch 103: avg train loss 2809.42, bar train loss 3.519, col train loss 275.365


Epoch 104: 0batch [00:00, ?batch/s]

epoch 103: avg test  loss 2810.96, bar  test loss 3.464, col  test loss 275.548


Epoch 104: 272batch [00:46,  5.83batch/s, loss=2.83e+3]


epoch 104: avg train loss 2809.05, bar train loss 3.514, col train loss 275.332


Epoch 105: 0batch [00:00, ?batch/s]

epoch 104: avg test  loss 2809.91, bar  test loss 3.537, col  test loss 275.460


Epoch 105: 272batch [00:46,  5.81batch/s, loss=2.82e+3]


epoch 105: avg train loss 2808.28, bar train loss 3.496, col train loss 275.262
epoch 105: avg test  loss 2809.57, bar  test loss 3.484, col  test loss 275.390


Epoch 106: 272batch [00:46,  5.82batch/s, loss=2.8e+3] 


epoch 106: avg train loss 2807.87, bar train loss 3.492, col train loss 275.211


Epoch 107: 1batch [00:00,  5.99batch/s, loss=2.79e+3]

epoch 106: avg test  loss 2808.87, bar  test loss 3.449, col  test loss 275.346


Epoch 107: 272batch [00:46,  5.82batch/s, loss=2.76e+3]


epoch 107: avg train loss 2807.65, bar train loss 3.478, col train loss 275.216


Epoch 108: 1batch [00:00,  5.99batch/s, loss=2.8e+3]

epoch 107: avg test  loss 2809.81, bar  test loss 3.628, col  test loss 275.324


Epoch 108: 272batch [00:46,  5.83batch/s, loss=2.86e+3]


epoch 108: avg train loss 2807.41, bar train loss 3.493, col train loss 275.161


Epoch 109: 0batch [00:00, ?batch/s]

epoch 108: avg test  loss 2808.82, bar  test loss 3.455, col  test loss 275.330


Epoch 109: 272batch [00:46,  5.80batch/s, loss=2.78e+3]


epoch 109: avg train loss 2807.34, bar train loss 3.479, col train loss 275.179


Epoch 110: 0batch [00:00, ?batch/s]

epoch 109: avg test  loss 2808.57, bar  test loss 3.498, col  test loss 275.248


Epoch 110: 272batch [00:47,  5.69batch/s, loss=2.83e+3]


epoch 110: avg train loss 2807.21, bar train loss 3.475, col train loss 275.152
epoch 110: avg test  loss 2808.37, bar  test loss 3.482, col  test loss 275.279


Epoch 111: 272batch [00:46,  5.81batch/s, loss=2.76e+3]


epoch 111: avg train loss 2807.02, bar train loss 3.465, col train loss 275.137


Epoch 112: 1batch [00:00,  5.85batch/s, loss=2.78e+3]

epoch 111: avg test  loss 2808.15, bar  test loss 3.466, col  test loss 275.334


Epoch 112: 56batch [00:09,  5.76batch/s, loss=2.81e+3]


KeyboardInterrupt: 

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')