### Theoretical Backgrounds

### Reconstructing MNIST images

In [21]:
# Load libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import tqdm

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")


We now load the data. Because the MNIST dataset is quite large, we import only the training dataset, which consists of 60000 images.

In [22]:
# Load MNIST dataset

mnist_dataset = MNIST(
    root='../../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
mnist_dataset


Dataset MNIST
    Number of datapoints: 60000
    Root location: ../../data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [23]:
mnist_dataset.data.shape

torch.Size([60000, 28, 28])

In [4]:
# flatten the images
X_data = mnist_dataset.data.view(-1, 28*28)

In [24]:
X_train, X_test, y_train, y_test = train_test_split(
    X_data, 
    mnist_dataset.targets, 
    test_size=0.2, 
    random_state=42,
)

In [25]:
print(f"X_train.shape: {X_train.shape}")
print(f"X_test.shape: {X_test.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"y_test.shape: {y_test.shape}")


X_train.shape: torch.Size([48000, 784])
X_test.shape: torch.Size([12000, 784])
y_train.shape: torch.Size([48000])
y_test.shape: torch.Size([12000])


In [130]:
# # Scale the data so that they have zero mean and unit variance

# scaler = StandardScaler()
# X_train = scaler.fit_transform(X_train)
# X_test = scaler.transform(X_test)

In [26]:
train_loader = DataLoader(X_train, batch_size=64, shuffle=True)
test_loader = DataLoader(X_test, batch_size=64, shuffle=False)

In [27]:
# Design a sparse autoencoder

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], latent_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dims, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dims[2])
        self.fc2 = nn.Linear(hidden_dims[2], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[0])
        self.fc4 = nn.Linear(hidden_dims[0], output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [28]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dims, input_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [29]:
input_dim = X_train.shape[1]
hidden_dims = [256, 128, 64]
latent_dim = 16
lr = 1e-3
num_epochs = 10

autoencoder = Autoencoder(input_dim=input_dim, hidden_dims=hidden_dims, latent_dim=latent_dim)
autoencoder

Autoencoder(
  (encoder): Encoder(
    (fc1): Linear(in_features=784, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=64, bias=True)
    (fc4): Linear(in_features=64, out_features=16, bias=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=16, out_features=64, bias=True)
    (fc2): Linear(in_features=64, out_features=128, bias=True)
    (fc3): Linear(in_features=128, out_features=256, bias=True)
    (fc4): Linear(in_features=256, out_features=784, bias=True)
  )
)

In [20]:
# Train the autoencoder

optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
criterion = nn.MSELoss()

loss_history = []

bar_format = '{l_bar}{bar:10}| {n:4}/{total_fmt} [{elapsed:>7}<{remaining:>7}, {rate_fmt}{postfix}]'
progress_bar = tqdm.trange(num_epochs, unit="ep", bar_format=bar_format, ascii=True)

for i in progress_bar:

    total_loss = 0
    for i, x in enumerate(train_loader):
        inputs = x.type(torch.float32).to(device)
        optimizer.zero_grad()
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
        loss_history.append(loss.item())
    loss_avg = np.mean(loss_history)

    progress_bar.set_postfix_str(f"Loss: {loss: 7.2f}, avg loss: {loss_avg: 7.2f}")
    progress_bar.update(0)

100%|##########|   10/10 [  00:46<  00:00,  4.67s/ep, Loss:  1317.40, avg loss:  1748.31]


In [11]:
# Train the autoencoder

optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
criterion = nn.MSELoss()

total_loss = 0
for epoch in range(num_epochs):
    for i, x in enumerate(train_loader):
        inputs = x.type(torch.float32).to(device)
        optimizer.zero_grad()
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/10, Loss: 3142.9654
Epoch 2/10, Loss: 5196.4968
Epoch 3/10, Loss: 7018.3595
Epoch 4/10, Loss: 8682.3663
Epoch 5/10, Loss: 10238.5985
Epoch 6/10, Loss: 11705.2416
Epoch 7/10, Loss: 13121.9792
Epoch 8/10, Loss: 14500.9964
Epoch 9/10, Loss: 15843.4386
Epoch 10/10, Loss: 17153.2524


### Dimensionality reduction 