In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)

# =====================
# Configuration
# =====================
DATASET_ROOT = "./dataset"
BATCH_SIZE = 32
NUM_CLASSES = 10
LR = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
NUM_WORKERS = 4
SEED = 42

loss_history = []
acc_history = []
f1_history = []

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)

# =====================
# Transforms
# =====================
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# =====================
# Datasets
# =====================
full_train_dataset = ImageFolder(DATASET_ROOT, transform=train_transforms)
full_test_dataset = ImageFolder(DATASET_ROOT, transform=test_transforms)

train_size = int(0.8 * len(full_train_dataset))
test_size = len(full_train_dataset) - train_size

generator = torch.Generator().manual_seed(SEED)

train_dataset, _ = random_split(
    full_train_dataset, [train_size, test_size], generator=generator
)
_, test_dataset = random_split(
    full_test_dataset, [train_size, test_size], generator=generator
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# =====================
# Model
# =====================
class ConvNeuralNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        return self.classifier(x)


model = ConvNeuralNet(NUM_CLASSES).to(DEVICE)

# =====================
# Optimization
# =====================
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

# =====================
# TorchMetrics
# =====================
accuracy = MulticlassAccuracy(num_classes=NUM_CLASSES).to(DEVICE)
precision = MulticlassPrecision(num_classes=NUM_CLASSES, average="macro").to(DEVICE)
recall = MulticlassRecall(num_classes=NUM_CLASSES, average="macro").to(DEVICE)
f1 = MulticlassF1Score(num_classes=NUM_CLASSES, average="macro").to(DEVICE)

# =====================
# Training & Evaluation
# =====================
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)

    model.eval()
    accuracy.reset()
    precision.reset()
    recall.reset()
    f1.reset()

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            accuracy.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)
            f1.update(preds, labels)

            loss_history.append(avg_loss)
            acc_history.append(accuracy.compute().item())
            f1_history.append(f1.compute().item())
    
        print(
        f"Epoch [{epoch + 1}/{EPOCHS}] "
        f"Loss: {avg_loss:.4f} "
        f"Acc: {acc_history[-1]:.4f} "
        f"F1: {f1_history[-1]:.4f}"
        )

Epoch [1/20] Loss: 1.6692 Acc: 0.2092 F1: 0.1908
Epoch [2/20] Loss: 1.5658 Acc: 0.2331 F1: 0.2197


In [None]:
epochs = range(1, EPOCHS + 1)

plt.figure()
plt.plot(epochs, loss_history)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True)
plt.show()

plt.figure()
plt.plot(epochs, acc_history)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy")
plt.grid(True)
plt.show()

plt.figure()
plt.plot(epochs, f1_history)
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.title("Validation F1 Score")
plt.grid(True)
plt.show()