#### Imports

In [None]:
import lightning as L
import torch
import torchvision
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

from src.dataset.FashionMNISTDataModule import FashionMNISTDataModule
from src.train.wrapper.classifier_wrapper import ClassifierWrapper
from src.utils.constants import Paths
from src.utils.helpers import detect_device

#### Defining batch size

In [None]:
BATCH_SIZE = 512

#### Create FashionMNISTDataModule

In [None]:
fashionMNISTDataModule = FashionMNISTDataModule(Paths.DATA_DIR, BATCH_SIZE)
fashionMNISTDataModule.setup('fit')

# Report split sizes
print('Training set has {} instances'.format(len(fashionMNISTDataModule.train_dataloader())*BATCH_SIZE))
print('Validation set has {} instances'.format(len(fashionMNISTDataModule.val_dataloader())*BATCH_SIZE))

In [None]:
model = torchvision.models.resnet18(weights=None)
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, 10)

classifier_wrapper = ClassifierWrapper(model, len(fashionMNISTDataModule.dataset_classes()))

In [None]:
loggers = [
    TensorBoardLogger(Paths.LOGS_DIR, name='classifier_training.logs', log_graph=True, version='version-1.0'),
    CSVLogger(Paths.LOGS_DIR, name='classifier_training.logs', version='version-1.0')
]
checkpoint_callback = ModelCheckpoint(dirpath=Paths.MODEL_CHECKPOINT_DIR,
                                      filename='classifier-{epoch:02d}-{val_loss:.2f}', save_top_k=3,
                                      monitor='val_loss')

In [None]:
trainer = L.Trainer(default_root_dir=Paths.MODEL_CHECKPOINT_DIR, max_epochs=50, callbacks=[checkpoint_callback],
                    logger=loggers, accelerator=detect_device(), enable_checkpointing=True, log_every_n_steps=50)

trainer.fit(classifier_wrapper, datamodule=fashionMNISTDataModule)