#### Imports

In [1]:
import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

from src.dataset.FashionMNISTDataModule import FashionMNISTDataModule
from src.models.classifier import ResNet18Classifier
from src.models.vae import VAE
from src.train.wrapper.vae_augmented_training_wrapper import VAEAugmentedTrainingWrapper
from src.train.wrapper.vae_wrapper import VAEWrapper
from src.utils.constants import Paths
from src.utils.helpers import detect_device

#### Defining batch size

In [2]:
BATCH_SIZE = 1024

#### Create FashionMNISTDataModule

In [3]:
fashionMNISTDataModule = FashionMNISTDataModule(Paths.DATA_DIR, BATCH_SIZE)
fashionMNISTDataModule.setup('fit')

# Report split sizes
print('Training set has {} instances'.format(len(fashionMNISTDataModule.train_dataloader()) * BATCH_SIZE))
print('Validation set has {} instances'.format(len(fashionMNISTDataModule.val_dataloader()) * BATCH_SIZE))

Training set has 60416 instances
Validation set has 10240 instances


#### Defining the VAE parameters

In [4]:
input_dim = 28 * 28
hidden_dim = 400
latent_dim = 200

#### Assemble the model

In [5]:
model = ResNet18Classifier(fashionMNISTDataModule.num_classes())

vae = VAE(input_dim, hidden_dim, latent_dim, detect_device()).to(detect_device())

vae_wrapper = VAEWrapper(vae, display_every_n_steps=100)
vae_wrapper.load_state_dict(torch.load(Paths.VAE_WRAPPER_CHECKPOINT_FILE_PATH)['state_dict'])

classifier_wrapper = VAEAugmentedTrainingWrapper(model, vae_wrapper.vae, train_img_original=fashionMNISTDataModule.train_dataloader_unaltered())

  vae_wrapper.load_state_dict(torch.load(Paths.VAE_WRAPPER_CHECKPOINT_FILE_PATH)['state_dict'])


FileNotFoundError: [Errno 2] No such file or directory: 'D:\\Projects\\Github\\augment-aid-ml\\assets\\model_checkpoints\\vae-wrapper.ckpt'

#### Adding logging and checkpointing

In [None]:
loggers = [
    TensorBoardLogger(Paths.LOGS_DIR, name='classifier_vae_training.logs', log_graph=True, version='version-1.0'),
    CSVLogger(Paths.LOGS_DIR, name='classifier_vae_training.logs', version='version-1.0')
]

checkpoint_callback = ModelCheckpoint(dirpath=Paths.MODEL_CHECKPOINT_DIR,
                                      filename='classifier-vae-{epoch:02d}-{val_loss:.2f}', save_top_k=3,
                                      monitor='val_loss')

In [None]:
trainer = L.Trainer(default_root_dir=Paths.MODEL_CHECKPOINT_DIR, max_epochs=50, callbacks=[checkpoint_callback], logger=loggers, accelerator=detect_device(), enable_checkpointing=True, log_every_n_steps=50)

trainer.fit(classifier_wrapper, datamodule=fashionMNISTDataModule)