<a href="https://colab.research.google.com/github/kampelmuehler/MLKurs/blob/main/MNIST/MNIST_Autoencoder_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title Setup
! pip install --quiet "torchvision" "torch>=1.8" "torchmetrics>=0.7" "pytorch-lightning>=1.4"
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

BATCH_SIZE = 256 if torch.cuda.is_available() else 64

In [3]:
class MNISTAutoencoder(LightningModule):
    def __init__(self, bottleneck, data_dir="./data", learning_rate=2e-4):

        super().__init__()
        self.bottleneck = bottleneck
        self.data_dir = data_dir
        self.learning_rate = learning_rate

        # Parameter für den Datensatz
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        self.transform_predict = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        self.transform_visualize = transforms.Compose(
            [
                transforms.Normalize((-0.1307 / 0.3081,), (1 / 0.3081,)),
            ]
        )

        self.activation = nn.ReLU()
        # Modelle definieren
        self.encoder = nn.Sequential( 
        )
        self.decoder = nn.Sequential( 
        )

    def forward(self, x):
        encoding = self.encoder(x)
        return self.decoder(encoding)  # Achtung! Shape!

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = None
        return loss

    def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
      inputs = batch[0]
      if batch_idx == 0:
        reconstruction = self(inputs)
        input_grid = make_grid(inputs[0:64], 8, padding=4, pad_value=inputs.max())
        reconstruction_grid = make_grid(reconstruction[0:64], 8, padding=4, pad_value=inputs.max())
        input_grid = self.transform_visualize(input_grid)
        reconstruction_grid = self.transform_visualize(reconstruction_grid)
        input_grid = transforms.ToPILImage()(input_grid)
        reconstruction_grid = transforms.ToPILImage()(torch.clamp(reconstruction_grid, 0, 1))
        dst = Image.new('L', (input_grid.width + input_grid.width + 20, input_grid.height))
        dst.paste(input_grid, (0, 0))
        dst.paste(reconstruction_grid, (reconstruction_grid.width + 20, 0))
        display(dst)

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        reconstruction = self(x)
        loss = F.binary_cross_entropy_with_logits(x, reconstruction)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # Datensatz vorbereiten
    ####################

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform_train)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

In [None]:
model = MNISTAutoencoder(bottleneck=None)
trainer = Trainer(
    accelerator="auto",
    devices="auto", 
    max_epochs=10,
)
trainer.fit(model)