# Advanced Models: EVAE-Net

In [None]:
# imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms, utils
import torchvision.models as models
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings("ignore")

In [None]:
class EVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(EVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Define ResNet50 Encoder
        resnet = models.resnet50(pretrained=True)
        resnet_layers = list(resnet.children())[:-1]  # Remove last layer (classification head)
        self.resnet_encoder = nn.Sequential(*resnet_layers)
        
        # Define VGG16 Encoder
        vgg16 = models.vgg16(pretrained=True)
        vgg16_layers = list(vgg16.features.children())[:-1]  # Remove last layer (max pooling)
        self.vgg16_encoder = nn.Sequential(*vgg16_layers)
        
        # Define reparameterization layers
        self.fc1 = nn.Linear(4096, latent_dim)  # 4096 = 2048 from ResNet50 + 2048 from VGG16
        self.fc2 = nn.Linear(4096, latent_dim)
        
        # Define classification head
        self.classification_head = nn.Linear(latent_dim, num_classes)
        
        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        resnet_features = self.resnet_encoder(x)
        vgg16_features = self.vgg16_encoder(x)
        features = torch.cat((resnet_features.view(resnet_features.size(0), -1), 
                              vgg16_features.view(vgg16_features.size(0), -1)), dim=1)
        h1 = F.relu(self.fc1(features))
        h2 = F.relu(self.fc2(features))
        return h1, h2
    
    def decode(self, z):
        x_hat = self.decoder(z.unsqueeze(-1).unsqueeze(-1))
        return x_hat
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = eps * std + mu
        x_hat = self.decode(z)
        y = self.classification_head(z)
        return x_hat, y, mu, log_var
    
    def loss_function(self, x_hat, x, y, mu, log_var):
        BCE = F.binary_cross_entropy(x_hat.view(-1, 1024*1024), x.view(-1, 1024*1024), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        Lcls = F.cross_entropy(y, target, reduction='sum')
        return BCE, KLD, Lcls

In [None]:
# Training EVAE-Net with pre-trained ResNet50 