In [27]:
from typing import Any, Dict, List, Optional, Tuple, Union
import logging
from os import path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import (
    cli_lightning_logo,
    LightningDataModule,
    LightningModule,
    Trainer,
)


In [29]:
TransformTypes = Optional[Union[A.Compose, T.Compose]]
EPOCHS = 2
BATCH_SIZE = 32
LEARNING_RATE = 0.003
IMAGE_SIZE = 224
TRAIN_DATA_PATH: str = path.join("..", "data", "train")
VALIDATION_DATA_PATH: str = path.join("..", "data", "val")
TEST_DATA_PATH: str = path.join("..", "data", "test")


TRAIN_TRANSFORM_IMG = T.Compose(
    [
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
TEST_TRANSFORM_IMG = T.Compose(
    [
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
# TRAIN_TRANSFORM_IMG = A.Compose(
#     [
#         A.augmentations.crops.transforms.RandomResizedCrop(
#             height=IMAGE_SIZE, width=IMAGE_SIZE, scale=[0.9, 1], ratio=[1, 1]
#         ),
#         A.augmentations.transforms.Normalize(
#             mean=[0.4913997551666284, 0.48215855929893703, 0.4465309133731618],
#             std=[0.24703225141799082, 0.24348516474564, 0.26158783926049628],
#         ),
#         ToTensorV2(),
#     ]
# )
# TEST_TRANSFORM_IMG = A.Compose(
#     [
#         A.augmentations.geometric.resize.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
#         A.augmentations.transforms.Normalize(
#             mean=[0.4913997551666284, 0.48215855929893703, 0.4465309133731618],
#             std=[0.24703225141799082, 0.24348516474564, 0.26158783926049628],
#         ),
#         ToTensorV2(),
#     ]
# )

In [30]:
def set_parameter_requires_grad(model, feature_extracting=True):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def create_model(num_classes: int = 2):
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    set_parameter_requires_grad(model, feature_extracting=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model


In [31]:
class Backbone(nn.Module):
    def __init__(self, num_classes: int = 2, dropout: float = 0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(16 * 1 * 1, 120),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(120, 84),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(84, num_classes),
        )

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # x = self.avgpool(x)
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = self.classifier(x)
        return x


In [32]:
class ImageClassifier(LightningModule):
    """
    >>> LitClassifier(Backbone())  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    LitClassifier(
      (backbone): ...
    )
    """

    def __init__(
        self, backbone: Optional[Backbone] = None, learning_rate: float = 0.0001
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["backbone"])
        if backbone is None:
            backbone = Backbone()
        self.backbone = backbone

    def forward(self, x):
        # use forward for inference/predictions
        embedding = self.backbone(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("valid_loss", loss, on_step=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)


In [33]:
class ImageClassificationDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.train_dataset = torchvision.datasets.ImageFolder(
            root=TRAIN_DATA_PATH, transform=TRAIN_TRANSFORM_IMG
        )
        self.validation_dataset = torchvision.datasets.ImageFolder(
            root=VALIDATION_DATA_PATH, transform=TEST_TRANSFORM_IMG
        )
        self.test_dataset = torchvision.datasets.ImageFolder(
            root=TEST_DATA_PATH, transform=TEST_TRANSFORM_IMG
        )

        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


In [34]:
trainer = Trainer(
    max_epochs=EPOCHS
)  # , callbacks=[EarlyStopping(monitor="valid_loss", mode="min")]
model = ImageClassifier(backbone=create_model())
data_module = ImageClassificationDataModule()
trainer.fit(model, datamodule=data_module)
trainer.test(ckpt_path="best", datamodule=data_module)
predictions = trainer.predict(ckpt_path="best", datamodule=data_module)
print(predictions[0])


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type   | Params
------------------------------------
0 | backbone | ResNet | 23.5 M
------------------------------------
4.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.049    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 1: 100%|██████████| 163/163 [01:00<00:00,  2.68it/s, v_num=2]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 163/163 [01:01<00:00,  2.67it/s, v_num=2]


Restoring states from the checkpoint path at /Users/furyhawk/Explainable-Neural-Networks/notebooks/lightning_logs/version_2/checkpoints/epoch=1-step=326.ckpt
Loaded model weights from the checkpoint at /Users/furyhawk/Explainable-Neural-Networks/notebooks/lightning_logs/version_2/checkpoints/epoch=1-step=326.ckpt
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 20/20 [00:06<00:00,  3.20it/s]


Restoring states from the checkpoint path at /Users/furyhawk/Explainable-Neural-Networks/notebooks/lightning_logs/version_2/checkpoints/epoch=1-step=326.ckpt
Loaded model weights from the checkpoint at /Users/furyhawk/Explainable-Neural-Networks/notebooks/lightning_logs/version_2/checkpoints/epoch=1-step=326.ckpt
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 20/20 [00:08<00:00,  2.25it/s]
tensor([[ 0.0389, -0.3110],
        [-0.3541,  0.2206],
        [ 0.4866, -0.4310],
        [-0.4913,  0.5149],
        [ 0.2809, -0.2115],
        [ 0.2992, -0.5055],
        [-0.1557,  0.2244],
        [ 0.0430, -0.1161],
        [ 0.1052, -0.1405],
        [ 0.1736, -0.1959],
        [ 0.1231, -0.0484],
        [-0.4312,  0.4116],
        [ 0.2958, -0.2707],
        [-0.0429,  0.1251],
        [-0.1399,  0.2113],
        [-0.0820,  0.2951],
        [-1.0355,  1.1248],
        [ 0.0376,  0.1329],
        [ 0.7288, -0.5405],
        [ 0.0058,  0.2159],
        [ 0.2130, -0.2014],
        [-0.1570,  0.0375],
        [ 0.3130, -0.2778],
        [ 0.0967, -0.3230],
        [ 0.1634, -0.0106],
        [ 0.1893, -0.1183],
        [ 0.1806, -0.0971],
        [ 0.1242, -0.0677],
        [-0.0372, -0.0669],
        [-0.3038,  0.3310],
        [ 0.3326, -0.2008],
        [-0.0511, -0.0314]])


In [35]:
ImageClassifier()

ImageClassifier(
  (backbone): Backbone(
    (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (classifier): Sequential(
      (0): Linear(in_features=16, out_features=120, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=120, out_features=84, bias=True)
      (4): ReLU(inplace=True)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=84, out_features=2, bias=True)
    )
  )
)