#### Imports

In [1]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

from src.dataset.FashionMNISTDataModule import FashionMNISTDataModule
from src.models.classifier import ResNet18Classifier
from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.train.wrapper.gan_augmented_training_wrapper import GANAugmentedTrainingWrapper
from src.train.wrapper.gan_wrapper import GANWrapper
from src.utils.constants import Paths
from src.utils.helpers import determine_state_dict, detect_device

#### Defining batch size

In [2]:
BATCH_SIZE = 512

#### Creating the FashionMNISTDataModule

In [3]:
fashionMNISTDataModule = FashionMNISTDataModule(Paths.DATA_DIR, BATCH_SIZE)
fashionMNISTDataModule.setup('fit')

# Report split sizes
print('Training set has {} instances'.format(len(fashionMNISTDataModule.train_dataloader()) * BATCH_SIZE))
print('Validation set has {} instances'.format(len(fashionMNISTDataModule.val_dataloader()) * BATCH_SIZE))

Training set has 60416 instances
Validation set has 10240 instances


#### Defining the GAN parameters

In [4]:
z_dim = 128
generator_input_dim = z_dim + fashionMNISTDataModule.num_classes()
input_channels = 1
discriminator_input_dim = input_channels + fashionMNISTDataModule.num_classes()

#### Assemble the model

In [5]:


model = ResNet18Classifier(fashionMNISTDataModule.num_classes())
generator = Generator(generator_input_dim, input_channels)
discriminator = Discriminator(discriminator_input_dim)

gan_wrapper = GANWrapper(generator, discriminator, z_dim, fashionMNISTDataModule.num_classes(), 10)
gan_wrapper.load_state_dict(determine_state_dict(Paths.MODEL_CHECKPOINT_DIR, 'gan', ))

classifier_wrapper = GANAugmentedTrainingWrapper(model, gan_wrapper.generator, z_dim,
                                                 fashionMNISTDataModule.num_classes())

  gan_wrapper.load_state_dict(torch.load(Paths.GAN_WRAPPER_CHECKPOINT_FILE_PATH)['state_dict'])


FileNotFoundError: [Errno 2] No such file or directory: 'D:\\Projects\\Github\\augment-aid-ml\\assets\\model_checkpoints\\gan-wrapper.ckpt'

#### Adding logging and checkpointing

In [None]:
loggers = [
    TensorBoardLogger(Paths.LOGS_DIR, name='classifier_gan_training.logs', log_graph=True, version='version-1.0'),
    CSVLogger(Paths.LOGS_DIR, name='classifier_gan_training.logs', version='version-1.0')
]
checkpoint_callback = ModelCheckpoint(dirpath=Paths.MODEL_CHECKPOINT_DIR,
                                      filename='classifier-gan-{epoch:02d}-{val_loss:.2f}', save_top_k=3,
                                      monitor='val_loss')

#### Training the model

In [None]:
trainer = L.Trainer(default_root_dir=Paths.MODEL_CHECKPOINT_DIR, max_epochs=50, callbacks=[checkpoint_callback],
                    logger=loggers, accelerator=detect_device(), enable_checkpointing=True, log_every_n_steps=50)

trainer.fit(classifier_wrapper, datamodule=fashionMNISTDataModule)