In [1]:
link = 'D:/users/Marko/downloads/mirna/'

# Imports

In [2]:
%load_ext tensorboard

In [3]:
import sys
#sys.path.insert(0,'/content/drive/MyDrive/Marko/master')
sys.path.insert(0, link)
import numpy as np
import matplotlib.pyplot as plt

#import tensorflow as tf

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as dist

from torch.nn import functional as F
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm
from tqdm import trange

import datetime
import os
import wandb


wandb.init(project="VAE", entity="generativemirna")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: marko5kovic (generativemirna). Use `wandb login --relogin` to force relogin


In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
DEVICE

device(type='cuda')

In [6]:
LR = 0.001
LATENT_DIM = 64
BATCH_SIZE = 16
PREWARMUP=100
WARMUP=100
BETA_KL=30

In [7]:
config = {
    'lr':LR,
    'latent_dim':LATENT_DIM,
    'batch_size':BATCH_SIZE,
    'beta':BETA_KL,
    'prewarmup':PREWARMUP,
    'warmup':WARMUP
}
wandb.config=config

In [8]:
model_folder = f"{link}/saved_models/wandb_models/{LR}_{LATENT_DIM}_{BATCH_SIZE}_{BETA_KL}_{PREWARMUP}_{WARMUP}"

In [9]:
writer = SummaryWriter(f"{model_folder}/tensorboard")

# Model Classes

In [10]:
class diva_args:

    def __init__(self, z_dim=128, d_dim=45, x_dim=7500, y_dim=2,
                 beta=1, rec = 1,
                 warmup = 1, prewarmup = 1):

        self.z_dim = z_dim
        self.d_dim = d_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        
        self.beta = beta
        self.rec = rec
        self.warmup = warmup
        self.prewarmup = prewarmup


## Dataset Class

In [11]:
class MicroRNADataset(Dataset):

    def __init__(self, ds='train', create_encodings=False, use_subset=False):
        
        # loading images
        self.images = np.load(f'{link}/data/modmirbase_{ds}_images.npz')['arr_0']/255
        
        
        if create_encodings:
            x_cat = self.get_encoded_values(self.images, ds)
        else:
            x_cat = np.load(f'{link}/data/modmirbase_{ds}_images_cattt.npz')['arr_0']
        #self.images_cat = np.load(f'{link}/data/modmirbase_{ds}_images_cat_new.npz')
        
        self.images_cat = x_cat
        
        # loading labels
        print('Loading Labels! (~10s)')     
        ohe = OneHotEncoder(categories='auto', sparse=False)
        labels = np.load(f'{link}/data/modmirbase_{ds}_labels.npz')['arr_0']
        self.labels = ohe.fit_transform(labels)
        
        
        self.mountain = np.load(f'{link}/data/modmirbase_{ds}_mountain.npy')
        
        
        # loading names
        print('Loading Names! (~5s)')
        names =  np.load(f'{link}/data/modmirbase_{ds}_names.npz')['arr_0']
        names = [i.decode('utf-8') for i in names]
        self.species = ['mmu', 'prd', 'hsa', 'ptr', 'efu', 'cbn', 'gma', 'pma',
                        'cel', 'gga', 'ipu', 'ptc', 'mdo', 'cgr', 'bta', 'cin', 
                        'ppy', 'ssc', 'ath', 'cfa', 'osa', 'mtr', 'gra', 'mml',
                        'stu', 'bdi', 'rno', 'oan', 'dre', 'aca', 'eca', 'chi',
                        'bmo', 'ggo', 'aly', 'dps', 'mdm', 'ame', 'ppc', 'ssa',
                        'ppt', 'tca', 'dme', 'sbi']
        # assigning a species label to each observation from species
        # with more than 200 observations from past research
        self.names = []
        for i in names:
            append = False
            for j in self.species:
                if j in i.lower():
                    self.names.append(j)
                    append = True
                    break
            if not append:
                if 'random' in i.lower() or i.isdigit():
                    self.names.append('hsa')
                else:
                    self.names.append('notfound')
        
        # performing one hot encoding
        ohe = OneHotEncoder(categories='auto', sparse=False)
        
       
        
        self.names_ohe = ohe.fit_transform(np.array(self.names).reshape(-1,1))
            
    def __len__(self):
        return(self.images.shape[0])

    def __getitem__(self, idx):
        d = self.names_ohe[idx]
        y = self.labels[idx]
        x = self.images[idx]
        x = np.transpose(x, (2,0,1))
        x_cat = self.images_cat[idx]
        return (x_cat, y, d, x)


    def get_encoded_values(self, x, ds):
        """
        given an image or batch of images
        returns length of strand, length of bars and colors of bars
        """
        n = x.shape[0]
        x = np.transpose(x, (0,3,1,2))
        x_cat = np.zeros((n, 5, 25, 100), dtype=np.uint8)
        
        for i in range(n):
            if i % 100 == 0:
                print(f'at {i} out of {n}')
            for j in range(100):
                if (x[i,:,12,j] == np.array([1,1,1])).all():
                    break
                else:
                    # loop through all pixels of the bar
                    for k in range(25):
                        if (x[i,:,k,j] == np.array([1,1,1])).all():
                            continue
                        else:
                            x_cat[i,self.get_color(x[i,:,k,j]),k,j] = 1

        np.savez_compressed(f'{link}/data/modmirbase_{ds}_images_cattt.npz', x_cat)
        #with open(f'{link}/data/modmirbase_{ds}_images_cattt.npz', 'wb') as f:
        #    np.save(f, out_len)
        

        return x_cat

        
    
    def get_color(self, pixel):
        """
        returns the encoded value for a pixel
        """
        if (pixel == np.array([0,0,0])).all():  
            return 0 # black
        elif (pixel == np.array([1,0,0])).all():  
            return 1 # red
        elif (pixel == np.array([0,0,1])).all():  
            return 2 # blue
        elif (pixel == np.array([0,1,0])).all():  
            return 3 # green
        elif (pixel == np.array([1,1,0])).all():  
            return 4 # yellow
        else:
            print("Something wrong!")


## Decoder classes

In [38]:
# Decoders
class px(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim, dim1=256, dim2=512):
        super(px, self).__init__()

        self.fc = nn.Sequential(nn.Linear(z_dim, dim2),  
                                nn.ReLU(),
                                nn.Dropout(.1),
                                nn.Linear(dim2, dim2),
                                nn.Dropout(.4),
                                nn.ReLU())
        
        # Predicting length and color of each bar
        
        self.color = nn.Sequential(nn.Linear(dim2, 1000))
        
        
        self.length_bar = nn.Sequential(nn.Linear(dim2,2800))
        
        # monster
        self.stamp = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0],
                                   [1,0,0,0,0,0,0,0,0,0,0,0,0],
                                   [1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [1,1,1,0,0,0,0,0,0,0,0,0,0],
                                   [1,1,1,1,0,0,0,0,0,0,0,0,0],
                                   [1,1,1,1,1,0,0,0,0,0,0,0,0],
                                   [1,1,1,1,1,1,0,0,0,0,0,0,0],
                                   [1,1,1,1,1,1,1,0,0,0,0,0,0],
                                   [1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [1,1,1,1,1,1,1,1,1,0,0,0,0],
                                   [1,1,1,1,1,1,1,1,1,1,0,0,0],
                                   [1,1,1,1,1,1,1,1,1,1,1,0,0],
                                   [1,1,1,1,1,1,1,1,1,1,1,1,0],
                                   [1,1,1,1,1,1,1,1,1,1,1,1,1],
                                 ])[None,:].to(DEVICE).float()
        
        
        
    def forward(self, z):
        
        h = self.fc(z)
        
        
        len_bar = self.length_bar(h).reshape(-1,14,200)
        len_bar_ = nn.Softmax(dim=1)(len_bar)
        len_bar = len_bar_.permute(0,2,1)
        len_bar = torch.bmm(len_bar, self.stamp.repeat(len_bar.shape[0],1,1))
        #len_bar = len_bar.reshape(-1,2,100,13)
        msk2 = len_bar[:,:100,:12]
        msk1 = len_bar[:,100:,:].flip(2)
        bars = torch.cat([msk1,msk2], 2).permute(0,2,1)[:,None]
        bars2 = bars.repeat(1,5,1,1)
        
        col = nn.Softmax(dim=1)(self.color(h).reshape(-1,5,2,100))
        col_top = col[:,:,0,None,:].repeat(1,1,13,1)
        col_bot = col[:,:,1,None,:].repeat(1,1,12,1)
        color = torch.cat([col_top,col_bot],2)
        
        rna = color*bars2
        
        return rna, col, len_bar_
    
    def sample(self, color, len_bar, mean=True):
        
        if mean:
            bars = torch.argmax(len_bar, dim=1)
            col = torch.argmax(color, dim=1)
        else:
            bars = dist.Categorical(len_bar).sample()
            col = dist.Categorical(color).sample()
        out = torch.ones((color.shape[0], 25, 100 ,3))
        for i in range(color.shape[0]):
            for j in range(100):
                out[i, 13-bars[i, 100+j]:13, j] = self.get_color(col[i,0,j])
                out[i, 13: 13+bars[i,j], j] = self.get_color(col[i,1,j])
                
        return out
                
                

                
    def get_color(self, color):
        if color == 0:
            return torch.tensor([0,0,0])
        elif color == 1:
            return torch.tensor([1,0,0])
        elif color == 2:
            return torch.tensor([0,0,1])
        elif color == 3:
            return torch.tensor([0,1,0])
        elif color == 4:
            return torch.tensor([1,1,0])
        

In [39]:
pzy_ = px(45, 7500, 2, 512)
summary(pzy_, (1,512))

Layer (type:depth-idx)                   Output Shape              Param #
px                                       --                        --
├─Sequential: 1-1                        [1, 512]                  --
│    └─Linear: 2-1                       [1, 512]                  262,656
│    └─ReLU: 2-2                         [1, 512]                  --
│    └─Dropout: 2-3                      [1, 512]                  --
│    └─Linear: 2-4                       [1, 512]                  262,656
│    └─Dropout: 2-5                      [1, 512]                  --
│    └─ReLU: 2-6                         [1, 512]                  --
├─Sequential: 1-2                        [1, 2800]                 --
│    └─Linear: 2-7                       [1, 2800]                 1,436,400
├─Sequential: 1-3                        [1, 1000]                 --
│    └─Linear: 2-8                       [1, 1000]                 513,000
Total params: 2,474,712
Trainable params: 2,474,712
Non-trainab

## Endcoder Classes

In [40]:
class qz(nn.Module):
    def __init__(self, d_dim, x_dim, y_dim, z_dim, h_dim=2592):
        super(qz, self).__init__()
        self.h_dim = h_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(5, 48, kernel_size=5, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(48, 48, kernel_size=5, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(48, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.Conv2d(60, 60, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(60, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.Conv2d(72, 72, kernel_size=3, stride=1, padding = 'same'),
            nn.ReLU(), 
            nn.MaxPool2d(2, 2),
        )

        self.fc11 = nn.Sequential(nn.Linear(self.h_dim, z_dim))
        self.fc12 = nn.Sequential(nn.Linear(self.h_dim, z_dim), nn.Softplus())


    def forward(self, x):
        h = self.encoder(x)
        h = h.view(-1, self.h_dim)
        z_loc = self.fc11(h)
        z_scale = self.fc12(h) + 1e-7

        return z_loc, z_scale




In [41]:
enc = qz(128,10,10,512)
summary(enc, (1,5,25,100))

Layer (type:depth-idx)                   Output Shape              Param #
qz                                       --                        --
├─Sequential: 1-1                        [1, 72, 3, 12]            --
│    └─Conv2d: 2-1                       [1, 48, 25, 100]          6,048
│    └─ReLU: 2-2                         [1, 48, 25, 100]          --
│    └─Conv2d: 2-3                       [1, 48, 25, 100]          57,648
│    └─ReLU: 2-4                         [1, 48, 25, 100]          --
│    └─MaxPool2d: 2-5                    [1, 48, 12, 50]           --
│    └─Conv2d: 2-6                       [1, 60, 12, 50]           25,980
│    └─ReLU: 2-7                         [1, 60, 12, 50]           --
│    └─Conv2d: 2-8                       [1, 60, 12, 50]           32,460
│    └─ReLU: 2-9                         [1, 60, 12, 50]           --
│    └─MaxPool2d: 2-10                   [1, 60, 6, 25]            --
│    └─Conv2d: 2-11                      [1, 72, 6, 25]            38,

## Full model class

In [42]:
class StampDIVA(nn.Module):
    def __init__(self, args):
        super(StampDIVA, self).__init__()
        self.z_dim = args.z_dim
        self.d_dim = args.d_dim
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim

        self.px = px(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        
        self.qz = qz(self.d_dim, self.x_dim, self.y_dim, self.z_dim)
        

        self.beta = args.beta
        
        self.rec = args.rec
        self.warmup = args.warmup
        self.prewarmup = args.prewarmup

        self.cuda()

    def forward(self, d, x, y):
        # Encode
        zd_q_loc, zd_q_scale = self.qz(x)
        
        # Reparameterization trick
        qz = dist.Normal(zd_q_loc, zd_q_scale)
        z_q = qz.rsample()
        
        
        # Decode
        x_hat, color, bars = self.px(z_q)
        z_p_loc, z_p_scale = torch.zeros(z_q.size()[0], self.z_dim).cuda(),\
                        torch.ones(z_q.size()[0], self.z_dim).cuda()
        
        pz = dist.Normal(z_p_loc, z_p_scale)
        
        return x_hat, qz, pz, z_q, color ,bars

    def loss_function(self, d, x, y):
        
        x_hat, qz, pz, z_q, _, _ = self.forward(d, x, y)
       
        rec_loss = F.mse_loss(x_hat, x, reduction='sum')
        
        KL_z = torch.sum(pz.log_prob(z_q) - qz.log_prob(z_q))
          
        return self.rec * rec_loss - self.beta * KL_z, rec_loss, KL_z

# Training the model

## Loading dataset

In [43]:
RNA_dataset = MicroRNADataset(create_encodings=False)

Loading Labels! (~10s)
Loading Names! (~5s)


In [44]:
RNA_dataset_test = MicroRNADataset('test', create_encodings=False)

Loading Labels! (~10s)
Loading Names! (~5s)


In [45]:
len(RNA_dataset)

34721

In [46]:
def train_single_epoch(train_loader, model, optimizer, epoch):
    model.train()
    train_loss = 0
    no_batches = 0
    mse_t = 0
    kl_t = 0
    pbar = tqdm(enumerate(train_loader), unit="batch", 
                                     desc=f'Epoch {epoch}')
    for batch_idx, (x, y, d, _) in pbar:
        # To device
        x, y, d = x.to(DEVICE), y.to(DEVICE), d.to(DEVICE)

        optimizer.zero_grad()
        loss, mse, kl = model.loss_function(d.float(), x.float(), y.float())
      
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item()/x.shape[0])
        train_loss += loss
        mse_t += mse
        kl_t += kl
        no_batches += 1

    train_loss /= len(train_loader.dataset)
    mse_t /= len(train_loader.dataset)
    kl_t /= len(train_loader.dataset)
    
    return train_loss, mse_t, kl_t

In [47]:
def test_single_epoch(test_loader, model, epoch):
    model.eval()
    test_loss = 0
    mse_t = 0
    kl_t = 0
    with torch.no_grad():
        for batch_idx, (x,y,d,_) in enumerate(test_loader):
            x, y, d= x.to(DEVICE), y.to(DEVICE), d.to(DEVICE)
            loss , mse, kl = model.loss_function(d.float(), x.float(), y.float())
            test_loss += loss
            mse_t += mse
            kl_t += kl
    test_loss /= len(test_loader.dataset)
    mse_t /= len(test_loader.dataset)
    kl_t /= len(test_loader.dataset)
    return test_loss, mse_t, kl_t
  

In [48]:
def train(args, train_loader, test_loader, diva, optimizer, end_epoch, start_epoch=0, save_folder='sd_1.0.0',save_interval=5):
    
    
    if not os.path.exists(f"{save_folder}/checkpoints/"):
        os.makedirs(f"{save_folder}/checkpoints/")
    
    if not os.path.exists(f"{save_folder}/reconstructions/"):
        os.makedirs(f"{save_folder}/reconstructions/")
    
    wandb.watch(diva)
    epoch_loss_sup = []
    test_loss = []
    
    
    for epoch in range(start_epoch+1, end_epoch+1):
        #idx1, idx2 = choose_index(epoch)
        diva.beta = max(1,min([args.beta, args.beta * (epoch - args.prewarmup * 1.) / (args.warmup)]))
        if epoch < args.prewarmup:
            diva.beta = 1
        train_loss , mtr, _ = train_single_epoch(train_loader, diva, optimizer, epoch)
        str_loss_sup = train_loss
        epoch_loss_sup.append(train_loss)
        str_print = "epoch {}: avg train loss {:.2f}".format(epoch, str_loss_sup)
        str_print += ", ce {:.3f}".format(mtr)
        print(str_print)

        rec_loss_train = diva.rec * mtr
        dis_loss_train = train_loss - rec_loss_train

        test_lss, mte, _ = test_single_epoch(test_loader, diva, epoch)
        test_loss.append(test_lss)
       
        str_print = "epoch {}: avg test  loss {:.2f}".format(epoch, test_lss)
        str_print += ", ce {:.3f}".format(mte)
        print(str_print)

        rec_loss_test = diva.rec* mte
        dis_loss_test = test_lss - rec_loss_test

        
        wandb.log({'epoch':epoch,
                   'train_loss': train_loss,
                   'test_loss': test_lss,
                   'train_mse': mtr,
                   'test_mse': mte,
                   'train_rec': rec_loss_train,
                   'test_rec': rec_loss_test,
                   'train_kl': dis_loss_train,
                   'test_kl': dis_loss_test,
                   'diva_beta': diva.beta})
        
        if writer is not None:
            
            writer.add_scalars("Total_Loss", {'train': train_loss, 'test': test_lss} ,epoch)
            writer.add_scalars("Reconstruction_vs_Disentanglement",{'rec':rec_loss_train, 'dis':dis_loss_train}, epoch)
            writer.add_scalars("bar_mse",{'train': mtr, 'test':mte}, epoch)
           
        if epoch % save_interval == 0:
            save_reconstructions(epoch, test_loader, diva, name=save_folder)
            save_reconstructions(epoch, train_loader, diva, name=save_folder, estr='tr')
        
        
        if epoch % 50 == 0:
            torch.save(diva.state_dict(), f'{model_folder}/checkpoints/{epoch}.pth')

    if writer is not None:
        writer.flush()

    epoch_loss_sup = [i.detach().cpu().numpy() for i in epoch_loss_sup]
    test_loss = [i.detach().cpu().numpy() for i in test_loss]
    return epoch_loss_sup, test_loss

In [49]:
def save_reconstructions(epoch, test_loader, diva, name='diva', estr=''):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:10].to(DEVICE).float()
        x = a[1][0][:10].to(DEVICE).float()
        y = a[1][1][:10].to(DEVICE).float()
        x_org = a[1][-1][:10]
        x_hat ,qz, pz, z_q, color, bar = diva(d,x,y)
        
        rec = diva.px.sample(color, bar)
        
        
        #out = x_hat.round().permute(0,2,3,1)

    plt.figure(figsize=(80,20))
    fig, ax = plt.subplots(nrows=10, ncols=2)

    ax[0,0].set_title("Original")
    ax[0,1].set_title("Reconstructed")

    for i in range(10):
        ax[i, 1].imshow(rec[i].cpu())
        ax[i, 0].imshow(x_org[i].cpu().permute(1,2,0))
        ax[i, 0].xaxis.set_visible(False)
        ax[i, 0].yaxis.set_visible(False)
        ax[i, 1].xaxis.set_visible(False)
        ax[i, 1].yaxis.set_visible(False)
    fig.tight_layout(pad=0.1)
    plt.savefig(f'{model_folder}/reconstructions/e{epoch}{estr}.png')
    plt.close('all')

In [50]:
DEVICE

device(type='cuda')

## Model Training

In [51]:
default_args = diva_args(z_dim=LATENT_DIM, rec = 1, 
                         beta=BETA_KL, warmup=WARMUP, prewarmup=PREWARMUP)

In [52]:
diva = StampDIVA(default_args).to(DEVICE)

In [53]:
summary(diva,[ (1,1),(1,5,25,100),(1,1)])

Layer (type:depth-idx)                   Output Shape              Param #
StampDIVA                                --                        --
├─qz: 1-1                                [1, 64]                   --
│    └─Sequential: 2-1                   [1, 72, 3, 12]            --
│    │    └─Conv2d: 3-1                  [1, 48, 25, 100]          6,048
│    │    └─ReLU: 3-2                    [1, 48, 25, 100]          --
│    │    └─Conv2d: 3-3                  [1, 48, 25, 100]          57,648
│    │    └─ReLU: 3-4                    [1, 48, 25, 100]          --
│    │    └─MaxPool2d: 3-5               [1, 48, 12, 50]           --
│    │    └─Conv2d: 3-6                  [1, 60, 12, 50]           25,980
│    │    └─ReLU: 3-7                    [1, 60, 12, 50]           --
│    │    └─Conv2d: 3-8                  [1, 60, 12, 50]           32,460
│    │    └─ReLU: 3-9                    [1, 60, 12, 50]           --
│    │    └─MaxPool2d: 3-10              [1, 60, 6, 25]            --


In [54]:
#diva.load_state_dict(torch.load(f'{link}/saved_models/new/CIMVAE1/checkpoints/1750.pth'))

In [55]:
train_loader = DataLoader(RNA_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(RNA_dataset_test, batch_size=256)

In [56]:
#optimizer = optim.SGD(diva.parameters(), lr=0.00001, momentum=0.1, nesterov=True)
optimizer = optim.Adam(diva.parameters(), lr=LR)

In [57]:
#RNA_dataset.x_len.min(), RNA_dataset.x_len.max()

In [58]:
writer.flush()

In [62]:
%tensorboard --logdir=D:/users/Marko/downloads/mirna/saved_models/wandb_models/  --host localhost

Reusing TensorBoard on port 6006 (pid 19552), started 0:00:32 ago. (Use '!kill 19552' to kill it.)

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1000, 0, save_folder=model_folder,save_interval=25)

Epoch 1: 2171batch [00:41, 52.18batch/s, loss=559]


epoch 1: avg train loss 451.67, ce 449.246
epoch 1: avg test  loss 446.76, ce 443.628


Epoch 2: 2171batch [00:39, 54.68batch/s, loss=285]


epoch 2: avg train loss 442.45, ce 437.824


Epoch 3: 6batch [00:00, 54.55batch/s, loss=444]

epoch 2: avg test  loss 436.63, ce 430.546


Epoch 3: 2171batch [00:39, 54.68batch/s, loss=311]


epoch 3: avg train loss 436.49, ce 430.259


Epoch 4: 5batch [00:00, 45.87batch/s, loss=420]

epoch 3: avg test  loss 432.48, ce 426.032


Epoch 4: 2171batch [00:38, 56.25batch/s, loss=415]


epoch 4: avg train loss 433.82, ce 426.970


Epoch 5: 6batch [00:00, 53.57batch/s, loss=396]

epoch 4: avg test  loss 430.31, ce 423.184


Epoch 5: 2171batch [00:39, 55.54batch/s, loss=348]


epoch 5: avg train loss 432.43, ce 425.253


Epoch 6: 6batch [00:00, 56.08batch/s, loss=368]

epoch 5: avg test  loss 429.12, ce 421.915


Epoch 6: 2171batch [00:38, 55.98batch/s, loss=558]


epoch 6: avg train loss 431.12, ce 423.559


Epoch 7: 6batch [00:00, 51.72batch/s, loss=430]

epoch 6: avg test  loss 427.03, ce 418.992


Epoch 7: 2171batch [00:38, 56.13batch/s, loss=350]


epoch 7: avg train loss 429.05, ce 420.919


Epoch 8: 6batch [00:00, 55.05batch/s, loss=427]

epoch 7: avg test  loss 425.25, ce 417.131


Epoch 8: 2171batch [00:38, 55.85batch/s, loss=219]


epoch 8: avg train loss 427.61, ce 419.132


Epoch 9: 6batch [00:00, 56.08batch/s, loss=419]

epoch 8: avg test  loss 424.00, ce 415.494


Epoch 9: 2171batch [00:38, 56.84batch/s, loss=318]


epoch 9: avg train loss 426.74, ce 418.057


Epoch 10: 6batch [00:00, 54.05batch/s, loss=416]

epoch 9: avg test  loss 422.97, ce 414.407


Epoch 10: 2171batch [00:38, 56.41batch/s, loss=208]


epoch 10: avg train loss 426.08, ce 417.194


Epoch 11: 6batch [00:00, 53.57batch/s, loss=431]

epoch 10: avg test  loss 421.93, ce 413.026


Epoch 11: 2171batch [00:38, 56.41batch/s, loss=199]


epoch 11: avg train loss 425.46, ce 416.406


Epoch 12: 6batch [00:00, 54.55batch/s, loss=368]

epoch 11: avg test  loss 421.45, ce 412.127


Epoch 12: 2171batch [00:38, 55.97batch/s, loss=234]


epoch 12: avg train loss 425.01, ce 415.856


Epoch 13: 6batch [00:00, 57.69batch/s, loss=430]

epoch 12: avg test  loss 420.94, ce 411.710


Epoch 13: 2171batch [00:38, 55.96batch/s, loss=361]


epoch 13: avg train loss 424.70, ce 415.497


Epoch 14: 6batch [00:00, 51.28batch/s, loss=420]

epoch 13: avg test  loss 420.86, ce 411.464


Epoch 14: 2171batch [00:39, 55.58batch/s, loss=399]


epoch 14: avg train loss 424.41, ce 415.102
epoch 14: avg test  loss 420.15, ce 410.604


Epoch 15: 2171batch [00:39, 55.41batch/s, loss=466]


epoch 15: avg train loss 424.15, ce 414.715


Epoch 16: 6batch [00:00, 52.63batch/s, loss=470]

epoch 15: avg test  loss 420.85, ce 412.049


Epoch 16: 2171batch [00:39, 55.03batch/s, loss=395]


epoch 16: avg train loss 423.80, ce 414.302


Epoch 17: 6batch [00:00, 53.57batch/s, loss=405]

epoch 16: avg test  loss 419.96, ce 410.493


Epoch 17: 2171batch [00:39, 54.90batch/s, loss=423]


epoch 17: avg train loss 423.60, ce 414.015


Epoch 18: 6batch [00:00, 53.57batch/s, loss=431]

epoch 17: avg test  loss 420.09, ce 410.623


Epoch 18: 2171batch [00:39, 54.62batch/s, loss=436]


epoch 18: avg train loss 423.44, ce 413.795


Epoch 19: 6batch [00:00, 52.17batch/s, loss=433]

epoch 18: avg test  loss 419.36, ce 409.950


Epoch 19: 2171batch [00:39, 54.38batch/s, loss=558]


epoch 19: avg train loss 423.10, ce 413.352


Epoch 20: 6batch [00:00, 52.17batch/s, loss=400]

epoch 19: avg test  loss 420.45, ce 411.490


Epoch 20: 2171batch [00:40, 54.06batch/s, loss=399]


epoch 20: avg train loss 422.91, ce 413.054


Epoch 21: 6batch [00:00, 52.63batch/s, loss=430]

epoch 20: avg test  loss 419.05, ce 408.934


Epoch 21: 2171batch [00:40, 53.84batch/s, loss=416]


epoch 21: avg train loss 422.64, ce 412.703


Epoch 22: 6batch [00:00, 53.10batch/s, loss=432]

epoch 21: avg test  loss 418.76, ce 408.313


Epoch 22: 2171batch [00:40, 53.52batch/s, loss=512]


epoch 22: avg train loss 422.37, ce 412.431


Epoch 23: 6batch [00:00, 52.17batch/s, loss=395]

epoch 22: avg test  loss 418.36, ce 408.296


Epoch 23: 2171batch [00:40, 53.12batch/s, loss=464]


epoch 23: avg train loss 422.20, ce 412.130


Epoch 24: 6batch [00:00, 53.57batch/s, loss=385]

epoch 23: avg test  loss 418.72, ce 408.788


Epoch 24: 2171batch [00:41, 52.76batch/s, loss=347]


epoch 24: avg train loss 422.09, ce 411.934


Epoch 25: 5batch [00:00, 46.73batch/s, loss=369]

epoch 24: avg test  loss 418.59, ce 408.539


Epoch 25: 2171batch [00:41, 51.72batch/s, loss=328]


epoch 25: avg train loss 422.02, ce 411.866
epoch 25: avg test  loss 417.83, ce 407.750


Epoch 26: 2171batch [00:45, 48.17batch/s, loss=473]


epoch 26: avg train loss 421.68, ce 411.498


Epoch 27: 5batch [00:00, 49.51batch/s, loss=425]

epoch 26: avg test  loss 418.98, ce 409.303


Epoch 27: 2171batch [00:43, 50.12batch/s, loss=514]


epoch 27: avg train loss 421.59, ce 411.375


Epoch 28: 5batch [00:00, 49.02batch/s, loss=407]

epoch 27: avg test  loss 417.45, ce 407.345


Epoch 28: 2171batch [00:43, 50.13batch/s, loss=261]


epoch 28: avg train loss 421.46, ce 411.191


Epoch 29: 5batch [00:00, 45.46batch/s, loss=428]

epoch 28: avg test  loss 417.80, ce 407.755


Epoch 29: 2171batch [00:43, 49.85batch/s, loss=438]


epoch 29: avg train loss 421.39, ce 411.082


Epoch 30: 5batch [00:00, 46.73batch/s, loss=489]

epoch 29: avg test  loss 417.87, ce 407.051


Epoch 30: 2171batch [00:44, 48.78batch/s, loss=490]


epoch 30: avg train loss 421.30, ce 410.930


Epoch 31: 5batch [00:00, 46.30batch/s, loss=416]

epoch 30: avg test  loss 417.40, ce 407.024


Epoch 31: 2171batch [00:44, 48.25batch/s, loss=659]


epoch 31: avg train loss 421.19, ce 410.819


Epoch 32: 6batch [00:00, 52.63batch/s, loss=408]

epoch 31: avg test  loss 417.57, ce 407.087


Epoch 32: 2171batch [00:45, 47.39batch/s, loss=398]


epoch 32: avg train loss 421.11, ce 410.711


Epoch 33: 5batch [00:00, 49.51batch/s, loss=422]

epoch 32: avg test  loss 417.29, ce 407.082


Epoch 33: 2171batch [00:44, 48.69batch/s, loss=384]


epoch 33: avg train loss 420.96, ce 410.549


Epoch 34: 4batch [00:00, 35.71batch/s, loss=392]

epoch 33: avg test  loss 416.94, ce 406.352


Epoch 34: 2171batch [00:45, 48.18batch/s, loss=753]


epoch 34: avg train loss 420.96, ce 410.487


Epoch 35: 6batch [00:00, 50.85batch/s, loss=375]

epoch 34: avg test  loss 416.96, ce 406.643


Epoch 35: 2171batch [00:49, 44.06batch/s, loss=425]


epoch 35: avg train loss 420.71, ce 410.156


Epoch 36: 5batch [00:00, 47.62batch/s, loss=417]

epoch 35: avg test  loss 417.98, ce 407.440


Epoch 36: 317batch [00:06, 45.55batch/s, loss=426]

In [None]:
lss2, lss_t2 = train(default_args, train_loader, test_loader, diva, optimizer, 2000, 1000, save_folder="new/CIMVAE1", save_interval=25)

In [None]:
lss, lss_t = train(default_args, train_loader, test_loader, diva, optimizer, 1600, 1000, save_folder="VAEFC")

In [None]:
                


def get_color( color):
    if color == 0:
        return torch.tensor([0,0,0])
    elif color == 1:
        return torch.tensor([1,0,0])
    elif color == 2:
        return torch.tensor([0,0,1])
    elif color == 3:
        return torch.tensor([0,1,0])
    elif color == 4:
        return torch.tensor([1,1,0])

In [None]:
a = next(enumerate(train_loader))
with torch.no_grad():
    diva.eval()
    d = a[1][2][:10].to(DEVICE).float()
    x = a[1][0][:10].to(DEVICE).float()
    y = a[1][1][:10].to(DEVICE).float()
    x_org = a[1][-1][:10]
    x_hat ,qz, pz, z_q, color, bar = diva(d,x,y)

#x_hat[x_hat < 0.5] = 0

In [None]:
bar.shape

bar[0,:,1]
bar[0,:,0]

In [None]:
x_hat[0,:,:,1]

In [None]:
out = np.zeros((10,25,100,3))
for i in range(10):
    for j in range(100):
        for k in range(25):
            if torch.max(x_hat[i,:,k,j]) < 0.5:
                break
            else: 
                col = torch.argmax(x_hat[i,:,k,j])
                out[i,k,j] = get_color(col)


In [None]:
plt.imshow(out[0])

In [None]:
def plot_loss_acc(lss, lss_t):
    fig,ax = plt.subplots()
    ax.plot(lss, label="train loss")
    ax.plot(lss_t, label = "test loss")
    #ax1 = ax.twinx()
    #ax1.plot(yacc, label = "train accuracy", ls='--')
    #ax1.plot(yacc_t, label = "test accuracy", ls='--')

    lines, labels = ax.get_legend_handles_labels()
    #lines2, labels2 = ax1.get_legend_handles_labels()

    ax.legend(lines, labels)

In [None]:
plot_loss_acc(lss, lss_t)

In [None]:
plot_loss_acc(lss3, lss_t3, yacc3, yacc_t3)

In [None]:
def plot_change_latent_var(diva, lat_space="y", var_idx=[0,1,2,3,4,5,6,7], step = 5):
    a = next(enumerate(test_loader))
    with torch.no_grad():
        diva.eval()
        d = a[1][2][:len(var_idx)].to(DEVICE).float()
        x = a[1][0][:len(var_idx)].to(DEVICE).float()
        y = a[1][1][:len(var_idx)].to(DEVICE).float()

        zx, zx_sc = diva.qzx(x)
        zy, zy_sc = diva.qzy(x)
        zd, zd_sc =  diva.qzd(x)

        print(torch.max(zy), torch.min(zy), "sdmax:", torch.max(zy_sc))

        out = change(zx, zy, zd, var_idx, lat_space, diva, step)
    
    fig, ax = plt.subplots(ncols=out.shape[0],nrows=len(var_idx),figsize=(10*4*out.shape[0],10*len(var_idx)))
    for i in range(out.shape[0]):
      for j in range(len(var_idx)):
        ax[j,i].imshow(out[i,j])

In [None]:
def change(zx, zy, zd, idx, lat = "y", model=diva, step = 2):
    
    dif = np.arange(-30,15,step)
    print(torch.max(zy), torch.min(zy))
    out = np.zeros((dif.shape[0], len(idx), 25, 100 ,3))  
    #print(zy.shape, dif.shape[0])
    for i in range(dif.shape[0]):
      for j in range(len(idx)):
        if lat == "y":
            zy[j,idx] = dif[i]
        elif lat == "x":
            zx[j,idx] = dif[i]
        elif lat == "d":
            zd[j,idx] = dif[i]
        len_, bar, col = model.px(zd[j],zx[j],zy[j])
        out[i,j] = model.px.reconstruct_image(len_[None,:], bar, col)
    
    return out



In [None]:
plot_change_latent_var(diva)

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss2], label="train loss")
ax.plot(np.arange(50,120), [i.cpu().detach().numpy() for i in lss_t2], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(50,120), yacc2, label = "train")
ax1.plot(np.arange(50,120), yacc_t2, label = "test")

plt.legend()

In [None]:
fig,ax = plt.subplots()
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss3], label="train loss")
ax.plot(np.arange(120,180), [i.cpu().detach().numpy() for i in lss_t3], label = "testloss")
ax1 = ax.twinx()
ax1.plot(np.arange(120,180), yacc3, label = "train",c='green')
ax1.plot(np.arange(120,180), yacc_t3, label = "test")

plt.legend()

# Model Evaluation

## Sampling from trained model

In [None]:
def plot_latent_space(lat_space="y"):
    '''
    lat_space: y, d, x
    '''

    

In [None]:
plot(x, out, 0)

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=3)
for i in range(9):
  ax[i//3, i%3].imshow(x[i].cpu().permute(1,2,0))
  
plt.savefig('divastamporg.png')