In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Encoder Class

In [2]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        # Define the layers for the encoder
        self.flatten = nn.Flatten()  # Flatten the image input
        self.dense1 = nn.Linear(28 * 28 + 10, 512)  # First dense layer
        self.dense2 = nn.Linear(512, 256)  # Second dense layer
        self.z_mean = nn.Linear(256, latent_dim)  # Mean of the latent distribution
        self.z_log_var = nn.Linear(256, latent_dim)  # Log variance of the latent distribution

    def forward(self, x, label):
        label = F.one_hot(label, num_classes=10).float()  # One-hot encode the labels
        x = torch.cat([x.view(x.size(0), -1), label], dim=-1)  # Concatenate label with the image
        x = F.relu(self.dense1(x))  # Pass through the first dense layer
        x = F.relu(self.dense2(x))  # Pass through the second dense layer
        z_mean = self.z_mean(x)  # Mean of the latent space
        z_log_var = self.z_log_var(x)  # Log variance of the latent space
        
        return z_mean, z_log_var

## Decoder class

In [3]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        # Define the layers for the decoder
        self.dense1 = nn.Linear(latent_dim, 256)  # First dense layer
        self.dense2 = nn.Linear(256, 512)  # Second dense layer
        self.output_layer = nn.Linear(512, 28 * 28)  # Output layer to reconstruct the image

    def forward(self, z):
        x = F.relu(self.dense1(z))  # Pass through the first dense layer
        x = F.relu(self.dense2(x))  # Pass through the second dense layer
        x = torch.sigmoid(self.output_layer(x))  # Reconstruct the image
        return x.view(-1, 1, 28, 28)  # Reshape back to image shape

## CVAE class

In [4]:
# Reparameterization trick (sampling from the latent space)
def sampling(z_mean, z_log_var):
    epsilon = torch.randn_like(z_mean)  # Sample from a normal distribution
    return z_mean + torch.exp(0.5 * z_log_var) * epsilon  # Reparameterization trick

In [5]:
class CVAE(nn.Module):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x, label):
        z_mean, z_log_var = self.encoder(x, label)  # Get mean and log variance from the encoder
        z = sampling(z_mean, z_log_var)  # Sample from the latent space using the reparameterization trick
        reconstructed = self.decoder(z)  # Reconstruct the image using the decoder
        return reconstructed, z_mean, z_log_var

# Loss function: Reconstruction loss + KL divergence
def compute_loss(x, reconstructed, z_mean, z_log_var):
    # Binary cross-entropy loss for image reconstruction
    BCE = F.binary_cross_entropy(reconstructed, x, reduction='sum')
    
    # KL divergence between the learned distribution and the prior
    KL = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
    return BCE + KL

---