# This notebook is for training a new autoencoder to downscale the MNIST (or Fashion-MNIST) dataset.
### To download the MNIST dataset, you need to install pytorch or find an equivalent source.
### Otherwise, you may use our pretrained autoencoder and downscaled data in the folder.

In [None]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from torchvision.datasets import MNIST, FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# Hyperparameters
latent_dim = 8  # Size of the bottleneck (compressed representation)
#latent_dim = 16 # for Fashion-MNIST, it's better to use higher dimensions
input_dim = 784  # MNIST is 28x28 = 784
learning_rate = 1e-3
batch_size = 256
epochs = 200

# Load MNIST
def numpy_collate(batch):
    images, labels = zip(*batch)
    return np.stack(images), np.stack(labels)

train_dataset = MNIST(root="raw_mnist/", train=True, download=True, transform=lambda x: np.array(x, dtype=np.float32).reshape(-1) / 255.0)
#train_dataset = FashionMNIST(root="raw_mnist/", train=True, download=True, transform=lambda x: np.array(x, dtype=np.float32).reshape(-1) / 255.0)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
latent_loader = DataLoader(train_dataset, batch_size=80, shuffle=False, collate_fn=numpy_collate)
evaluation_loader = DataLoader(train_dataset, batch_size=96, shuffle=True, collate_fn=numpy_collate)

# Define Autoencoder
class Autoencoder(eqx.Module):
    encoder: eqx.Module
    decoder: eqx.Module

    def __init__(self, key):
        key1, key2 = jax.random.split(key)
        # Encoder: 784 -> 128 -> 64 -> latent_dim
        self.encoder = eqx.nn.Sequential([
            eqx.nn.Linear(input_dim, 128, key=key1),
            eqx.nn.Lambda(jax.nn.relu),
            eqx.nn.Linear(128, 64, key=key2),
            eqx.nn.Lambda(jax.nn.relu),
            eqx.nn.Linear(64, latent_dim, key=key1),  # Bottleneck
        ])
        # Decoder: latent_dim -> 64 -> 128 -> 784
        self.decoder = eqx.nn.Sequential([
            eqx.nn.Linear(latent_dim, 64, key=key2),
            eqx.nn.Lambda(jax.nn.relu),
            eqx.nn.Linear(64, 128, key=key1),
            eqx.nn.Lambda(jax.nn.relu),
            eqx.nn.Linear(128, input_dim, key=key2),
            eqx.nn.Lambda(jax.nn.sigmoid),  # MNIST pixels are in [0, 1]
        ])

    def __call__(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# Training

In [None]:
# Initialize model and optimizer
key = jax.random.PRNGKey(42)
model = Autoencoder(key)
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Loss function (MSE)
@eqx.filter_value_and_grad
def compute_loss(model, x):
    x_recon = jax.vmap(model)(x)  # Vectorize over batch
    return jnp.mean((x_recon - x) ** 2)

@eqx.filter_jit
def train_step(model, opt_state, x):
    loss, grads = compute_loss(model, x)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

# Training loop
for epoch in range(epochs):
    for batch in train_loader:
        x, _ = batch
        x = jnp.array(x)  # Convert to JAX array
        model, opt_state, loss = train_step(model, opt_state, x)
    print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")

# Save the trained model
eqx.tree_serialise_leaves("autoencoder_mnist_8.eqx", model)

# Sample

In [None]:
import torchvision.transforms.functional as F
from torchvision.utils import make_grid

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

for batch in evaluation_loader:
    x, _ = batch
    x = jnp.array(x)  # Convert to JAX array
    x_recon = jax.vmap(model)(x)
    #model, opt_state, loss = train_step(model, opt_state, x)
    break

import numpy as np
import torch
x_recon = torch.tensor(np.array(x_recon))
x_recon = x_recon.view(96, 1, 28, 28)

show(make_grid(x_recon))

## Extract latent space

In [None]:
key = jax.random.PRNGKey(42)
model = Autoencoder(key)
model = eqx.tree_deserialise_leaves("autoencoder_mnist_8.eqx", model)

In [None]:
latent_space = []
labels = [] 
for batch in latent_loader:
    x, label = batch
    x = jnp.array(x)  # Convert to JAX array
    z = jax.vmap(model.encoder)(x)
    latent_space.append(np.array(z))
    labels.append(label)

latent_space = jnp.asarray(latent_space)
latent_vector = jnp.reshape(latent_space, (60000, latent_dim))

np.save("ae_mnist_8", latent_vector))

In [None]:
for batch in evaluation_loader:
    x, _ = batch
    x = jnp.array(x)  # Convert to JAX array
    x_latent = jax.vmap(model.encoder)(x)
    break

x_latent = np.load("ae_mnist_8.npy")

plt.imshow(x_latent[:20, :])
plt.colorbar()