# Dependencies


In [None]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import ImageFolder
from transformers import ViTForImageClassification, ViTImageProcessor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def save_checkpoint(model: nn.Module, optimizer: optim.Optimizer, path: str):
    obj = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(obj, path)


def load_checkpoint(model: nn.Module, optimizer: optim.Optimizer, path: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(path, map_location=device, weights_only=True)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    return model, optimizer

In [None]:
def train_step(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.CrossEntropyLoss,
):
    model.train()
    loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loss += loss.item()

    return loss


def test_step(model: nn.Module, dataloader: DataLoader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total

# Dataset


In [None]:
def load_train_dataset(
    root: str, model_name: str = "google/vit-base-patch16-224", batch_size: int = 32
) -> DataLoader:
    processor = ViTImageProcessor.from_pretrained(model_name)
    transform = transforms.Compose(
        [
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop(size=(224, 224), antialias=True),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    dataset = ImageFolder(os.path.join(root, "train"), transform=transform)

    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

In [None]:
def load_test_dataset(
    root: str, model_name: str = "google/vit-base-patch16-224", batch_size: int = 32
):
    processor = ViTImageProcessor.from_pretrained(model_name)
    transform = transforms.Compose(
        [
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    dataset = ImageFolder(os.path.join(root, "test"), transform=transform)
    indices = list(range(len(dataset)))
    split = int(np.floor(0.3 * len(dataset)))

    # Set seed for reproducibility
    np.random.seed(42)
    np.random.shuffle(indices)

    # Split the dataset
    test_idx, valid_idx = indices[split:], indices[:split]
    test_sampler = sampler.SubsetRandomSampler(test_idx)
    valid_sampler = sampler.SubsetRandomSampler(valid_idx)
    test_loader = DataLoader(
        dataset=dataset, batch_size=batch_size, sampler=test_sampler
    )
    valid_loader = DataLoader(
        dataset=dataset, batch_size=batch_size, sampler=valid_sampler
    )

    return test_loader, valid_loader

# Model


In [None]:
def load_model(num_classes: int, model_name: str = "google/vit-base-patch16-224"):
    model = ViTForImageClassification.from_pretrained(model_name)
    model.classifier = nn.Sequential(
        nn.Linear(in_features=768, out_features=512),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=512, out_features=256),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=256, out_features=num_classes, bias=False),
    )

    return model

# Train


In [None]:
train_loader = load_train_dataset(root="/kaggle/input/pbl6-dataset")
test_loader, valid_loader = load_test_dataset(root="/kaggle/input/pbl6-dataset")

In [None]:
model = load_model(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
EPOCHS = 20
for epoch in range(1, EPOCHS + 1):
    loss = train_step(model, train_loader, optimizer, criterion)
    accuracy = test_step(model, valid_loader)
    scheduler.step()
    print(f"Epoch [{epoch}/{EPOCHS}] | Loss: {loss:.4f}, Accuracy: {accuracy:.2f}%")

    if epoch % 5 == 0:
        save_checkpoint(model, optimizer, f"model_{epoch}.pt")

# Test


In [None]:
def test_checkpoint(path: str):
    model = load_model(num_classes=10).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model, optimizer = load_checkpoint(model, optimizer, path)

    accuracy = test_step(model, test_loader)
    print(f"Accuracy: {accuracy:.2f}%")

In [None]:
def test_single_image(path: str):
    model = load_model(num_classes=10).to(device)
    model.eval()

    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    transform = transforms.Compose(
        [
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    image = transform(Image.open(path)).unsqueeze(0).to(device)
    outputs = model(image).logits
    _, predicted = torch.max(outputs, 1)

    return predicted.item()