In [2]:
import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms, datasets

import wandb

torch.set_float32_matmul_precision("medium")

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mbrinashong[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
from torchvision.models import AlexNet_Weights

# get data transforms
weights = AlexNet_Weights.DEFAULT
data_transforms = weights.transforms()
print(data_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [5]:
# Load your custom dataset
data_dir = './dataset'

# Define transforms for the dataset
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'test']}
class_names = image_datasets['train'].classes

In [6]:
class ImageClassifierModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './dataset'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'test': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }        

        self.image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), self.transform[x])
                  for x in ['train', 'test']}
        self.class_names = self.image_datasets['train'].classes
        self.num_classes = len(self.class_names)

        set_train_full = self.image_datasets['train']
        self.set_train, self.set_val = random_split(set_train_full, [0.9, 0.1])
        self.set_test = self.image_datasets['test']

    def train_dataloader(self):
        return DataLoader(self.set_train, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.set_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.set_test, batch_size=self.batch_size, num_workers=4)

In [7]:
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{class_names[pred] if class_names is not None else pred}, Label:{class_names[y] if class_names is not None else y}")
                           for x, pred, y in zip(val_imgs[:self.num_samples],
                                                 preds[:self.num_samples],
                                                 val_labels[:self.num_samples])]
            })


In [8]:
from torchvision import models

class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4, transfer=True):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        
        # transfer learning if pretrained=True
        self.feature_extractor = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
        
        if transfer:
            # layers are frozen by using eval()
            self.feature_extractor.eval()
            # freeze params
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
        
        n_sizes = self._get_conv_output(input_shape)

        self.classifier = nn.Linear(n_sizes, num_classes)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy('multiclass', num_classes=num_classes)
  
    # returns the size of the output tensor going into the Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(tmp_input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = self.classifier(x)
        
       return x
    
    def training_step(self, batch):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        loss = self.criterion(out, gt)
        acc = self.accuracy(out, gt)

        self.log("train/loss", loss)
        self.log("train/acc", acc)

        return loss
    
    def validation_step(self, batch, batch_idx):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        loss = self.criterion(out, gt)
        acc = self.accuracy(out, gt)

        self.log("val/loss", loss)
        self.log("val/acc", acc)

        return loss
    
    def test_step(self, batch, batch_idx):
        batch, gt = batch[0], batch[1]
        out = self.forward(batch)
        loss = self.criterion(out, gt)
        acc = self.accuracy(out, gt)
        
        self.log("test/loss", loss)
        self.log("test/acc", acc)
        
        return {"loss": loss, "outputs": out, "gt": gt}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)  


In [9]:
dm = ImageClassifierModule(batch_size=32, data_dir="./dataset")

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

(torch.Size([32, 3, 224, 224]), torch.Size([32]))

In [10]:
model = LitModel((3, 224, 224), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='Retail Image Classification - AlexNet', job_type='train')

# Initialize Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val/loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# Initialize a trainer
trainer = pl.Trainer(max_epochs=100,
                     accelerator="auto",
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# Close wandb run
wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type               | Params
---------------------------------------------------------
0 | feature_extractor | AlexNet            | 61.1 M
1 | classifier        | Linear             | 2.4 M 
2 | criterion         | CrossEntropyLoss   | 0     
3 | accuracy          | MulticlassAccuracy | 0     
---------------------------------------------------------
2.4 M     Trainable params
61.1 M    Non-trainable params
63.5 M    Total params
253.949   Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …



Validation: |                                                                                                 …

Restoring states from the checkpoint path at ./Retail Image Classification - AlexNet/qypw6fpb/checkpoints/epoch=36-step=358900.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./Retail Image Classification - AlexNet/qypw6fpb/checkpoints/epoch=36-step=358900.ckpt


Testing: |                                                                                                    …



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/acc            0.7506399750709534
        test/loss           1.0794569253921509
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


VBox(children=(Label(value='53.126 MB of 53.126 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test/acc,▁
test/loss,▁
train/acc,▅▇▅▅▇▄▆▅▇▆▆▇▇▇▆▇▆▆▇▇▆▇▇▆▆▇▅▇▇▆▇▇▇█▁▆█▇▆▆
train/loss,▅▃▄▃▂▃▃▃▂▂▃▁▃▂▃▁▂▂▂▂▃▂▁▃▃▁▃▂▁▃▂▂▁▂█▃▁▁▃▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▁▃▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████████████
val/loss,█▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁

0,1
epoch,37.0
test/acc,0.75064
test/loss,1.07946
train/acc,1.0
train/loss,0.24936
trainer/global_step,358900.0
val/acc,0.61966
val/loss,1.79639
