# Aerial Scene Classification with ResNet-18
Using Transfer Learning in PyTorch

## 1. Import Required Libraries

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from torchvision.models import resnet18, ResNet18_Weights
import pandas as pd
from PIL import Image

## 2. Set Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## 3. Define Transforms and Load Dataset

In [None]:
input_size = 224
batch_size = 32

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
}
train_dir = r"D:\hmzhao\UNSW\Courses\COMP9517\Group project\data\train"
test_dir = r"D:\hmzhao\UNSW\Courses\COMP9517\Group project\data\test"

train_dirs = [
    os.path.join(train_dir, "flip"),
    os.path.join(train_dir, "blur"),
    os.path.join(train_dir, "brightness"),
    os.path.join(train_dir, "crop"),
    os.path.join(train_dir, "rotate"),
    os.path.join(train_dir, "original")
]

train_datasets = [datasets.ImageFolder(train_sub_dir, transform=data_transforms['train']) for train_sub_dir in train_dirs]
train_dataset = ConcatDataset(train_datasets)

test_dataset = datasets.ImageFolder(test_dir, transform=data_transforms['test'])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples:  {len(test_dataset)}")
print("Classes:", test_dataset.classes)

## 4. Define and Modify ResNet-18 Model

In [None]:
model = resnet18(weights=None)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 15)
model = model.to(device)
print("Using device:", device)


## 5. Define Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

## 6. Train the Model

In [None]:
max_epochs = 50
best_acc = 0.0
patience = 3
trigger_times = 0
model_save_path = r"D:\hmzhao\UNSW\Courses\COMP9517\Group project\resnet18_model-4.pth"

print("Starting training with early stopping...\n")

for epoch in range(max_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    train_loss_history.append(epoch_loss)
    train_acc_history.append(epoch_acc)
    print(f"Epoch [{epoch+1}/{max_epochs}] Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    model.eval()
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for val_inputs, val_labels in test_loader:
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_outputs = model(val_inputs)
            _, val_predicted = val_outputs.max(1)
            val_correct += val_predicted.eq(val_labels).sum().item()
            val_total += val_labels.size(0)

    val_acc = val_correct / val_total
    val_acc_history.append(val_acc)
    print(f"→ Validation Acc: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        trigger_times = 0
        torch.save(model.state_dict(), model_save_path)
        print(f"Model improved! Saved to {model_save_path}")
    else:
        trigger_times += 1
        print(f"No improvement. Trigger times: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("\n Early stopping triggered.")
            break

## 7. Visualize Training Curves

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Training Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curve')
plt.legend()

plt.tight_layout()
plt.savefig('training_curves.png')
plt.show()

## 8. Evaluate Model on Test Set

In [None]:
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 15)
model.load_state_dict(torch.load(model_save_path))
model = model.to(device)
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

acc = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='macro')
recall = recall_score(all_labels, all_preds, average='macro')
f1 = f1_score(all_labels, all_preds, average='macro')

print("\nEvaluation Metrics:")
print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-score:  {f1:.4f}")

print("\nPer-Class Report:")
print(classification_report(all_labels, all_preds, target_names=test_dataset.classes, digits=4))

cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=test_dataset.classes, 
            yticklabels=test_dataset.classes)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()