# Paper Inspiring EVAE-Net
https://www.mdpi.com/2075-4418/12/11/2569

Implementation has been simplified and adapted slgihtly.

# Organizing Data #

In [93]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


Don't need to re-run -- Data creation

In [None]:
# from sklearn.model_selection import train_test_split
# import os
# from shutil import copyfile

# # Set paths to image folders
# class1_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Viral_Pneumonia/images'
# class2_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Normal/images'
# class3_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Lung_Opacity/images'
# class4_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/COVID/images'

# # Set paths to output directories
# train_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/train'
# val_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/val'
# test_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/test'

# # Create output directories
# os.makedirs(train_dir, exist_ok=True)
# os.makedirs(val_dir, exist_ok=True)
# os.makedirs(test_dir, exist_ok=True)

# # Split images into train, validation, and test sets
# for class_dir, class_name in zip([class1_dir, class2_dir, class3_dir, class4_dir], ['class1', 'class2', 'class3', 'class4']):
#     image_files = os.listdir(class_dir)
#     train_files, test_files = train_test_split(image_files, test_size=0.1, random_state=42)
#     train_files, val_files = train_test_split(train_files, test_size=0.25, random_state=42)

#     # Copy train images to train folder
#     for file_name in train_files:
#         src_path = os.path.join(class_dir, file_name)
#         dst_path = os.path.join(train_dir, class_name, file_name)
#         os.makedirs(os.path.dirname(dst_path), exist_ok=True)
#         copyfile(src_path, dst_path)

#     # Copy validation images to validation folder
#     for file_name in val_files:
#         src_path = os.path.join(class_dir, file_name)
#         dst_path = os.path.join(val_dir, class_name, file_name)
#         os.makedirs(os.path.dirname(dst_path), exist_ok=True)
#         copyfile(src_path, dst_path)

#     # Copy test images to test folder
#     for file_name in test_files:
#         src_path = os.path.join(class_dir, file_name)
#         dst_path = os.path.join(test_dir, class_name, file_name)
#         os.makedirs(os.path.dirname(dst_path), exist_ok=True)
#         copyfile(src_path, dst_path)


# Advanced Models: Modified EVAE-Net

In [26]:
# imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms, utils
import torchvision.models as models
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings("ignore")

In [105]:
class EVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(EVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes # 4 classes
        
        # Define ResNet50 Encoder
        # resnet = models.resnet18(pretrained=True)
        # resnet_layers = list(resnet.children())[:-1]  # Remove last layer (classification head)
        resnet = models.resnet18(pretrained=True)
        resnet.fc = torch.nn.Linear(in_features = 512, out_features = 4)
        resnet.load_state_dict(torch.load("/content/drive/MyDrive/finetuned_resnet.pth"))
        resnet_layers = list(resnet.children())[:-1] 
        # resnet_layers = list(resnet.children())

        # loss_fn = torch.nn.CrossEntropyLoss()
        # optimizer = torch.optim.Adam(resnet18.parameters(), lr = 3e-5)

        self.resnet_encoder = nn.Sequential(*resnet_layers)
        
        # Define VGG16 Encoder
        vgg16 = models.vgg16(pretrained=True)
        vgg16_layers = list(vgg16.features.children())[:-1]  # Remove last layer (max pooling)
        self.vgg16_encoder = nn.Sequential(*vgg16_layers)

        # Define reparameterization layers
        self.fc0 = nn.Linear(100864, latent_dim)
        self.fc1 = nn.Linear(latent_dim, 512)
        self.fc2 = nn.Linear(512, latent_dim*2)

        # Define classification head
        self.classification_head = nn.Linear(latent_dim, num_classes)
        
        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
    
    def decode(self, z):
        x_hat = self.decoder(z.unsqueeze(-1).unsqueeze(-1))

        return x_hat

    def encode(self, x):
        # encode
        resnet_features = self.resnet_encoder(x)
        vgg16_features = self.vgg16_encoder(x)

        # flatten the features and concatenate them
        features = torch.cat([resnet_features.view(x.size(0), -1), 
                              vgg16_features.view(x.size(0), -1)], dim=1)
        
        # apply reparameterization
        x = F.relu(self.fc0(features))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        mu, log_var = torch.chunk(x, 2, dim=-1)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return z, mu, log_var

        # def forward(self, x):
    #     mu, log_var = self.encode(x)
    #     std = torch.exp(0.5 * log_var)
    #     eps = torch.randn_like(std)
    #     z = eps * std + mu
    #     x_hat = self.decode(z)
    #     y = self.classification_head(z)
    #     return x_hat, y, mu, log_var

    def forward(self, x):
          # print('x:', x.shape)
          z, mu, log_var = self.encode(x)
          x_hat = self.decode(z)
          y = self.classification_head(z)
          # print('x_hat:', x_hat.shape)
          # print('z:', z.shape)
          # print('y:', y.shape)
          return x_hat, y, mu, log_var
    
    def loss_function(self, x_hat, x, y, target, mu, log_var):
        num_pixels = x.shape[1] * x.shape[2] * x.shape[3]

        # Upsample to get x and x_hat pixels matching -- improvements I'm sure can be made here
        x_hat_upsampled = F.interpolate(x_hat, size=x.shape[2:], mode='nearest') # align_corners=False
      
        # Compute reconstruction loss
        recons_loss = F.mse_loss(x_hat_upsampled, x, reduction='sum')
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        # Lcls = F.cross_entropy(torch.argmax(y, dim=1), torch.argmax(target, dim=1))
        Lcls = F.cross_entropy(y, target.argmax(dim=1))
        
        return recons_loss, kld_loss, Lcls

# Training EVAE

In [87]:
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms # need to adapt image format


# Access train data
train_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/train'
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
train_data = datasets.ImageFolder(train_dir, transform=transform)

In [88]:
# Dataloader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)

In [106]:
from tqdm import tqdm
# Training loop

def train(model, optimizer, train_loader, device):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        recon_batch, y, mu, log_var = model(data)
        target_onehot = F.one_hot(target, num_classes=4).float() 
        mse, kld, Lcls = model.loss_function(recon_batch, data, y, target_onehot, mu, log_var)
        loss = mse + kld + Lcls
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


In [107]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EVAE(latent_dim=256, num_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
epochs = 20

for epoch in range(epochs):
    train(model, optimizer, train_loader, device)

  0%|          | 0/248 [00:00<?, ?it/s]


RuntimeError: ignored

# EVAE Validation

In [96]:
# Access eval data
val_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/val'
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
val_data = datasets.ImageFolder(val_dir, transform=transform)

In [97]:
val_loader = torch.utils.data.DataLoader(val_data, batch_size=4, shuffle=True)

In [98]:
from sklearn.metrics import accuracy_score, roc_auc_score

def get_results(model, val_loader):
    model.eval()
    y_true = []
    y_pred = []
    
    y_true_int = []
    y_pred_int = []
    with torch.no_grad():
        for data, target in tqdm(val_loader):
            data = data.to(device)
            target = target.to(device)
            _, y_hat, _, _ = model(data)
            y_hat = torch.softmax(y_hat, dim=1)
            y_true.extend(target.tolist())
            y_pred.extend(y_hat.tolist())

    y_true = F.one_hot(torch.tensor(y_true)).numpy()
    y_pred = np.array(y_pred)
    auc = roc_auc_score(y_true, y_pred, multi_class='ovr')
    acc = accuracy_score(y_true.argmax(axis=1), y_pred.argmax(axis=1))

    return y_true, y_pred

In [99]:
y_true, y_pred = get_results(model, val_loader)

100%|██████████| 275/275 [00:14<00:00, 18.91it/s]


# Inference on Test Set

# Metrics for Test Set

In [100]:
# convert format of get_results output
y_true_int = []

for ele in y_true:
    class_label = ele.tolist().index(max(ele.tolist()))
    y_true_int.append(class_label)

In [101]:
y_pred_int = []

for ele in y_pred:
    class_label = ele.tolist().index(max(ele.tolist()))
    y_pred_int.append(class_label)

In [102]:
# Create function to calculate multiclass AUC
def multiclass_auc(test, pred, average="macro"):
    
    # Create set of unique classes
    unique = set(test)
    auc_dict = {}
    
    # Loop through each class
    for class_i in unique:
        
        # Create list of classes other than class_i
        other_class = [x for x in unique if x != class_i]

        # Get test / prediction values for each class
        new_test = [0 if x in other_class else 1 for x in test]
        new_pred = [0 if x in other_class else 1 for x in pred]

        # Calculate AUC, add to dictionary
        auc = roc_auc_score(new_test, new_pred, average = average)
        auc_dict[class_i] = auc

    return auc_dict

In [103]:
auc = multiclass_auc(y_true_int, y_pred_int, average="macro")
print('AUC by Class:', auc)

AUC by Class: {0: 0.90625, 1: 0.7782692975500028, 2: 0.6993939393939393, 3: 0.7220575887785844}


In [104]:
from sklearn.metrics import accuracy_score, precision_score
overall_accuracy = accuracy_score(y_pred_int, y_true_int)
overall_precision = precision_score(y_pred_int, y_true_int, average="macro")

print("Overall Accuracy: ", overall_accuracy)
print("Overall Precision: ", overall_precision)

Overall Accuracy:  0.6672727272727272
Overall Precision:  0.6640768075598857
