In [None]:
import matplotlib.pyplot as matplotlib_pyplot
import numpy as numpy
import umap as umap
import torch as torch
import torch.nn as torch_nn
import torch.optim as torch_optim
import torch.utils.data as torch_utils_data
import torchvision as torchvision
import scanpy as scanpy
import sklearn.cluster as sklearn_cluster

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Autoencoder(torch_nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = torch_nn.Sequential(
            torch_nn.Linear(784, 128),
            torch_nn.ReLU(),
            torch_nn.Linear(128, 64),
            torch_nn.ReLU(),
            torch_nn.Linear(64, 12),
            torch_nn.ReLU(),
            torch_nn.Linear(12, 3)
        )
        self.decoder = torch_nn.Sequential(
            torch_nn.Linear(3, 12),
            torch_nn.ReLU(),
            torch_nn.Linear(12, 64),
            torch_nn.ReLU(),
            torch_nn.Linear(64, 128),
            torch_nn.ReLU(),
            torch_nn.Linear(128, 784),
            torch_nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
latent_spaces = {}
all_labels = {}
losses = []

# Load and preprocess the MNIST dataset
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize the model, loss function, and optimizer
model = Autoencoder().to(device)
criterion = torch_nn.MSELoss()
optimizer = torch_optim.Adam(model.parameters())

# Training loop
num_epochs = 125 + 1
save_every = 25
for epoch in range(num_epochs):
    epoch_loss = 0
    latent_space = []
    labels = []
    for data in train_loader:
        img, label = data
        img = img.view(img.size(0), -1).to(device)

        encoded, output = model(img)
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Store latent representations and labels
        if epoch % save_every == 0:
            latent_space.append(encoded.detach().cpu().numpy())
            labels.extend(label.numpy())

        
    if epoch % save_every == 0:
        latent_spaces[epoch] = numpy.concatenate(latent_space, axis=0)
        all_labels[epoch] = labels

    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
# Test the autoencoder
model.eval()
with torch.no_grad():
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
    test_img = test_dataset[0][0].view(-1, 28 * 28).to(device)
    _, output = model(test_img)
    output.cpu()

In [None]:
test_image_result = test_img.cpu().view(28, 28)
output_image_result = output.view(28, 28)

In [None]:
def display_umap(labels, latent_space):
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(latent_space)
    
    matplotlib_pyplot.figure(figsize=(12, 10))
    scatter = matplotlib_pyplot.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='tab10')
    matplotlib_pyplot.colorbar(scatter)
    matplotlib_pyplot.title('UMAP projection of the latent space')

In [None]:
for (_, labels), (_, space) in zip(all_labels.items(), latent_spaces.items()):
    display_umap(labels, space)