In [None]:
# Import necessary libraries and modules
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms

In [2]:
# Define a data transformation to convert images to tensors
transform = transforms.ToTensor()

In [None]:
# Load the MNIST dataset for training and validation
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
valid_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create a data loader for training data with a batch size of 100
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=100)

In [None]:
# Define the Encoder class 
class Encoder(nn.Module):
    def __init__(self, input_size=28 * 28, hidden_size1=128, hidden_size2=16, z_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, z_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Define the Decoder class 
class Decoder(nn.Module):
    def __init__(self, output_size=28 * 28, hidden_size1=128, hidden_size2=16, z_dim=2):
        super().__init()
        self.fc1 = nn.Linear(z_dim, hidden_size2)
        self.fc2 = nn.Linear(hidden_size2, hidden_size1)
        self.fc3 = nn.Linear(hidden_size1, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [None]:
# Check for GPU availability and set the device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize the Encoder and Decoder on the selected device
enc = Encoder().to(device)
dec = Decoder().to(device)

In [None]:
# Define the loss function (Mean Squared Error) and the optimizers
loss_fn = nn.MSELoss()
optimizer_enc = torch.optim.Adam(enc.parameters())
optimizer_dec = torch.optim.Adam(dec.parameters())

In [None]:
# Store training loss values for each epoch
train_loss = []
num_epochs = 100

In [None]:
# Loop through training epochs
for epoch in range(num_epochs):
    train_epoch_loss = 0

    # Iterate through batches of training data
    for (imgs, _) in train_dl:
        imgs = imgs.to(device)
        imgs = imgs.flatten(1)
        latents = enc(imgs)
        output = dec(latents)
        loss = loss_fn(output, imgs)
        train_epoch_loss += loss.cpu().detach().numpy()
        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()
        loss.backward()
        optimizer_enc.step()
        optimizer_dec.step()

    train_loss.append(train_epoch_loss)

In [None]:
# Plot the training loss over epochs
plt.plot(train_loss)

In [None]:
# Initialize variables to store latent representations and labels
values = None
all_labels = []

# Generate latent representations for the entire training dataset
with torch.no_grad():
    for (imgs, labels) in train_dl:
        imgs = imgs.to(device)
        imgs = imgs.flatten(1)
        all_labels.extend(list(labels.numpy())
        latents = enc(imgs)

        if values is None:
            values = latents.cpu()
        else:
            values = torch.vstack([values, latents.cpu()])

In [None]:
# Create a color map for visualization
cmap = plt.get_cmap('viridis', 10)

In [None]:
# Plot the scatter plot of latent space with color-coded labels
all_labels = np.array(all_labels)
values = values.numpy()
pc = plt.scatter(values[:, 0], values[:, 1], c=all_labels, cmap=cmap)
plt.colorbar(pc)

In [None]:
# Generate an image using a specific class's mean latent representation
with torch.no_grad():
    pred = dec(torch.Tensor(all_means[8])[None, ...].to(device)).cpu()
    transforms.ToPILImage()(pred.reshape(1, 28, 28))