In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
from torch.utils.data import random_split

In [16]:
data_root = "datasets"

In [17]:
# for ResNet18:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [18]:
trainval_data = datasets.OxfordIIITPet(
        root=data_root,
        split='trainval',
        target_types='binary-category',
        transform=transform,
        download=False
    )
test_data = datasets.OxfordIIITPet(
        root=data_root,
        split='test',
        target_types='binary-category',
        transform=transform,
        download=False
    )

In [19]:
val_ratio = 0.1  # 10% for validation
train_size = int((1 - val_ratio) * len(trainval_data))
val_size = len(trainval_data) - train_size

train_data, val_data = random_split(trainval_data, [train_size, val_size])

In [56]:
len(val_data), len(train_data), len(test_data)

(368, 3312, 3669)

In [57]:
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ResNet18
network = models.resnet18(weights='DEFAULT')
nf = network.fc.in_features
network.fc = nn.Linear(nf, 2)
network = network.to(device)

In [59]:
def train_network(network, loader, criterion, optimizer, n_epochs, val_loader=None):
    network.train()

    for epoch in range(n_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for X_batch, Y_batch in loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)

            optimizer.zero_grad()
            S = network(X_batch)
            loss = criterion(S, Y_batch)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, P = torch.max(S, 1)
            correct += (P == Y_batch).sum().item()
            total += Y_batch.size(0)

        acc = 100 * correct / total
        val_loss, val_acc = compute_val_metrics(network, val_loader, criterion)
        print(f"[Epoch {epoch+1}/{n_epochs}] Train Running Loss: {running_loss:.4f}, Train Accuracy: {acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")


In [60]:
def compute_val_metrics(network, loader, criterion):
    network.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, Y_batch in loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            S = network(X_batch)
            loss = criterion(S, Y_batch)
            val_loss += loss.item()
            _, P = torch.max(S, 1)
            correct += (P == Y_batch).sum().item()
            total += Y_batch.size(0)

    avg_loss = val_loss / len(loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [61]:
def compute_accuracy(network, loader, print_result=True):
    network.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, Y_batch in loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            S = network(X_batch)
            _, P = torch.max(S, 1)
            correct += (P == Y_batch).sum().item()
            total += Y_batch.size(0)

    acc = 100 * correct / total
    if print_result:
        print(f"Test Accuracy: {acc:.2f}%")
    return acc

In [65]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(network.parameters(), lr=1e-4)

In [68]:
# Running for 1 epoch
train_network(network, train_loader, criterion, optimizer, 1, val_loader)

[Epoch 1/1] Train Running Loss: 0.3426, Train Accuracy: 99.94% | Val Loss: 0.0002, Val Accuracy: 100.00%


In [69]:
compute_accuracy(network, test_loader)

Test Accuracy: 98.96%


98.96429544835105