In [None]:
import torch
from torch import nn
import lightning as L

class Encoder(nn.Module):
    def __init__(self, latent_dim: int = 20):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 8),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)


class Decoder(nn.Module):
    def __init__(self, latent_dim: int = 20):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x).view(x.size(0), 1, 28, 28)

In [None]:
class AutoEncoder(L.LightningModule):
    def __init__(self, latent_dim: int = 20):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def training_step(self, batch, batch_idx):
        x, _ = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

# Prepare the data
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = MNIST("../data", train=True, download=True, transform=transform)
mnist_val = MNIST("../data", train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# Initialize our model
autoencoder = AutoEncoder()

# Initialize a trainer
trainer = Trainer(max_epochs=10, gpus=1)

# Train the model
trainer.fit(autoencoder, train_loader, val_loader)