# Set up


In [None]:
import os

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from transformers import ViTForImageClassification, ViTImageProcessor

In [None]:
def save_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    save_path,
    parallel=False,
):
    torch.save(
        {
            "model": model.module.state_dict() if parallel else model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        save_path,
    )


def load_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    save_path,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    parallel=False,
):
    dict = torch.load(save_path, map_location=device, weights_only=True)

    if parallel:
        model.module.load_state_dict(dict["model"])
    else:
        model.load_state_dict(dict["model"])

    optimizer.load_state_dict(dict["optimizer"])

    return model, optimizer

In [None]:
def load_dataset(root, batch_size=16):
    torch.manual_seed(42)

    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    preprocess_train = transforms.Compose(
        [
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )
    preprocess_test = transforms.Compose(
        [
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
        ]
    )

    # ImageFolder
    train_set = ImageFolder(f"{root}/train", transform=preprocess_train)
    valid_set = ImageFolder(f"{root}/val", transform=preprocess_test)
    test_set = ImageFolder(f"{root}/test", transform=preprocess_test)

    # DataLoader
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

In [None]:
def load_ViT(num_classes=10):
    model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
    model.classifier = nn.Sequential(
        nn.Linear(in_features=768, out_features=512),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=512, out_features=256),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(in_features=256, out_features=num_classes, bias=False),
    )

    return model

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    save_path: str,
    num_epochs=50,
    resume_training: bool = False,
    lr=0.0001,
    parallel=False,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    os.makedirs(save_path, exist_ok=True)

    if parallel:
        model = nn.DataParallel(model)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    if resume_training:
        df = pd.read_csv(f"{save_path}/results.csv")
        results = list(df.T.to_dict().values())
        start_epoch = int(results[-1]["epoch"])

        model, optimizer = load_checkpoint(
            model, optimizer, f"{save_path}/ViT_{start_epoch}.pth"
        )

        for _ in range(start_epoch):
            for _ in train_loader:
                break

        print(f"Resuming training from epoch {start_epoch}")

    print(f"Start trainning with {device.upper()}")
    for epoch in range(num_epochs):
        # Train step
        model.train()
        train_running_loss, train_correct = 0.0, 0
        with tqdm(
            total=len(train_loader),
            desc=f"Train epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_running_loss += loss.item()
                _, predicted = torch.max(outputs.logits, 1)
                train_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(train_running_loss / (i + 1), ".4f")})
                pbar.update()

            train_loss = train_running_loss / len(train_loader)
            train_acc = 100 * train_correct / len(train_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(train_loss, ".4f"),
                    "accuracy": format(train_acc, ".2f"),
                }
            )

        # Validation step
        model.eval()
        valid_running_loss, valid_correct = 0.0, 0
        with tqdm(
            total=len(valid_loader),
            desc=f"Valid epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(valid_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                valid_running_loss += loss.item()
                _, predicted = torch.max(outputs.logits, 1)
                valid_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(valid_running_loss / (i + 1), ".4f")})
                pbar.update()

            valid_loss = valid_running_loss / len(valid_loader)
            valid_acc = 100 * valid_correct / len(valid_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(valid_loss, ".4f"),
                    "accuracy": format(valid_acc, ".2f"),
                }
            )

        # Save results
        results.append(
            {
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "valid_loss": valid_loss,
                "valid_acc": valid_acc,
            }
        )
        df = pd.DataFrame(results)
        df.to_csv(f"{save_path}/results.csv", index=False)

        # Save checkpoint
        if epoch % 5 == 0:
            save_checkpoint(model, optimizer, f"ViT_{epoch+1}.pt", parallel)

# Train


In [None]:
train_loader, valid_loader, test_loader = load_dataset("data", batch_size=16)

In [None]:
model = load_ViT(num_classes=10)

In [None]:
train(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    save_path="./vision_transformer",
    num_epochs=20,
    resume_training=False,
)