# Organizing Data #

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


Don't need to re-run

In [2]:
from sklearn.model_selection import train_test_split
import os
from shutil import copyfile

# Set paths to image folders
class1_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Viral_Pneumonia/images'
class2_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Normal/images'
class3_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/Lung_Opacity/images'
class4_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/COVID/images'

# Set paths to output directories
train_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/train'
val_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/val'
test_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/test'

# Create output directories
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Split images into train, validation, and test sets
for class_dir, class_name in zip([class1_dir, class2_dir, class3_dir, class4_dir], ['class1', 'class2', 'class3', 'class4']):
    image_files = os.listdir(class_dir)
    train_files, test_files = train_test_split(image_files, test_size=0.1, random_state=42)
    train_files, val_files = train_test_split(train_files, test_size=0.25, random_state=42)

    # Copy train images to train folder
    for file_name in train_files:
        src_path = os.path.join(class_dir, file_name)
        dst_path = os.path.join(train_dir, class_name, file_name)
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)
        copyfile(src_path, dst_path)

    # Copy validation images to validation folder
    for file_name in val_files:
        src_path = os.path.join(class_dir, file_name)
        dst_path = os.path.join(val_dir, class_name, file_name)
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)
        copyfile(src_path, dst_path)

    # Copy test images to test folder
    for file_name in test_files:
        src_path = os.path.join(class_dir, file_name)
        dst_path = os.path.join(test_dir, class_name, file_name)
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)
        copyfile(src_path, dst_path)


# Advanced Models: EVAE-Net

In [19]:
# imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms, utils
import torchvision.models as models
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings("ignore")

In [20]:
class EVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(EVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes #4
        
        # Define ResNet50 Encoder
        resnet = models18(pretrained=True)
        resnet_layers = list(resnet.children())[:-1]  # Remove last layer (classification head)
        self.resnet_encoder = nn.Sequential(*resnet_layers)
        
        # Define VGG16 Encoder
        vgg16 = models.vgg16(pretrained=True)
        vgg16_layers = list(vgg16.features.children())[:-1]  # Remove last layer (max pooling)
        self.vgg16_encoder = nn.Sequential(*vgg16_layers)
        
        # Define reparameterization layers
        self.fc1 = nn.Linear(4096, latent_dim)  # 4096 = 2048 from ResNet50 + 2048 from VGG16
        self.fc2 = nn.Linear(4096, latent_dim)
        
        # Define classification head
        self.classification_head = nn.Linear(latent_dim, num_classes)
        
        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        resnet_features = self.resnet_encoder(x)
        vgg16_features = self.vgg16_encoder(x)
        features = torch.cat((resnet_features.view(resnet_features.size(0), -1), 
                              vgg16_features.view(vgg16_features.size(0), -1)), dim=1)
        h1 = F.relu(self.fc1(features))
        h2 = F.relu(self.fc2(features))
        return h1, h2
    
    def decode(self, z):
        x_hat = self.decoder(z.unsqueeze(-1).unsqueeze(-1))
        return x_hat
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = eps * std + mu
        x_hat = self.decode(z)
        y = self.classification_head(z)
        return x_hat, y, mu, log_var
    
    def loss_function(self, x_hat, x, y, target, mu, log_var):
        BCE = F.binary_cross_entropy(x_hat.view(-1, 1024*1024), x.view(-1, 1024*1024), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        Lcls = F.cross_entropy(y, target, reduction='sum')
        return BCE, KLD, Lcls

### Training

In [24]:
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms # need to adapt image format


# Access train data
train_dir = '/content/drive/MyDrive/COVID-19_Radiography_Dataset/train'
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
train_data = datasets.ImageFolder(train_dir, transform=transform)

In [21]:
def train(model, optimizer, train_loader, device):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        recon_batch, y, mu, log_var = model(data)
        BCE, KLD, Lcls = model.loss_function(recon_batch, data, y, mu, log_var)
        loss = BCE + KLD + Lcls
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EVAE(latent_dim=256, num_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5

for epoch in range(epochs):
    train(model, optimizer, train_loader, device)

RuntimeError: ignored

# ----BACK UP------

### Vanilla VAE from Git Repo

In [11]:
class VanillaVAE():


    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

NameError: ignored

# VAE Classifier

In [6]:
class VAEClassifier(BaseVAE):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VAEClassifier, self).__init__()

        self.num_classes = num_classes

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_logits = nn.Linear(hidden_dims[-1]*4, num_classes)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(num_classes, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1],
                               hidden_dims[-1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=3,
                      kernel_size=3, padding=1),
            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        logits = self.fc_logits(result)
        return [logits, F.softmax(logits, dim=1)]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, logits: Tensor) -> Tensor:
        std = torch.ones_like(logits)
        eps = torch.randn_like(logits)
        return eps * std + logits

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        logits, probs = self.encode(input)
        z = self.reparameterize(logits)
        return [self.decode(z), input, probs]


# Training

In [25]:
# Define dataloader for batching and shuffling the data
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)

In [5]:
from abc import abstractmethod
import torch
from torch import nn
from torch.nn import functional as F

from typing import List, Tuple, Dict, Any
from torch import Tensor

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

In [None]:
# Pre-process to fixed size 224 x 224


In [13]:
# Instantiate VAE model
model = VAEClassifier(in_channels=3, num_classes=4, hidden_dims=[32, 64, 128, 256])

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Train model for specified number of epochs
num_epochs = 5

for epoch in range(num_epochs):
    # Iterate over batches of training data
    for images, labels in train_loader:
        # Zero out gradient
        optimizer.zero_grad()

        # Forward pass through model
        images = images.view(images.size(0), -1) # flatten
        recon_images, mu, log_var = model(images)

        # Compute reconstruction loss
        loss = criterion(recon_images, images)

        # Compute KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss += kl_loss

        # Backward pass through model and update parameters
        loss.backward()
        optimizer.step()

    # Print loss after each epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")


RuntimeError: ignored

# Training

In [None]:

model = EVAE(latent_dim=256, num_classes=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

num_epochs = 5














for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        x, target = x.to(device), target.to(device)
        optimizer.zero_grad()
        x_hat, y, mu, log_var = model(x)
        bce_loss, kld_loss, cls_loss = model.loss_function(x_hat, x, y, target, mu, log_var)
        loss = bce_loss + kld_loss + cls_loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    # with torch.no_grad():
    #     for batch_idx, (x, target) in enumerate(val_loader):
    #         x, target = x.to(device), target.to(device)
    #         x_hat, y, mu, log_var = model(x)
    #         bce_loss, kld_loss, cls_loss = model.loss_function(x_hat, x, y, target, mu, log_var)
    #         loss = bce_loss + kld_loss + cls_loss
    #         val_loss += loss.item()
    #         _, predicted = torch.max(y.data, 1)
    #         total += target.size(0)
    #         correct += (predicted == target).sum().item()
    
    train_loss /= len(train_loader.dataset)
    # val_loss /= len(val_loader.dataset)
    accuracy = 100 * correct / total
    
    print('Epoch: {}, Train Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, train_loss, accuracy))


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class EVAE(nn.Module):
    def init(self, input_dim, latent_dim, hidden_dims):
        super().init()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        self.optimizer = torch.optim.Adam(self.parameters())
    
def build_encoder(self):
    layers = []
    layers.append(nn.Linear(self.input_dim, self.hidden_dims[0]))
    layers.append(nn.ReLU())
    for i in range(1, len(self.hidden_dims)):
        layers.append(nn.Linear(self.hidden_dims[i-1], self.hidden_dims[i]))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(self.hidden_dims[-1], self.latent_dim))
    self.classifier = nn.Linear(self.latent_dim, self.num_classes)
    layers.append(nn.Linear(self.hidden_dims[-1], self.latent_dim))
    return nn.Sequential(*layers)

def build_decoder(self):
    layers = []
    layers.append(nn.Linear(self.latent_dim, self.hidden_dims[-1]))
    layers.append(nn.ReLU())
    for i in range(len(self.hidden_dims)-2, -1, -1):
        layers.append(nn.Linear(self.hidden_dims[i+1], self.hidden_dims[i]))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(self.hidden_dims[0], self.input_dim))
    layers.append(nn.Sigmoid())
    return nn.Sequential(*layers)

def sampling(self, z_mean, z_log_var):
    epsilon = torch.randn_like(z_mean)
    return z_mean + torch.exp(0.5 * z_log_var) * epsilon

def forward(self, x):
    z_mean, z_log_var = self.encoder(x)
    z = self.sampling(z_mean, z_log_var)
    class_probs = F.softmax(self.classifier(z), dim=1)
    x_pred = self.decoder(z)
    return x_pred, z_mean, z_log_var, class_probs

def vae_loss(self, x, x_pred, z_mean, z_log_var):
    recon_loss = F.binary_cross_entropy(x_pred, x, reduction='sum')
    kl_loss = -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    return recon_loss + kl_loss

def train(self, x_train, x_val=None, batch_size=32, epochs=5):
    train_dataset = TensorDataset(torch.Tensor(x_train))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    if x_val is not None:
        val_dataset = TensorDataset(torch.Tensor(x_val))
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
    else:
        val_loader = None
    for epoch in range(epochs):
        train_loss = 0
        val_loss = 0
        for x_batch in train_loader:
            x_batch = x_batch[0]
            self.optimizer.zero_grad()
            x_pred, z_mean, z_log_var = self.forward(x_batch)
            loss = self.vae_loss(x_batch, x_pred, z_mean, z_log_var)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader.dataset)
        if val_loader is not None:
            with torch.no_grad():
                for x_batch in val_loader:
                    x_batch = x_batch[0]
                    x_pred, z_mean, z_log_var = self.forward(x_batch)
                    loss = self.vae_loss(x_batch, x_pred, z_mean, z_log_var)
                    val_loss += loss.item()
                val_loss /= len(val_loader.dataset)
                print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
# EVAE-Net with pre-trained ResNet50 
import tensorflow as tf

# We implement EVAE-Net which concatenates the outputs of ResNet50 and VGG16 as the input of the encoder.

class EVAE:
    def __init__(self, input_dim, latent_dim, hidden_dims):
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        self.optimizer = tf.keras.optimizers.Adam()
        
    def build_encoder(self):
        inputs = tf.keras.layers.Input(shape=(self.input_dim,))
        x = inputs
        for hidden_dim in self.hidden_dims:
            x = tf.keras.layers.Dense(hidden_dim, activation='relu')(x)
        z_mean = tf.keras.layers.Dense(self.latent_dim)(x)
        z_log_var = tf.keras.layers.Dense(self.latent_dim)(x)
        z = tf.keras.layers.Lambda(self.sampling)([z_mean, z_log_var])
        return tf.keras.Model(inputs, [z_mean, z_log_var, z], name='encoder')

    def build_decoder(self):
        latent_inputs = tf.keras.layers.Input(shape=(self.latent_dim,))
        x = latent_inputs
        for hidden_dim in reversed(self.hidden_dims):
            x = tf.keras.layers.Dense(hidden_dim, activation='relu')(x)
        outputs = tf.keras.layers.Dense(self.input_dim, activation='sigmoid')(x)
        return tf.keras.Model(latent_inputs, outputs, name='decoder')

    def sampling(self, args):
        z_mean, z_log_var = args
        epsilon = tf.keras.backend.random_normal(shape=(tf.keras.backend.shape(z_mean)[0], self.latent_dim))
        return z_mean + tf.keras.backend.exp(0.5 * z_log_var) * epsilon

    def vae_loss(self, x, x_pred, z_mean, z_log_var):
        recon_loss = tf.keras.losses.binary_crossentropy(x, x_pred)
        kl_loss = -0.5 * tf.keras.backend.mean(1 + z_log_var - tf.keras.backend.square(z_mean) - tf.keras.backend.exp(z_log_var), axis=-1)
        return tf.keras.backend.mean(recon_loss + kl_loss)

    @tf.function
    def train_step(self, x):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(x)
            x_pred = self.decoder(z)
            loss = self.vae_loss(x, x_pred, z_mean, z_log_var)
        gradients = tape.gradient(loss, self.encoder.trainable_variables + self.decoder.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables + self.decoder.trainable_variables))
        return loss

    def train(self, x_train, x_val=None, batch_size=32, epochs=100):
        train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(len(x_train)).batch(batch_size)
        val_dataset = None
        if x_val is not None:
            val_dataset = tf.data.Dataset.from_tensor_slices(x_val).batch(batch_size)
        for epoch in range(epochs):
            train_loss = tf.keras.metrics.Mean()
            val_loss = tf.keras.metrics.Mean()
            for x_batch in train_dataset:
                loss = self.train_step(x_batch)
                train_loss(loss)
            if val_dataset is not None:
                for x_batch in val_dataset:
                    z_mean, z_log_var, z = self.encoder(x_batch)
                    x_pred = self.decoder(z)
                    loss = self.vae_loss(x_batch, x_pred, z_mean, z_log_var)
