# Paper Inspiring EVAE-Net
https://www.mdpi.com/2075-4418/12/11/2569

Implementation has been simplified and adapted slgihtly.

# Organizing Data #

In [None]:
from google.colab import drive
import sys
import os

drive.mount('/content/drive/')

path_to_utils='/content/drive/MyDrive/Colab Notebooks/healthcare_data' # CHECK PATH.
sys.path.append(path_to_utils)
os.chdir(path_to_utils)

In [None]:
print(os.getcwd())

In [None]:
from google.colab import auth
auth.authenticate_user()

# Advanced Models: Modified EVAE-Net

In [None]:
# imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms, utils
import torchvision.models as models
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings("ignore")

In [None]:
# import tensorflow_probability as tfp
# import tensorflow as tf

# def mmd_loss(source_features, target_features):
#     rbf_kernel = tfp.math.psd_kernels.ExponentiatedQuadratic()
#     loss = tfp.stats.maximum_mean_discrepancy(source_features, target_features, kernel=rbf_kernel)
#     return loss

In [None]:
class EVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(EVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes # 4 classes
        self.conv_transpose = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=14, stride=14, padding=0)
        
        # can use torch ResNet + VGG or FT on data

        # Define ResNet50 Encoder
        # resnet = models.resnet18(pretrained=True)
        # resnet_layers = list(resnet.children())[:-1]  # Remove last layer (classification head)
        resnet = models.resnet18(weights=False)
        resnet.fc = torch.nn.Linear(in_features = 512, out_features = 4)
        resnet.load_state_dict(torch.load("/content/drive/MyDrive/finetuned_resnet.pth")) # CHECK PATH.
        resnet_layers = list(resnet.children())[:-1]
        self.resnet_encoder = nn.Sequential(*resnet_layers)
        
        # Define VGG16 Encoder
        vgg16 = models.vgg16(weights=False)
        # vgg16_layers = list(vgg16.features.children())[:-1]  # Remove last layer (max pooling)
        # vgg16 = models.vgg16(pretrained=True)
        vgg16.fc = torch.nn.Linear(in_features = 512, out_features = 4)
        vgg16.load_state_dict(torch.load("/content/drive/MyDrive/finetuned_vgg.pth")) # CHECK PATH.
        vgg16_layers = list(vgg16.features.children())[:-1]
        self.vgg16_encoder = nn.Sequential(*vgg16_layers)

        # Define reparameterization layers
        self.fc0 = nn.Linear(100864, latent_dim)
        self.fc1 = nn.Linear(latent_dim, 512)
        self.fc2 = nn.Linear(512, latent_dim*2)

        # Define classification head
        self.classification_head = nn.Linear(latent_dim, num_classes)
        
        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
    
    def decode(self, z):
        x_hat = self.decoder(z.unsqueeze(-1).unsqueeze(-1))
        return x_hat

    def encode(self, x):
        # encode
        resnet_features = self.resnet_encoder(x)
        vgg16_features = self.vgg16_encoder(x)

        # flatten the features and concatenate them
        features = torch.cat([resnet_features.view(x.size(0), -1), 
                              vgg16_features.view(x.size(0), -1)], dim=1)
        
        # apply reparameterization
        x = F.relu(self.fc0(features))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        mu, log_var = torch.chunk(x, 2, dim=-1)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return z, mu, log_var

    def forward(self, x):
          z, mu, log_var = self.encode(x)
          x_hat = self.decode(z)
          y = self.classification_head(z)
          return x_hat, y, mu, log_var
    
    def loss_function(self, x_hat, x, y, target, mu, log_var):
        # num_pixels = x.shape[1] * x.shape[2] * x.shape[3]

        # Upsample to get x and x_hat pixels matching
        # x_hat_upsampled = F.interpolate(x_hat, size=x.shape[2:], mode='bilinear') # align_corners=False
        x_hat_upsampled = self.conv_transpose(x_hat)
      
        # Compute reconstruction loss
        recons_loss = F.mse_loss(x_hat_upsampled, x, reduction='sum')
        # Compute kld loss
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        Lcls = F.cross_entropy(y, target.argmax(dim=1))
        
        return recons_loss, kld_loss, Lcls

# Training EVAE

In [None]:
!pwd

In [None]:
# check data we are using
import os

directory = '/content/drive/MyDrive/Colab Notebooks/healthcare_data/train_data2/class2'
extension = '.png'

num_files = len([f for f in os.listdir(directory) if f.endswith(extension)])

print(f"There are {num_files} {extension} files in {directory}")


In [None]:
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms # need to adapt image format


# Access train data
train_dir = './train_data2' # CHECK PATH.
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
train_data = datasets.ImageFolder(train_dir, transform=transform)

In [None]:
# Dataloader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)

In [None]:
len(train_data)

In [None]:
from tqdm import tqdm
# Training loop

losses = []
def train(model, optimizer, train_loader, device):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        recon_batch, y, mu, log_var = model(data)
        target_onehot = F.one_hot(target, num_classes=4).float() 
        mse, kld, Lcls = model.loss_function(recon_batch, data, y, target_onehot, mu, log_var)
        loss = mse + kld + Lcls
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    losses.append(train_loss)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EVAE(latent_dim=256, num_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.00003) # lr can be adjusted
epochs = 20 # epochs can be adjusted

for epoch in range(epochs):
    train(model, optimizer, train_loader, device)

# EVAE Evaluation


In [None]:
# Access eval data
test_dir = './test_data2' # CHECK PATH.
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
test_data = datasets.ImageFolder(test_dir, transform=transform)

In [None]:
test_loader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=True)

In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score

def get_results(model, loader):
    model.eval()
    y_true = []
    y_pred = []
    
    y_true_int = []
    y_pred_int = []
    with torch.no_grad():
        for data, target in tqdm(loader):
            data = data.to(device)
            target = target.to(device)
            _, y_hat, _, _ = model(data)
            y_hat = torch.softmax(y_hat, dim=1)
            y_true.extend(target.tolist())
            y_pred.extend(y_hat.tolist())

    y_true = F.one_hot(torch.tensor(y_true)).numpy()
    y_pred = np.array(y_pred)
    auc = roc_auc_score(y_true, y_pred, multi_class='ovr')
    acc = accuracy_score(y_true.argmax(axis=1), y_pred.argmax(axis=1))

    return y_true, y_pred

In [None]:
y_true, y_pred = get_results(model, test_loader)

# Inference on Test Set

In [None]:
# convert format of get_results output
y_true_int = []

for ele in y_true:
    class_label = ele.tolist().index(max(ele.tolist()))
    y_true_int.append(class_label)

In [None]:
y_pred_int = []

for ele in y_pred:
    class_label = ele.tolist().index(max(ele.tolist()))
    y_pred_int.append(class_label)

In [None]:
# Create function to calculate multiclass AUC
def multiclass_metrics(test, pred, average="macro"):
    
    # Create set of unique classes
    unique = set(test)
    auc_dict = {}
    acc_dict = {}
    
    # Loop through each class
    for class_i in unique:
        
        # Create list of classes other than class_i
        other_class = [x for x in unique if x != class_i]

        # Get test / prediction values for each class
        new_test = [0 if x in other_class else 1 for x in test]
        new_pred = [0 if x in other_class else 1 for x in pred]
        
        #print(accuracy_score(new_test, new_pred))
        accuracy = accuracy_score(new_test, new_pred)
        acc_dict[class_i] = accuracy

        # Calculate AUC, add to dictionary
        auc = roc_auc_score(new_test, new_pred, average = average)
        auc_dict[class_i] = auc
        
    return acc_dict, auc_dict

In [None]:
multi_acc, multi_auc = multiclass_metrics(y_true_int, y_pred_int, average="macro")

print('Multiclass AUC scores:')
print(multi_auc)

print('Multiclass accuracy scores:')
print(multi_acc)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
overall_accuracy = accuracy_score(y_true_int, y_pred_int)
overall_precision = precision_score(y_true_int, y_pred_int, average="macro")
overall_recall = recall_score(y_true_int, y_pred_int, average="macro")
overall_f1 = f1_score(y_true_int, y_pred_int, average="macro")

print("Overall Accuracy: ", overall_accuracy)
print("Overall Precision: ", overall_precision)
print("Overall Recall: ", overall_recall)
print("Overall F1: ", overall_f1)