In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# --------------------------
# 1. Define utility functions
# --------------------------

def train_model(model, train_loader, criterion, optimizer, device, epochs=5):
    """
    Train the model for a specified number of epochs.
    """
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")


def evaluate_model(model, test_loader, device):
    """
    Evaluate the model on the test set.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    return accuracy

# -------------------------
# 2. Define our CNN baseline
# -------------------------

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=6):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # After 3 max-pools, input 224x224 => 224/(2^3)=28 => feature map is 128x28x28
        self.classifier = nn.Sequential(
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.classifier(x)
        return x

# --------------------------------------------------
# 3. Main guard to avoid multiprocessing issues on Windows
# --------------------------------------------------

if __name__ == "__main__":
    # 3.1. Device check
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # 3.2. Data paths
    train_dir = "data/seg_train/seg_train"
    test_dir  = "data/seg_test/seg_test"

    # 3.3. Transforms
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),  # some augmentation
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    # 3.4. Create Datasets
    train_dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=train_transforms)
    test_dataset  = torchvision.datasets.ImageFolder(root=test_dir, transform=test_transforms)

    # 3.5. DataLoaders
    batch_size = 32
    # NOTE: num_workers > 0 requires the __main__ guard on Windows
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 3.6. Check classes
    classes = train_dataset.classes
    num_classes = len(classes)
    print("Classes:", classes)

    # ------------------------------------
    # 4. Baseline Model (trained from scratch)
    # ------------------------------------
    baseline_model = SimpleCNN(num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(baseline_model.parameters(), lr=0.001)

    print("\nTraining Baseline Model...")
    train_model(baseline_model, train_loader, criterion, optimizer, device, epochs=5)
    baseline_acc = evaluate_model(baseline_model, test_loader, device)
    print(f"Baseline Model Test Accuracy: {baseline_acc:.2f}%")

    # ------------------------------------
    # 5. Fine-Tuning Pretrained ResNet
    # ------------------------------------
    print("\nFine-Tuning ResNet18...")

    # Load pretrained ResNet18
    resnet = models.resnet18(pretrained=True)
    resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
    resnet = resnet.to(device)

    for param in resnet.parameters():
        param.requires_grad = True

    optimizer = optim.Adam(resnet.parameters(), lr=1e-4)
    train_model(resnet, train_loader, criterion, optimizer, device, epochs=5)
    resnet_acc = evaluate_model(resnet, test_loader, device)
    print(f"Fine-Tuned ResNet Test Accuracy: {resnet_acc:.2f}%")

    # ------------------------------------
    # 6. Compare results
    # ------------------------------------
    print("\n=== Final Results ===")
    print(f"Baseline CNN Accuracy:    {baseline_acc:.2f}%")
    print(f"Fine-Tuned ResNet18 Accuracy: {resnet_acc:.2f}%")

Using device: cuda
Classes: ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']

Training Baseline Model...
Epoch [1/5], Loss: 0.9533
Epoch [2/5], Loss: 0.6380
Epoch [3/5], Loss: 0.5063
Epoch [4/5], Loss: 0.4286
Epoch [5/5], Loss: 0.3531
Baseline Model Test Accuracy: 84.63%

Fine-Tuning ResNet18...




Epoch [1/5], Loss: 0.2776
Epoch [2/5], Loss: 0.1629
Epoch [3/5], Loss: 0.1121
Epoch [4/5], Loss: 0.0737
Epoch [5/5], Loss: 0.0618
Fine-Tuned ResNet Test Accuracy: 92.23%

=== Final Results ===
Baseline CNN Accuracy:    84.63%
Fine-Tuned ResNet18 Accuracy: 92.23%
