<a href="https://colab.research.google.com/github/nverchev/ExplainableVAE/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Explainable VAE


In [0]:
#@title Libraries
#%load_ext autoreload
#%autoreload 2
from google.colab import files as fdownload
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch import optim
from torch.distributions.normal import Normal
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from abc import ABCMeta,abstractmethod
from sklearn.metrics import jaccard_score,confusion_matrix,roc_auc_score

In [0]:
#@title Hyperparameters: { display-mode: "form" }
batch_size =  32#@param {type: "number"}
dim_latent=20 #@param {type: "number"}
initial_learning_rate=0.001 #@param {type: "number"}
weight_decay=0.0000001 #@param {type: "number"}
block = 'mobile' #@param ["mobile","resnet"]


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
Download = True #@param {type:"boolean"}
#prevents warnings
class ToNumpy(object):
    def __call__(self, sample):
        return np.array(sample)

transform=transforms.Compose([
                        ToNumpy(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])

pin_memory=torch.cuda.is_available()
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transform), batch_size=batch_size, shuffle=True, pin_memory=pin_memory)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,  download=True,
                   transform=transform), batch_size=batch_size, shuffle=True, pin_memory=pin_memory)

In [0]:
#@title Trainer
class Trainer(metaclass=ABCMeta):
    def __init__(self,model,version,optim,lr,weight_decay, train_loader, device,**kwargs):
        self.model=model.to(device)
        self.version=version
        self.optimizer=optim(model.parameters(), lr, weight_decay)
        self.device=device
        self.train_loader=train_loader
        self.val_loader=kwargs.get('val_loader',None)
        self.test_loader=kwargs.get('test_loader',None)
        self.losses=['loss']
        self.train_losses={}
        self.val_losses={}
        self.test_losses={}
        self.epoch=0

    def train(self,num_epoch):
        print('Version ',self.version)
        for _ in range(num_epoch):
            self.epoch+=1
            print('====> Epoch:{}'.format(self.epoch))        
            self.run_session(mode='train')
            if self.val_loader:
              self.run_session(mode='val')
        return

    def run_session(self,mode=train):
        if mode=='train':
            self.model.train()
            torch.set_grad_enabled(True)
            loader=self.train_loader
            dict_losses= self.train_losses           
        elif mode=='val':
            self.model.eval()
            torch.set_grad_enabled(False)
            loader=self.val_loader
            dict_losses= self.val_losses
        elif mode=='test':
            self.model.eval()
            torch.set_grad_enabled(False)
            loader=self.test_loader
            dict_losses= self.test_losses
            test_img,test_targets,test_output= [], [], []

        len_sess=len(loader.dataset)
        epoch_loss= {loss:0 for loss in  self.losses}
        iterable=tqdm(enumerate(loader),total=len(loader))
        for batch_idx, (img, targets) in iterable:
            
            img, targets = img.to(self.device).float(), targets.to(self.device)
            output =  self.model(img)
            batch_loss = self.loss(output, img, targets)
            for loss in self.losses:
              epoch_loss[loss]+=batch_loss[loss].item()
            if  mode=='train':
              batch_loss['loss'].backward()
              self.optimizer.step()
              self.optimizer.zero_grad()
            if  mode=='test':
              test_output.append(output)
              test_img.append(img)
              test_targets.append(targets)
            if batch_idx % (len(train_loader)//10) == 0 and mode=='train':
                iterable.set_description('Train [{:4d}/{:4d} ]\tLoss {:4f}'.format(
                     batch_idx * loader.batch_size, len_sess,batch_loss['loss'].item()))
            if batch_idx == len(train_loader)-1 and mode=='train':
                iterable.set_description('')           
        num_batch=batch_idx+1
        for loss in self.losses:
            try:
              dict_losses[loss].append(epoch_loss[loss]/ num_batch)
            except:
              dict_losses[loss]= [epoch_loss[loss]/ num_batch]
        print('Average {} losses :'.format(mode))
        for loss in self.losses:
            print('{}: {:.2f}'.format(loss,dict_losses[loss][-1]), end='\t')
        print()
        if  mode=='test':
          return  test_img,test_targets,test_output
        else:
          return

    def update_learning_rate(self,new_lr):
        for g in self.optimizer.param_groups:
            g['lr']= new_lr

    def to_numpy(tensor):
      return tensor.squeeze().detach().cpu().numpy()

    def load(self, epoch=None):
        raise NotImplementedError

    def save(self,new_version=None):
        raise NotImplementedError

    def paths(self,new_version=None):
        raise NotImplementedError

    def plot_losses(self):
        raise NotImplementedError

    @abstractmethod
    def loss(self,output, img, targets):
        pass



### Training Standard classifier on MNIST

In [0]:
#@title Classifier_nets

class MNISTclassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return {'regression': x}
  

In [0]:
#@title Train classifier

class trainer_classifier(Trainer):
    def __init__(self, model,version,block_args):
        super().__init__(model,version,**block_args)
    def loss(self, output, img, targets):
        pred=output['regression']
        return {'loss':F.cross_entropy(pred,targets)}



block_args={
    'optim':optim.Adadelta,
    'lr': 1e-3,
    'weight_decay':0,
    'train_loader':train_loader,
    'device':device,
    'test_loader': test_loader
}
classifier=MNISTclassifier()
training_classifier=trainer_classifier(classifier,'Mnistclassifier',block_args)



In [134]:
training_classifier.train(5)
training_classifier.update_learning_rate(1e-4)
training_classifier.train(10)

Version  Mnistclassifier
====> Epoch:1


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 1.90	
====> Epoch:2


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.96	
====> Epoch:3


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.68	
====> Epoch:4


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.60	
====> Epoch:5


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.55	
Version  Mnistclassifier
====> Epoch:6


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.54	
====> Epoch:7


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.53	
====> Epoch:8


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.53	
====> Epoch:9


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:10


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:11


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:12


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:13


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
====> Epoch:14


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
====> Epoch:15


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
