Домашнее задание:
- Получить точность 97.5% на валидации MNIST.
- Реализовать морфинг автоэнкодером (без формальных критериев — просто получите красивую гифку).
- Визуализировать MNIST автоэнкодером (обучить автоэнкодер с латентным пространством размерности 2 и вывести через scatter точки разного цвета).

Линка на колаб с тетрадкой: [kaggle.com](https://www.kaggle.com/hashshes/training-mnist-97-val-accuracy)

In [None]:
# If running on colab

# ! pip install -qqq pytorch-lightning torch torchvision torchmetrics

In [2]:
import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
from torch.optim import Adam

from torchmetrics.classification import Accuracy

pl.seed_everything(0)

In [3]:
config = {
    'batch_size': 500,
    'lr': 0.001,
    'max_epochs': 100,
    'lr_step': 30,
    'weight_decay': 10
}

## Defining the dataloader for mnist model

In [4]:
default_transform = transforms.Compose([
    transforms.ToTensor(),
])

class MNISTDataloader(pl.LightningDataModule):
    def __init__(self, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = datasets.MNIST(root="../data/raw", download=True, train=True, transform=default_transform)
        self.test_dataset = datasets.MNIST(root="../data/raw", download=True, train=False, transform=default_transform)

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=2)

    def val_dataloader(self):
        return DataLoader(dataset=self.test_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=2)

## Defining the dense model for mnist model

In [5]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(in_features=28 * 28, out_features=100),
    nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(in_features=100, out_features=10)
)


## Defining training loop for our model

In [6]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, _model: nn.Module, _config: dict):
        super().__init__()
        self.model = _model
        self.config = _config

        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()

        self.loss_fn = nn.CrossEntropyLoss()

        self.count_epoch = 0


    def training_step(self, batch, batch_idx):
        x, y = batch

        outputs = self.model(x)
        loss = self.loss_fn(outputs, y)

        self.log("train/loss_step", loss.item())
        self.log('train/acc_step', self.train_accuracy(outputs, y))

        return loss

    def training_epoch_end(self, outputs) -> None:
        self.count_epoch += 1
        print(f'Train accuracy on {self.count_epoch} epoch: {self.train_accuracy.compute()}')

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.model(x)
        self.log('val/acc_step', self.val_accuracy(outputs, y))

    def validation_epoch_end(self, outputs) -> None:
        print(f"Val accuracy on {self.count_epoch - 1} epoch: {self.val_accuracy.compute()}")


    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.config['lr'])
        lr_scheduler = StepLR(optimizer=optimizer, step_size=self.config['lr_step'])

        return [optimizer, ], [lr_scheduler, ]


In [None]:
def train(_config: dict):
    trainer = pl.Trainer(gpus=0, max_epochs=_config['max_epochs'],)

    data_loader = MNISTDataloader(batch_size=_config['batch_size'])
    classifier = MNISTClassifier(_model=model, _config=_config)
    trainer.fit(classifier, data_loader.train_dataloader(), data_loader.val_dataloader())
    trainer.save_checkpoint("../weights/mnist.ckpt", weights_only=True)

if __name__ == '__main__':
    train(config)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")

  | Name           | Type             | Params
----------------------------------------------------
0 | model          | Sequential       | 79.5 K
1 | train_accuracy | Accuracy         | 0     
2 | val_accuracy   | Accuracy         | 0     
3 | loss_fn        | CrossEntropyLoss | 0     
----------------------------------------------------
79.5 K    Trainable params
0         Non-trainable params
79.5 K    Total params
0.318     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Train accuracy on 1 epoch: 0.8823000192642212
