<a href="https://colab.research.google.com/github/nverchev/ExplainableVAE/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Explainable VAE


In [0]:
#@title Libraries
#%load_ext autoreload
#%autoreload 2
from google.colab import files as fdownload
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch import optim
from torch.distributions.normal import Normal
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from abc import ABCMeta,abstractmethod
from sklearn.metrics import jaccard_score,confusion_matrix,roc_auc_score

In [0]:
#@title Hyperparameters: { display-mode: "form" }
batch_size =  32#@param {type: "number"}
dim_latent=20 #@param {type: "number"}
initial_learning_rate=0.001 #@param {type: "number"}
weight_decay=0.0000001 #@param {type: "number"}
block = 'mobile' #@param ["mobile","resnet"]


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
Download = True #@param {type:"boolean"}
transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])

pin_memory=torch.cuda.is_available()
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transform), batch_size=batch_size, shuffle=True, pin_memory=pin_memory)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,  download=True,
                   transform=transform), batch_size=batch_size, shuffle=True, pin_memory=pin_memory)

In [0]:
#@title Trainer
class Trainer(metaclass=ABCMeta):
    def __init__(self,model,version,optim,lr,weight_decay, train_loader, device,**kwargs):
        self.model=model.to(device)
        self.version=version
        self.optimizer=optim(model.parameters(), lr, weight_decay)
        self.device=device
        self.train_loader=train_loader
        self.val_loader=kwargs.get('val_loader',None)
        self.test_loader=kwargs.get('test_loader',None)
        self.losses=['loss']
        self.train_losses={}
        self.val_losses={}
        self.test_losses={}
        self.epoch=0

    def train(self,num_epoch):
        print('Version ',self.version)
        for _ in range(num_epoch):
            self.epoch+=1
            print('====> Epoch:{}'.format(self.epoch))        
            self.run_session(mode='train')
            if self.val_loader:
              self.run_session(mode='val')
        return


    def auxiliary_inputs(self,img,labels):
      return {}

    def run_session(self,mode=train):
        if mode=='train':
            self.model.train()
            torch.set_grad_enabled(True)
            loader=self.train_loader
            dict_losses= self.train_losses           
        elif mode=='val':
            self.model.eval()
            torch.set_grad_enabled(False)
            loader=self.val_loader
            dict_losses= self.val_losses
        elif mode=='test':
            self.model.eval()
            torch.set_grad_enabled(False)
            loader=self.test_loader
            dict_losses= self.test_losses
            test_img,test_targets,test_outputs= [], [], {}
        else:
          raise ValueError('mode options are "train", "val", "test" ')
        
        len_sess=len(loader.dataset)
        epoch_loss= {loss:0 for loss in  self.losses}
        iterable=tqdm(enumerate(loader),total=len(loader))
        for batch_idx, (img, targets) in iterable:
            
            img, targets = img.to(self.device).float(), targets.to(self.device)
            aux=self.auxiliary_inputs(img,targets)
            outputs =  self.model(img,**aux)
            batch_loss = self.loss(outputs, img, targets)
            for loss in self.losses:
              epoch_loss[loss]+=batch_loss[loss].item()
            if  mode=='train':
              batch_loss['loss'].backward()
              self.optimizer.step()
              self.optimizer.zero_grad()
            if  mode=='test':
              for key,value in outputs.items():
                if key not in test_outputs.keys():
                  test_outputs[key]=[]
                test_outputs[key].extend(to_numpy(value))
              test_img.extend(to_numpy(img))
              test_targets.extend(to_numpy(targets))
            if batch_idx % (len(train_loader)//10) == 0 and mode=='train':
                iterable.set_description('Train [{:4d}/{:4d} ]\tLoss {:4f}'.format(
                     batch_idx * loader.batch_size, len_sess,batch_loss['loss'].item()))
            if batch_idx == len(train_loader)-1 and mode=='train':
                iterable.set_description('')           
        num_batch=batch_idx+1
        for loss in self.losses:
            if loss not in dict_losses.keys():
              dict_losses[loss]=[]
            dict_losses[loss].append(epoch_loss[loss]/ num_batch)
        print('Average {} losses :'.format(mode))
        for loss in self.losses:
            print('{}: {:.2f}'.format(loss,dict_losses[loss][-1]), end='\t')
        print()
        if  mode=='test':
          return  test_img,test_targets,test_outputs
        else:
          return

    def update_learning_rate(self,new_lr):
        for g in self.optimizer.param_groups:
            g['lr']= new_lr

    def to_numpy(tensor):
      return tensor.squeeze().detach().cpu().numpy()

    def load(self, epoch=None):
        raise NotImplementedError

    def save(self,new_version=None):
        raise NotImplementedError

    def paths(self,new_version=None):
        raise NotImplementedError

    def plot_losses(self):
        raise NotImplementedError

    @abstractmethod
    def loss(self,output, img, targets):
        pass



### Training Standard classifier on MNIST

In [0]:
#@title Classifier_nets

class MNISTclassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return {'regression': x}
classifier=MNISTclassifier()


In [0]:
#@title Train classifier

class trainer_classifier(Trainer):
    def __init__(self, model,version,block_args):
        super().__init__(model,version,**block_args)
    def loss(self, output, img, targets):
        pred=output['regression']
        return {'loss':F.cross_entropy(pred,targets)}



block_args={
    'optim':optim.Adadelta,
    'lr': 1e-3,
    'weight_decay':0,
    'train_loader':train_loader,
    'device':device,
    'test_loader': test_loader
}
training_classifier=trainer_classifier(classifier,'Mnistclassifier',block_args)



In [43]:
training_classifier.train(5)
training_classifier.update_learning_rate(1e-4)
training_classifier.train(10)

Version  Mnistclassifier
====> Epoch:1


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 1.92	
====> Epoch:2


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.95	
====> Epoch:3


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.68	
====> Epoch:4


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.59	
====> Epoch:5


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.55	
Version  Mnistclassifier
====> Epoch:6


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.53	
====> Epoch:7


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.53	
====> Epoch:8


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:9


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:10


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:11


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.52	
====> Epoch:12


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
====> Epoch:13


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
====> Epoch:14


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.51	
====> Epoch:15


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.50	


In [62]:
test_img,test_targets,test_output= training_classifier.run_session(mode='test')
pixels=test_img[5]
plt.imshow(pixels, cmap='gray')
plt.show()
pred=np.array(test_output['regression']).argmax(axis=1)
accuracy=(pred==test_targets).sum()/len(pred)
print('Accuracy : {}'.format(accuracy))

HBox(children=(FloatProgress(value=0.0, max=313.0), HTML(value='')))


Average test losses :
loss: 0.35	


In [0]:

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape=shape   
    def forward(self, input):
        return input.view(*self.shape)



class VAE(nn.Module):
    def __init__(self, x_dim, h_dims, k_dims, strides, z_dim):
        super(VAE, self).__init__()

        # encoder part
        self.fc0 = nn.Linear(10, x_dim)
        self.view = View([-1,1]+[int(np.sqrt(x_dim))]*2)
        self.conv1 = nn.Conv2d(2, h_dims[0], k_dims[0],stride=strides[0])
        self.conv2 = nn.Conv2d(h_dims[0], h_dims[1], k_dims[1],stride=strides[1])
        self.conv3 = nn.Conv2d(h_dims[1], h_dims[2], k_dims[2],stride=strides[2])
        self.fc1 = nn.Linear(h_dims[2], 2*z_dim)

        # decoder part
        self.fc2 = nn.Linear(z_dim, h_dims[2])
        self.convt3=nn.ConvTranspose2d(h_dims[2], h_dims[1], k_dims[2],stride=strides[2])
        self.convt2=nn.ConvTranspose2d(h_dims[1], h_dims[0], k_dims[1],stride=strides[1])
        self.convt1=nn.ConvTranspose2d(h_dims[0], h_dims[0], k_dims[0],stride=strides[0])
        self.convt0 = nn.ConvTranspose2d(h_dims[0], 1, 7)
        

    def encoder(self, x, regression):
        x=torch.cat([x,self.view(self.fc0(regression))],dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.flatten(1)
        return self.fc1(x).chunk(2,dim=1) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample

    def decoder(self, x):
        x = F.relu(self.fc2(x))
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = F.relu(self.convt3(x))
        x = F.relu(self.convt2(x))
        x = F.relu(self.convt1(x))
        return torch.sigmoid(self.convt0(x)) 

    def forward(self, x, regression):
        mu, log_var = self.encoder(x, regression)
        z = self.sampling(mu, log_var)
        outputs={
            'recon':self.decoder(z),
            'mu':mu,
            'log_var':log_var
        }
        return outputs
x_dim=784
z_dim=dim_latent
h_dims=[16,32,64]
k_dims=[4,4,4]
strides=[2,2,2]

vae = VAE(x_dim=x_dim, h_dims= h_dims,k_dims=k_dims, strides=strides, z_dim=z_dim)


In [0]:
#@title Train classifier

class vae_classifier_trainer(Trainer):
    def __init__(self, model,version,classifier,block_args):
        self.classifier=classifier
        super().__init__(model,version,**block_args)
    def loss(self, outputs, img, targets):
        recon=outputs['recon']
        return {'loss':F.binary_cross_entropy(outputs['recon'],img)}
    def auxiliary_inputs(self,img,labels):
      with torch.no_grad():
        self.classifier.eval()
        output=self.classifier(img)
      return output


block_args={
    'optim':optim.Adadelta,
    'lr': 1e-3,
    'weight_decay':0,
    'train_loader':train_loader,
    'device':device,
    'test_loader': test_loader,
    
}
vae_classifier_training=vae_classifier_trainer(vae,'VAEexplanability',classifier,block_args)

In [0]:
vae_classifier_training.train(6)

Version  VAEexplanability
====> Epoch:1


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.64	
====> Epoch:2


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))


Average train losses :
loss: 0.55	
====> Epoch:3


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))