In [6]:
import lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import ImageFolder

In [7]:
class ResNetClassifier(pl.LightningModule):
    resnets = {
        18: models.resnet18,
        34: models.resnet34,
        50: models.resnet50,
        101: models.resnet101,
        152: models.resnet152,
    }
    optimizers = {"adam": Adam, "sgd": SGD}

    def __init__(
        self,
        num_classes,
        resnet_version,
        train_path,
        val_path,
        test_path=None,
        optimizer="adam",
        lr=1e-3,
        batch_size=16,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.train_path = train_path
        self.val_path = val_path
        self.test_path = test_path
        self.lr = lr
        self.batch_size = batch_size
        self.optimizer = self.optimizers[optimizer]
        self.loss_fn = nn.CrossEntropyLoss()
        self.acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.resnet_model = self.resnets[resnet_version]

    def forward(self, X):
        return self.resnet_model(X)

    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.lr)

    def _step(self, batch):
        x, y = batch
        preds = self(x)
        
        loss = self.loss_fn(preds, y)
        acc = self.acc(preds, y)
        return loss, acc

    def _dataloader(self, data_path, shuffle=False):
        # values here are specific to pneumonia dataset and should be updated for custom data
        transform = transforms.Compose(
            [
                transforms.Resize((500, 500)),
                transforms.ToTensor(),
                transforms.Normalize((0.48232,), (0.23051,)),
            ]
        )

        img_folder = ImageFolder(data_path, transform=transform)

        return DataLoader(img_folder, batch_size=self.batch_size, shuffle=shuffle)

    def train_dataloader(self):
        return self._dataloader(self.train_path, shuffle=True)

    def training_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def val_dataloader(self):
        return self._dataloader(self.val_path)

    def validation_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, logger=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True, logger=True)

    def test_dataloader(self):
        return self._dataloader(self.test_path)

    def test_step(self, batch, batch_idx):
        loss, acc = self._step(batch)
        # perform logging
        self.log("test_loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("test_acc", acc, on_step=True, prog_bar=True, logger=True)


In [None]:
RUN_TEST = False #Run trained model on test dataset

model = ResNetClassifier(
    num_classes=3,
    resnet_version=50,
    train_path=args.train_set,
    val_path=args.val_set,
    test_path=args.test_set,
    optimizer="adam",
    lr=0.0001,
    batch_size=60,
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=save_path,
    filename="resnet-model-{epoch}-{val_loss:.2f}-{val_acc:0.2f}",
    monitor="val_loss",
    save_top_k=3,
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs="80",
    callbacks=[checkpoint_callback]
)


save_path = "./models"

stopping_callback = pl.callbacks.EarlyStopping()

trainer.fit(model)

if RUN_TEST:
    trainer.test(model)
    
torch.save(trainer.model.resnet_model.state_dict(), save_path + "/trained_model.pt")