In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import tqdm as tqdm
import torchvision
from torchvision import transforms
from ml_zoo import MNISTDataModuleConfig, MNISTDataModule

# Create DataModule
dm_config = MNISTDataModuleConfig(
    data_dir="data",
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    transforms=[transforms.Resize((32, 32)), transforms.ToTensor()],
    use_qmnist=True,
)

dm = MNISTDataModule(dm_config)
dm.prepare_data()
dm.setup()
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

In [4]:
# Save first batch of images
images, labels = next(iter(val_loader))
path = "blog/2-mnist/imgs/"
for i in range(len(images)):
    torchvision.utils.save_image(images[i], path + f"img_{i}.png")

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(self, latent):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(32*32, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, latent),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 32*32),
            nn.Sigmoid()
        )

    def forward(self, x):
        latent = self.encoder(x)
        decoded = self.decoder(latent)
        return decoded, latent


model = SparseAutoencoder(latent=4096)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
def criterion(x_hat, x, latent):
    return F.mse_loss(x_hat, x) + latent.abs().mean() * 1e-3

@torch.no_grad()
def eval():
    model.eval()
    pbar = tqdm.tqdm(val_loader, desc="Val")
    val_loss = 0
    for batch in pbar:
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat, enc = model(x)
        loss = criterion(x_hat, x, enc).item()
        val_loss += loss
        pbar.set_postfix({"loss": loss})

    return val_loss / len(val_loader)

In [None]:
# Training
val_loss = eval()
for epoch in range(10):
    model.train()
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        x, _ = batch
        x = x.view(x.size(0), -1)
        optimizer.zero_grad()
        x_hat, enc = model(x)
        loss = criterion(x_hat, x, enc)
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item(), "val_loss": val_loss})

    val_loss = eval()

In [None]:
x, y = next(iter(val_loader))
x = x.view(x.size(0), -1)

x_hat, _ = model(x)

fig, ax = plt.subplots(2, 5, figsize=(20, 8))
for i in range(5):
    ax[0, i].imshow(x[i].view(32, 32).numpy())
    ax[1, i].imshow(x_hat[i].view(32, 32).detach().numpy())

plt.show()

In [None]:
# Extract latent space
model.eval()
latents = []
class_labels = []
for batch in val_loader:
    x, y = batch
    x = x.view(x.size(0), -1)
    _, enc = model(x)
    latents.append(enc)
    class_labels.append(y)

latents = torch.cat(latents, dim=0).detach()
class_labels = torch.cat(class_labels, dim=0)

In [None]:
latents.flatten().mean().item(), latents.flatten().std().item()

In [None]:
latent[0]

In [None]:
# PCA
from sklearn.decomposition import PCA
import plotly.express as px

pca = PCA(n_components=3)
pca_latents = pca.fit_transform(latents)

fig = px.scatter_3d(
    x=pca_latents[:, 0],
    y=pca_latents[:, 1],
    z=pca_latents[:, 2],
    # z=class_labels,
    color=class_labels,
    labels={"color": "Class"},
)
fig.show()