# CIFAR100

In [None]:
from vision.nn.core.residual import ResNet50
import lightning as L
import torch
from torchvision.datasets import CIFAR100
from torchvision import transforms
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelSummary


In [None]:
class CIFAR100Model(L.LightningModule):
    DATA_DIR = "./data/"
    TRAIN_RATIO = 0.9
    NUM_CLASSES = 100
    NUM_CHANNELS = 3

    def __init__(self, batch_size=32):
        super().__init__()
        self.model = ResNet50(self.NUM_CLASSES, self.NUM_CHANNELS)
        self.batch_size = batch_size

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.NUM_CLASSES)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.NUM_CLASSES)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=self.NUM_CLASSES)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_accuracy.update(preds, y)

        self.log("train_loss", loss)
        self.log("train_acc", self.train_accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def prepare_data(self):
        # download
        CIFAR100(self.DATA_DIR, train=True, download=True)
        CIFAR100(self.DATA_DIR, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            data_full = CIFAR100(self.DATA_DIR, train=True, transform=self.transform)
            len_train = int(len(data_full) * self.TRAIN_RATIO)
            len_val = len(data_full) - len_train
            self.data_train, self.data_val = random_split(data_full, [len_train, len_val])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.data_test = CIFAR100(self.DATA_DIR, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=10)

    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=10)

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size)


In [None]:
model = CIFAR100Model(batch_size=32)

trainer = L.Trainer(
    max_epochs=100,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger("./experiments", name="cifar100"),
    callbacks=[EarlyStopping(monitor="val_loss", patience=10), ModelSummary(max_depth=-1)],
)
trainer.fit(model)

In [None]:
trainer.test()