📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Load Dataset](#toc2_)    
- [Model](#toc3_)    
- [Set up remaining Hyperparameters](#toc4_)    
- [Train Loop](#toc5_)    
- [Feature Extraction](#toc6_)    
  - [model.feature_extractor.0](#toc6_1_)    
  - [model.feature_extractor.5](#toc6_2_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchinfo import summary
from torchmetrics.classification import MulticlassAccuracy
from torchvision.datasets import CIFAR10
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import v2

In [2]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# log
device

# <a id='toc2_'></a>[Load Dataset](#toc0_)

In [4]:
transform = v2.Compose(
    [v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=(0.5,), std=(0.5,))]
)

trainset = CIFAR10("../datasets", train=True, transform=transform, download=False)
testset = CIFAR10("../datasets", train=False, transform=transform, download=False)

In [5]:
classes = np.array(trainset.classes)
num_classes = len(classes)
num_trainset, height, width, depth = trainset.data.shape

In [6]:
batch_size = 128

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

# <a id='toc3_'></a>[Model](#toc0_)

In [None]:
class CustomModel(nn.Module):
    def __init__(self, layers):
        super().__init__()

        feature_extractor_layers = []
        for i, o in zip(layers, layers[1:]):
            feature_extractor_layers.append(nn.Conv2d(i, out_channels=o, kernel_size=3))
            feature_extractor_layers.append(nn.BatchNorm2d(o))
            feature_extractor_layers.append(nn.ReLU())
            feature_extractor_layers.append(nn.MaxPool2d(kernel_size=2))

        self.feature_extractor = nn.Sequential(*feature_extractor_layers)
        self.flatten = nn.Flatten(start_dim=1)
        self.classifier = nn.Linear(1152, len(classes))

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x


# initialize the model
model = CustomModel(layers=[depth, 16, 32]).to(device)

# log
model

In [None]:
summary(model, input_size=(batch_size, *testset.data.transpose(0, 3, 1, 2).shape[1:]))

# <a id='toc4_'></a>[Set up remaining Hyperparameters](#toc0_)

In [9]:
lr = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=lr)
num_epochs = 10

# <a id='toc5_'></a>[Train Loop](#toc0_)

In [10]:
train_acc_per_epoch = []
train_loss_per_epoch = []
train_acc = MulticlassAccuracy(num_classes=len(testset.classes), top_k=1).to(device)

In [None]:
for epoch in range(num_epochs):

    # train loop
    model.train()
    train_loss = 0

    for x, y in trainloader:

        # send data to GPU
        x, y_true = x.to(device), y.to(device)

        # forward
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        # backward
        loss.backward()

        # update parameters
        optimizer.step()
        optimizer.zero_grad()

        # store loss and accuracy per iteration
        train_loss += loss.item() * len(x)
        train_acc.update(y_pred, y_true)

    # store loss and accuracy per epoch
    train_loss_per_epoch.append(train_loss / len(trainset))
    train_acc_per_epoch.append(train_acc.compute().item())
    train_acc.reset()

    # log
    print(
        f"epoch {epoch+1:0{len(str(num_epochs))}}/{num_epochs} -> train[loss: {train_loss_per_epoch[epoch]:7.5f} - acc: {train_acc_per_epoch[epoch]*100:5.2f}%]"
    )

# <a id='toc6_'></a>[Feature Extraction](#toc0_)

In [None]:
for name, param in model.named_parameters():
    print(f"{name:<26} - requires_grad: {param.requires_grad}")

In [13]:
nodes = {"feature_extractor.0": "conv1", "feature_extractor.5": "conv2"}

feature_extractor = create_feature_extractor(model, return_nodes=nodes)

In [None]:
frog = transform(testset.data[0]).to(device)[None, :, :, :]

feature_maps = feature_extractor(frog)
feature_maps.keys()

In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 4), layout="compressed")
axs[0].imshow(testset.data[0])
axs[0].axis("off")
axs[0].set_title("Orignal")
axs[1].imshow(frog.detach().cpu()[0].permute(1, 2, 0))
axs[1].axis("off")
axs[1].set_title("Transformed")
plt.show()

## <a id='toc6_1_'></a>[model.feature_extractor.0](#toc0_)
   - Feature maps : 16x30x30

In [None]:
# plot
total_rows = 4
total_cols = 4
fig, axs = plt.subplots(
    nrows=total_rows, ncols=total_cols, figsize=(total_cols * 2, total_rows * 2), layout="compressed"
)
fig.suptitle("model.feature_extractor.0 feature maps")
for row in range(total_rows):
    for col in range(total_cols):
        axs[row, col].imshow(feature_maps["conv1"][0, row * total_cols + col].detach().cpu(), cmap="gray")
        axs[row, col].axis("off")
        axs[row, col].set(title=row * total_cols + col)
plt.show()

## <a id='toc6_2_'></a>[model.feature_extractor.5](#toc0_)
   - Feature maps : 32x13x13

In [None]:
# plot
total_rows = 4
total_cols = 8
fig, axs = plt.subplots(
    nrows=total_rows, ncols=total_cols, figsize=(total_cols * 2, total_rows * 2), layout="compressed"
)
fig.suptitle("model.feature_extractor.0 feature maps")
for row in range(total_rows):
    for col in range(total_cols):
        axs[row, col].imshow(feature_maps["conv2"][0, row * total_cols + col].detach().cpu(), cmap="gray")
        axs[row, col].axis("off")
        axs[row, col].set(title=row * total_cols + col)
plt.show()