## Import libraries


In [None]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import ImageFolder
from transformers import ViTImageProcessor, ViTForImageClassification

In [None]:
device = torch.device("cuda")

## Prepare


In [None]:
def load_model(num_classes: int, model_name: str = "google/vit-base-patch16-224"):
    model = ViTForImageClassification.from_pretrained(model_name)
    model.classifier = nn.Sequential(
        nn.Linear(in_features=768, out_features=512),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=512, out_features=256),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=256, out_features=num_classes, bias=False),
    )
    return model

In [None]:
def load_train_dataset(
    root: str, model_name: str = "google/vit-base-patch16-224", batch_size: int = 32
) -> DataLoader:
    processor = ViTImageProcessor.from_pretrained(model_name)
    data_augmentation = transforms.Compose(
        [
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop(size=(224, 224), antialias=True),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
        ]
    )
    transform = transforms.Compose(
        [
            data_augmentation,
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    dataset = ImageFolder(os.path.join(root, "train"), transform=transform)
    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

In [None]:
def load_test_dataset(
    root: str,
    model_name: str = "google/vit-base-patch16-224",
    batch_size: int = 32,
    valid_size: float = 0.4,
):
    processor = ViTImageProcessor.from_pretrained(model_name)
    transform = transforms.Compose(
        [
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    dataset = ImageFolder(os.path.join(root, "test"), transform=transform)
    num_test = len(dataset)
    indices = list(range(num_test))
    split = int(np.floor(valid_size * num_test))

    np.random.seed(42)
    np.random.shuffle(indices)
    test_idx, valid_idx = indices[split:], indices[:split]
    test_sampler = sampler.SubsetRandomSampler(test_idx)
    valid_sampler = sampler.SubsetRandomSampler(valid_idx)

    test_loader = DataLoader(
        dataset=dataset, batch_size=batch_size, sampler=test_sampler
    )
    valid_loader = DataLoader(
        dataset=dataset, batch_size=batch_size, sampler=valid_sampler
    )
    return test_loader, valid_loader

## Train


In [None]:
train_loader = load_train_dataset(root="/kaggle/input/pbl6-dataset", batch_size=16)
test_loader, valid_loader = load_test_dataset(
    root="/kaggle/input/pbl6-dataset", batch_size=16, valid_size=0.4
)
model = load_model(num_classes=10).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
NUM_EPOCHS = 20
losses, accuracies = [], []
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    losses.append(running_loss)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracies.append(100 * correct / total)
    print(
        f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: {losses[-1]:.3f}, Accuracy: {accuracies[-1]:.2f}%"
    )
    scheduler.step()
    if (epoch + 1) % 5 == 0:
        torch.save(
            {
                "model_state_dict": model.state_dict(),
            },
            f"checkpoint_{epoch + 1}.pt",
        )

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
plt.plot(accuracies)
plt.title("Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")

plt.show()

## Test


In [None]:
def test_checkpoint(path: str, loader: DataLoader):
    checkpoint = torch.load(f=path, map_location=device, weights_only=True)
    model = load_model(num_classes=10).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    _correct, _total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            _total += labels.size(0)
            _correct += (predicted == labels).sum().item()

    print(f"Checkpoint: {path.split("_")[1]}, Accuracy: {100 * _correct / _total:.2f}%")