# 🏞 Convolutional Neural Network

In this notebook, we'll walk through the steps required to train your own convolutional neural network (CNN) on the CIFAR dataset

In [None]:
import os, sys
from dotenv import load_dotenv

load_dotenv()
python_path = os.getenv('PYTHONPATH')
data_path = os.getenv('DATA_PATH')
if python_path:
    for path in python_path.split(os.pathsep):
        if path not in sys.path:
            sys.path.append(path)

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm, trange

from notebooks.pt_utils import display, Trainer

## 0. Parameters <a name="parameters"></a>

In [None]:
NUM_CLASSES = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. Prepare the Data <a name="prepare"></a>

In [None]:
cifar10_transform = transforms.ToTensor()
train_dataset = datasets.CIFAR10(root=data_path, train=True, transform=cifar10_transform)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])

test_dataset = datasets.CIFAR10(root=data_path, train=False, transform=cifar10_transform)

In [None]:
BATCH_SIZE = 256
NUM_WOERKERS = 24

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')

In [None]:
images, labels = next(iter(train_loader))
display(images)
print(labels[:10])

## 2. Build the model <a name="build"></a>

In [None]:
class CNN(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.seq = nn.Sequential(

            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Flatten(),

            nn.Linear(in_features=8*8*64, out_features=128),
            nn.BatchNorm1d(num_features=128),
            nn.LeakyReLU(),
            nn.Dropout(0.5),

            nn.Linear(in_features=128, out_features=output_size)
        )
    
    def forward(self, x):
        return self.seq(x)
    
model = CNN(input_size=(3, 32, 32), output_size=10).to(device)

## 3. Train the model <a name="train"></a>

In [None]:
loss_fn = F.cross_entropy
pred_fn = lambda logits: torch.argmax(logits, dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

trainer = Trainer(
    model=model,
    optimizer=optimizer, 
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    pred_fn=pred_fn,
    device=device, 
    metrics=['accuracy'],
)

In [None]:
history = trainer.fit(epochs=10)

## 4. Evaluation <a name="evaluate"></a>

In [None]:
trainer.evaluate(test_loader)

In [None]:
x_test, y_test = map(list, zip(*[test_dataset[i] for i in range(100)]))

CLASSES = np.array(
    [
        "airplane",
        "automobile",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    ]
)

preds = trainer.predict(x_test)
preds_single = CLASSES[preds]
actual_single = CLASSES[y_test]

In [None]:
n_to_show = 10
indices = np.random.choice(range(len(x_test)), n_to_show)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, idx in enumerate(indices):
    img = x_test[idx].permute(1, 2, 0)
    ax = fig.add_subplot(1, n_to_show, i + 1)
    ax.axis("off")
    ax.text(
        0.5,
        -0.35,
        "pred = " + str(preds_single[idx]),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.text(
        0.5,
        -0.7,
        "act = " + str(actual_single[idx]),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.imshow(img)