## Setup


In [None]:
!python --version

In [None]:
%pip install torch torchvision transformers matplotlib

In [7]:
import os
import torch
import numpy as np
import torch.nn as nn

from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

## Load the Vision Transformer Model (ViT)


In [None]:
pretrained_weights = models.ViT_B_16_Weights.DEFAULT
pretrained_model = models.vit_b_16(weights=pretrained_weights)

In [None]:
for param in pretrained_model.parameters():
    param.requires_grad = False

In [None]:
pretrained_transforms = pretrained_weights.transforms()
pretrained_transforms

## Load Dataset


In [None]:
image_path = Path("")

train_dir = image_path.joinpath("train")
test_dir = image_path.joinpath("test")
validate_dir = image_path.joinpath("validate")

In [None]:
NUM_WORKERS = os.cpu_count()
NUM_WORKERS

In [9]:
def create_dataloader(
    train_dir: str,
    test_dir: str,
    transform: transforms.Compose,
    batch_size: int = 32,
    num_workers: int = NUM_WORKERS,
):
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, test_loader, train_data.classes

In [None]:
train_loader, test_loader, classes = create_dataloader(
    train_dir, test_dir, pretrained_transforms
)

## Train Model

### Setup Loss Function and Optimizer


In [None]:
HARDWARE = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(HARDWARE)

In [None]:
torch.manual_seed(42)
pretrained_model.heads = nn.Linear(
    in_features=pretrained_model.heads.in_features, out_features=len(classes)
).to(device)

In [10]:
def train_step(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
):
    model.train()
    train_accuracy, train_loss = 0, 0

    for _, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_accuracy += (y_pred_class == y).sum().item() / len(y_pred)

    train_accuracy /= len(loader)
    train_loss /= len(loader)

    return train_accuracy, train_loss

In [None]:
def test_step(model: nn.Module, loader: DataLoader, loss_fn: nn.Module):
    model.eval()
    test_loss_values, test_accuracy_values = 0, 0

    with torch.inference_mode():
        for _, (X, y) in enumerate(loader):
            X, y = X.to(device), y.to(device)
            y_test_pred_logits = model(X)

            test_loss = loss_fn(y_test_pred_logits, y)
            test_loss_values += test_loss.item()

            y_test_pred_class = torch.argmax(y_test_pred_logits, dim=1)
            test_accuracy_values += (y_test_pred_class == y).sum().item() / len(
                y_test_pred_logits
            )

    test_loss_values /= len(loader)
    test_accuracy_values /= len(loader)

    return test_loss_values, test_accuracy_values

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module = nn.CrossEntropyLoss(),
    epochs: int = 100,
    early_stopping: int | None = None,
):
    result = {
        "train_accuracy": [],
        "train_loss": [],
        "test_accuracy": [],
        "test_loss": [],
    }

    for epoch in range(epochs):
        train_accuracy, train_loss = train_step(model, train_loader, loss_fn, optimizer)
        test_loss, test_accuracy = test_step(model, test_loader, loss_fn)

        print(
            f"Epoch: {epoch + 1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_accuracy: {train_accuracy:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_accuracy: {test_accuracy:.4f}"
        )

        result["train_accuracy"].append(train_accuracy)
        result["train_loss"].append(train_loss)
        result["test_accuracy"].append(test_accuracy)
        result["test_loss"].append(test_loss)

        if early_stopping:
            if early_stopping.step(test_loss):
                print("Early stopping triggered at epoch", epoch + 1)
                break

    return result

In [None]:
def save_model(model: nn.Module, target_dir: str, model_name: str):
    target_dir_path = Path(target_dir)
    target_dir_path.mkdir(parents=True, exist_ok=True)
    model_save_path = target_dir_path / f"{model_name}.pt"
    torch.save(obj=model.state_dict(), f=model_save_path)

In [None]:
class EarlyStopping(object):
    def __init__(self, mode="min", min_delta=0, patience=10, percentage=False) -> None:
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
            print("Improvement!")
        else:
            self.num_bad_epochs += 1
            print(f"No improvement, bad epochs counter: {self.num_bad_epochs}")

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {"min", "max"}:
            raise ValueError("mode " + mode + " is unknown!")
        if not percentage:
            if mode == "min":
                self.is_better = lambda a, best: a < best - min_delta
            if mode == "max":
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == "min":
                self.is_better = lambda a, best: a < best - (best * min_delta / 100)
            if mode == "max":
                self.is_better = lambda a, best: a > best + (best * min_delta / 100)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=pretrained_model.parameters(), lr=0.001)

In [None]:
early_stopping = EarlyStopping(mode="min", patience=10)
devices = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_result = train(
    model=pretrained_model,
    train_loader=train_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=100,
    early_stopping=early_stopping,
)

In [None]:
save_model(model=pretrained_model, target_dir="models", model_name="classification.pt")