<a href="https://colab.research.google.com/github/lewisdoukas/scene-classification/blob/main/AID_Scene_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!unzip '/content/gdrive/MyDrive/aidProject/data/NWPU-RESISC45_test.zip' -d '/content/gdrive/MyDrive/aidProject/data'

In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import timm
from sklearn.metrics import precision_score, recall_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import os

In [3]:
ROOT_DIR = "/content/gdrive/MyDrive/aidProject"
DATA_DIR = os.path.join(ROOT_DIR, "data")
OUTPUT_DIR = os.path.join(ROOT_DIR, "output")

BATCH_SIZE = 32
CLASS_NAMES = ["Airport", "Bridge", "Center", "Industrial", "Port", "RailwayStation", "StorageTanks", "Viaduct"]
NUM_CLASSES = len(CLASS_NAMES) # 8
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
# Lower LR for transformers
# LEARNING_RATE = 3e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Data Preprocessing (ImageNet default values)
transform = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
# Load datasets
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=transform["train"])
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=transform["val"])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Load Pretrained VGG-16
model = models.vgg16(pretrained=True)
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, NUM_CLASSES)
model = model.to(DEVICE)

In [None]:
# Load Pretrained ResNet-50
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, NUM_CLASSES)
model = model.to(DEVICE)

In [None]:
# Load Pretrained ViT Model
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

In [None]:
# Hybrid CNN-Transformer Model
class HybridCNNTransformer(nn.Module):
    def __init__(self, num_classes):
        super(HybridCNNTransformer, self).__init__()

        # CNN Backbone (ResNet-50)
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove final FC layer

        # Transformer Backbone (ViT)
        self.transformer = timm.create_model("vit_base_patch16_224", pretrained=True)
        self.transformer.head = nn.Identity()  # Remove final classifier

        # Fully Connected Layer for Classification
        self.fc = nn.Linear(2048 + 768, num_classes)  # ResNet-50 (2048) + ViT (768)

    def forward(self, x):
        cnn_features = self.cnn(x)  # Extract CNN features
        transformer_features = self.transformer(x)  # Extract Transformer features
        combined = torch.cat((cnn_features, transformer_features), dim=1)  # Concatenate features
        output = self.fc(combined)  # Final classification
        return output


model = HybridCNNTransformer(num_classes=NUM_CLASSES)
model = model.to(DEVICE)

In [None]:
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) # ViT

In [None]:
# Train & Validate model
def apply_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, class_names, model_filename):
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        # Training phase
        model.train()
        train_loss, correct_train = 0.0, 0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct_train += torch.sum(preds == labels.data)

        train_acc = correct_train.double() / len(train_loader.dataset)
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")

        # Validation phase
        model.eval()

        correct_val = 0
        y_true, y_pred = [], []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                correct_val += torch.sum(preds == labels.data)
                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

        # Metrics
        val_acc = correct_val.double() / len(val_loader.dataset)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')
        print(f"\nValidation Acc: {val_acc:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-score: {f1:.4f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), model_filename)

            # Calculate confusion matrix for best model
            cm = confusion_matrix(y_true, y_pred)
            cm_filename = model_filename.replace(".pth", "_cm.png")

            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
            plt.xlabel("Predicted")
            plt.ylabel("Actual")
            plt.title("Confusion Matrix")
            plt.tight_layout()
            plt.savefig(cm_filename)

            print("Model saved...")

In [None]:
model_filename = os.path.join(OUTPUT_DIR, "best_aid_resnet50_20.pth")

apply_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, CLASS_NAMES, model_filename)

In [5]:
# Test best model using NWPU-RESISC45 overlapping classes
MODEL_PATH = os.path.join(OUTPUT_DIR, "best_aid_hybrid_10.pth")
TEST_DATA_DIR = os.path.join(DATA_DIR, "NWPU-RESISC45_test")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Load Trained Model
model = HybridCNNTransformer(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()  # Set model to evaluation mode
print("Model Loaded Successfully!")

Model Loaded Successfully!


In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
test_dataset = datasets.ImageFolder(TEST_DATA_DIR, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"Loaded {len(test_dataset)} images from common classes.")

In [None]:
# Evaluate Model
correct_val = 0
y_true, y_pred = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct_val += torch.sum(preds == labels.data)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Metrics
val_acc = correct_val.double() / len(test_loader.dataset)
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
print(f"\nModel Accuracy on NWPU-RESISC45 (Common Classes):")
print(f"Validation Acc: {val_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")


# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)
cm_filename = os.path.join(OUTPUT_DIR, "test_hybrid_NWPU_RESISC45.png")

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Oranges', xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(cm_filename)