In [None]:
""" The following code is a simple implementation of a variational autoencoder, or VAE, in PyTorch.
The autoencoder is trained on the MNIST dataset. 
The code follows the following structure:
1. Import libraries, load and preprocess the data
2. Define the AE architecture
3. Create an instance of the AE and train it
4. Generate samples from the trained AE
5. Visualize the results of compression and generation
6. Analyze the latent space of the AE
7. Create a k-NN classifier on the latent space
"""

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import random
from keras.datasets import mnist


device = torch.device('mps')


In [2]:
#load mnist 
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:

# Normalize and Reshape images (flatten)
x_train, x_test = torch.tensor(x_train.astype('float32')/255.), torch.tensor(x_test.astype('float32')/255.).to(device)
x_train_flat, x_test_flat = x_train.reshape(x_train.shape[0], -1), x_test.reshape(x_test.shape[0], -1).to(device)

In [4]:
from torch.utils.data import DataLoader, TensorDataset


batch_size = 64 

train_dataset = TensorDataset(x_train_flat)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim, use_hidden_layer=False, hidden_dim=32, BCE = False):
        super(Autoencoder, self).__init__()
        self.BCE = BCE
        self.latent_dim = latent_dim 
        # Encoder
        if use_hidden_layer:
            self.encoder = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, latent_dim),
                nn.Tanh()
            )
        else:
            self.encoder = nn.Sequential(
                nn.Linear(input_dim, latent_dim),
                nn.ReLU()
            )
        
        # Decoder
        if use_hidden_layer:
            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim),
                nn.Sigmoid()
            )
        else:
            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, input_dim),
                nn.Sigmoid()
            )

    def forward(self, x, y=None, output_latent=False):
        latent = self.encoder(x)
        output = self.decoder(latent)
        if output_latent:
            return latent
        if y is None:
            return output
        else:
            if self.BCE == True:
                loss = F.binary_cross_entropy(output, y, reduction='sum')
            else:
                loss = F.mse_loss(output, y)
            return output, loss


In [6]:
model = Autoencoder(input_dim=784, latent_dim=32, use_hidden_layer=True, hidden_dim=512, BCE = True).to(device)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
lossi = []


In [7]:
# Training loop
epochs = 50
for e in range(epochs):
    for batch in train_loader: 
        x_batch = batch[0].to(device)

        output, loss = model(x=x_batch, y=x_batch)
        optimizer.zero_grad()

        lossi.append(loss.item())
        loss.backward()
        optimizer.step()

    if e % 1 == 0:
        print(f'{e}')

    

0
1
2
3
4
5
6
7
8
9
10
11
12


In [None]:
plt.plot(lossi)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Losses over Epochs')
plt.show()

# print average over last 1000 losses
print(np.mean(lossi[-1000:]))


In [None]:
import matplotlib.pyplot as plt

rand = random.randint(0,9999)
output = model.forward(x_test_flat[rand])
print(y_test[rand])

# Display the original image
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(x_test_flat[rand].cpu().reshape(28,28), cmap='gray')

# Display the generated image
output_scaled = output
plt.subplot(1, 2, 2)
plt.title("Generated Image")
plt.imshow(output.reshape(28,28).detach().cpu(), cmap='gray')

plt.show()

In [None]:
# import torch
# import matplotlib.pyplot as plt

# # Assuming model, x_test_flat, and y_test are already defined and loaded

# # Calculate losses for all test images
# losses = []
# outputs = []
# for i in range(10000):
#     output, loss = model(x=x_test_flat[i], y=x_test_flat[i])
#     losses.append(loss.item())
#     outputs.append(output)

# # Convert lists to tensors for easier indexing
# losses = torch.tensor(losses)
# outputs = torch.stack(outputs)

# # Find indices of the 10 lowest and 10 highest losses
# lowest_losses_indices = torch.argsort(losses)[:10]
# highest_losses_indices = torch.argsort(losses, descending=True)[:10]

# # Display the original and generated images with lowest losses
# plt.figure(figsize=(20, 8))
# for idx, i in enumerate(lowest_losses_indices):
#     plt.subplot(2, 10, idx + 1)
#     plt.title(f"Orig {idx+1}")
#     plt.imshow(x_test_flat[i].cpu().reshape(28, 28), cmap='gray')
#     plt.axis('off')

#     plt.subplot(2, 10, idx + 11)
#     plt.title(f"Gen {idx+1}")
#     plt.imshow(outputs[i].reshape(28, 28).detach().cpu(), cmap='gray')
#     plt.axis('off')

# plt.suptitle("10 Best Reconstructions (Lowest Losses)")
# plt.show()

# # Display the original and generated images with highest losses
# plt.figure(figsize=(20, 8))
# for idx, i in enumerate(highest_losses_indices):
#     plt.subplot(2, 10, idx + 1)
#     plt.title(f"Orig {idx+1}")
#     plt.imshow(x_test_flat[i].cpu().reshape(28, 28), cmap='gray')
#     plt.axis('off')

#     plt.subplot(2, 10, idx + 11)
#     plt.title(f"Gen {idx+1}")
#     plt.imshow(outputs[i].reshape(28, 28).detach().cpu(), cmap='gray')
#     plt.axis('off')

# plt.suptitle("10 Worst Reconstructions (Highest Losses)")
# plt.show()


In [None]:
# In the case of a 2d latent space, we can visualize the latent space by plotting the color coded latent vectors of the test images
if model.latent_dim == 2:
    latent_vectors = model.forward(x_test_flat, output_latent=True).detach().cpu().numpy()
    plt.figure(figsize=(10, 8))
    plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], c=y_test, cmap='tab10')
    plt.colorbar()
    plt.xlabel("Latent X")
    plt.ylabel("Latent Y")
    plt.title("2D Latent Space Visualization")
    plt.show()



In [None]:
# # Make a bar graph of the average loss for each digit
# average_losses = []
# for i in range(10):
#     indices = (y_test == i)
#     average_loss = torch.mean(losses[indices]).item()
#     average_losses.append(average_loss)

# plt.figure(figsize=(10, 5))
# plt.bar(range(10), average_losses)
# plt.xticks(range(10))
# plt.xlabel("Digit")
# plt.ylabel("Average Loss")
# plt.title("Average Loss per Digit")
# plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics.pairwise import cosine_similarity

def plot_latent_space_distances(model, x_test_flat, y_test, num_instances=10, metric='dot_product'):
    # Find specified number of instances of each digit
    instances = []
    for i in range(10):
        indices = (y_test == i)
        digit_instances = x_test_flat[indices][:num_instances]  # Get num_instances for each digit
        instances.append(digit_instances)

    # Concatenate all instances
    instances = torch.cat(instances)
    instances = instances.view(-1, x_test_flat.shape[1])  # Flatten the instances if not already

    # Calculate latent space representation of each instance
    latent_vectors = model.forward(instances, output_latent=True).detach().cpu().numpy()

    # Reshape latent vectors to have a separate axis for each digit
    latent_vectors = latent_vectors.reshape(10, num_instances, -1)

    # Calculate the desired metric between each digit's latent space representation and all other digits
    average_distances = np.zeros((10, 10))

    for i in range(10):
        for j in range(10):
            if metric == 'dot_product':
                distance = np.mean([np.dot(latent_vectors[i, k], latent_vectors[j, l]) for k in range(num_instances) for l in range(num_instances)])
            elif metric == 'euclidean':
                distance = np.mean([np.linalg.norm(latent_vectors[i, k] - latent_vectors[j, l]) for k in range(num_instances) for l in range(num_instances)])
            elif metric == 'manhattan':
                distance = np.mean([np.sum(np.abs(latent_vectors[i, k] - latent_vectors[j, l])) for k in range(num_instances) for l in range(num_instances)])
            elif metric == 'cosine':
                distance = np.mean([cosine_similarity(latent_vectors[i, k].reshape(1, -1), latent_vectors[j, l].reshape(1, -1))[0][0] for k in range(num_instances) for l in range(num_instances)])
            else:
                raise ValueError(f"Unknown metric: {metric}")
            average_distances[i, j] = distance

    # Plot the heatmap
    plt.figure(figsize=(10, 8))
    plt.imshow(average_distances, cmap='viridis')
    plt.colorbar()
    plt.xticks(range(10))
    plt.yticks(range(10))
    plt.xlabel("Digit")
    plt.ylabel("Digit")
    plt.title(f"Average {metric.replace('_', ' ').capitalize()} Between Digits in Latent Space ({num_instances} Instances Each)")
    plt.show()

# Usage example:
# Assuming model, x_test_flat, y_test are already defined and loaded.
plot_latent_space_distances(model, x_test_flat, y_test, num_instances=100, metric='cosine')  # You can specify any metric: 'dot_product', 'euclidean', 'manhattan', 'cosine'