Using pretrained network

In [None]:
# set working directory path
WORK_DIR = '/home/iwawiwi/research/22/dla-playground/'

Import torchvision module

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import Flowers102

Create class

In [None]:
class Resnet18(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.model = models.resnet18(pretrained=pretrained)
        # freeze all layers
        for param in self.model.parameters():
            param.requires_grad = False
        # replace the last layer with a new one
        self.model.fc = nn.Linear(512, num_classes) # replace last layer to match number of classes in dataset
    
    def forward(self, x):
        return self.model(x)

Transforms

In [None]:
transformation = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


Using flower102 dataset and define dataloader

In [None]:
import os

In [None]:
train_data = Flowers102(os.path.join(WORK_DIR, "data"), split="train", transform=transformation)
val_data = Flowers102(os.path.join(WORK_DIR, "data"), split="val", transform=transformation)
test_data = Flowers102(os.path.join(WORK_DIR, "data"), split="test", transform=transformation)

# print train, test, val data sizes
print(len(train_data), len(val_data), len(test_data))

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=4)

Evaluating model using pytorch module

In [None]:
# import pytorch lightning
import pytorch_lightning as pl
from torchmetrics.classification.accuracy import Accuracy

In [None]:
class PretrainedModule(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.net = Resnet18(num_classes=num_classes) # using pretrained weights
        self.criterion = nn.CrossEntropyLoss()

        # set performance metric
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def forward(self, x):
        return self.net(x)

    def step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y-1)
        preds = logits.argmax(dim=1)
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        loss, preds, target = self.step(batch)

        # train accuracy metrics
        acc = self.train_acc(preds, target)
        # log loss
        self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        # log accuracy
        self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "target": target}

    def validation_step(self, batch, batch_idx):
        loss, preds, target = self.step(batch)

        # val accuracy metric
        acc = self.val_acc(preds, target)
        # log val loss
        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        # log val accuracy
        self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "target": target}

    def test_step(self, batch, batch_idx):
        loss, preds, target = self.step(batch)

        # test accuracy metric
        acc = self.test_acc(preds, target)
        # log test accuracy
        self.log('test/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        # log loss
        self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "target": target}

    def on_epoch_end(self):
        # reset metrics at the end of every epoch
        self.train_acc.reset()
        self.test_acc.reset()
        self.val_acc.reset()

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)


Init trainer and traning module

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=2, log_every_n_steps=30)
model = PretrainedModule(num_classes=102)

Do Training

In [None]:
trainer.fit(model, train_loader, val_loader)

Do test

In [None]:
trainer.test(model, dataloaders=test_loader)