## Training a GAN on Fashion MNIST

#### Setting for hot reloading of modules

In [None]:
%load_ext autoreload
%autoreload 2

#### Importing modules

In [2]:
import lightning as L
import torchvision
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

from src.dataset.FashionMNISTDataModule import FashionMNISTDataModule
from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.train.wrapper.gan_wrapper import GANWrapper
from src.utils.constants import Paths
from src.utils.helpers import detect_device, matplotlib_imshow

KeyboardInterrupt: 

#### Setting up the data

In [None]:
BATCH_SIZE = 1024

#### Downloading and preparing the data

In [None]:
fashionMNISTDataModule = FashionMNISTDataModule(Paths.DATA_DIR, BATCH_SIZE)
fashionMNISTDataModule.setup('fit')

# Report split sizes
print('Training set has {} instances'.format(len(fashionMNISTDataModule.train_dataloader()) * BATCH_SIZE))
print('Validation set has {} instances'.format(len(fashionMNISTDataModule.val_dataloader()) * BATCH_SIZE))

#### Defining the model hyperparameters

In [None]:
z_dim = 128
generator_input_dim = z_dim + fashionMNISTDataModule.num_classes()
input_channels = 1
discriminator_input_dim = input_channels + fashionMNISTDataModule.num_classes()

#### Visualizing the data

In [None]:
dataiter = iter(fashionMNISTDataModule.train_dataloader())
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(fashionMNISTDataModule.dataset_classes()[labels[j]] for j in range(4)))

#### Defining the model

In [None]:
generator = Generator(generator_input_dim, input_channels)
discriminator = Discriminator(discriminator_input_dim)

gan_wrapper = GANWrapper(generator, discriminator, z_dim, fashionMNISTDataModule.num_classes(), display_every_n_steps=100)

#### Defining the training parameters

In [None]:
loggers = [
    TensorBoardLogger(Paths.LOGS_DIR, name='gan-training.logs', log_graph=True, version='version-1.0'),
    CSVLogger(Paths.LOGS_DIR, name='gan-training.logs', version='version-1.0')
]
checkpoint_callback = ModelCheckpoint(dirpath=Paths.MODEL_CHECKPOINT_DIR,
                                      filename='gan-wrapper', save_top_k=1,
                                      monitor='val_generator_loss')

In [None]:
%load_ext tensorboard
%tensorboard --logdir ../../logs/gan-training.logs

#### Training the model

In [None]:
trainer = L.Trainer(default_root_dir=Paths.MODEL_CHECKPOINT_DIR, max_epochs=10000, callbacks=[checkpoint_callback],
                    logger=loggers, accelerator=detect_device(), enable_checkpointing=True, log_every_n_steps=50)

trainer.fit(gan_wrapper, datamodule=fashionMNISTDataModule)