In [None]:
import torch
import lightning as L
from dummy_problems.dataloaders import LettersDataModule
from timm import create_model
from pathlib import Path
import torchmetrics

class ClassificationModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = create_model('mobilenetv3_large_100', num_classes=26)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=26)

    def training_step(self, batch):
        images, targets = batch
        outputs = self.model(images)
        loss = self.loss_fn(outputs, targets)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch):
        images, targets = batch
        outputs = self.model(images)
        
        self.accuracy(outputs, targets)
        self.log('test_acc_step', self.accuracy)

    def on_test_epoch_end(self):
        self.log('test_acc_epoch', self.accuracy)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=0.005)   

In [2]:
model = ClassificationModel()

In [None]:
# Training
settings = {
    "dataset_dir": Path("/home/ubuntu/data/letters_dataset"),
    "stage": "fit",
}

data = LettersDataModule(settings)
trainer = L.Trainer(max_epochs=40)
trainer.fit(model, data)


In [None]:
# Testing
settings = {
    "dataset_dir": Path("/home/ubuntu/data/letters_dataset"),
    "stage": "test",
}

data = LettersDataModule(settings)
trainer = L.Trainer()
trainer.test(model=model, datamodule=data, ckpt_path="lightning_logs/version_23/checkpoints/epoch=39-step=280.ckpt")