In [1]:
import os, random
from collections import defaultdict

import numpy as np
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# ----------------------------
# Config
# ----------------------------
SEED = 42
DATA_ROOT = "data"
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 0          # safest in Colab (no dataloader worker issues)
EPOCHS_HEAD = 2
EPOCHS_FT = 5
LR_HEAD = 1e-3
LR_FT_BACKBONE = 1e-5
LR_FT_HEAD = 5e-4
WEIGHT_DECAY = 1e-4

BEST_WEIGHTS_PATH = "best_resnet50_eurosat_3cls.pt"
BEST_CKPT_PATH    = "best_resnet50_eurosat_3cls.pth"  # includes metadata

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

set_seed(SEED)

Device: cpu


In [2]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Load EuroSAT base dataset (PIL images)
base_ds = torchvision.datasets.EuroSAT(root=DATA_ROOT, download=True, transform=None)
class_names = base_ds.classes
name_by_idx = {i: n for i, n in enumerate(class_names)}
print("EuroSAT classes:", class_names)

# Map EuroSAT -> 3 classes
# Urban = Residential + Industrial + Highway
# Forest = Forest
# Water = River + SeaLake
EUROSAT_TO_3 = {
    "Residential": 0,
    "Industrial": 0,
    "Highway": 0,
    "Forest": 1,
    "River": 2,
    "SeaLake": 2,
}
THREE_CLASS_NAMES = ["Urban", "Forest", "Water"]

# Filter usable indices
kept_indices = []
kept_targets = []
for i in range(len(base_ds)):
    _, y = base_ds[i]
    cname = name_by_idx[y]
    if cname in EUROSAT_TO_3:
        kept_indices.append(i)
        kept_targets.append(EUROSAT_TO_3[cname])

print(f"Keeping {len(kept_indices)} images for 3-class task.")

# Stratified split into train/val/test
def stratified_split_3way(indices, labels, val_ratio=0.15, test_ratio=0.15, seed=42):
    random.seed(seed)
    by_class = defaultdict(list)
    for idx, lab in zip(indices, labels):
        by_class[lab].append(idx)

    train_idx, val_idx, test_idx = [], [], []
    for lab, idxs in by_class.items():
        random.shuffle(idxs)
        n = len(idxs)
        n_val = int(n * val_ratio)
        n_test = int(n * test_ratio)
        val_idx.extend(idxs[:n_val])
        test_idx.extend(idxs[n_val:n_val+n_test])
        train_idx.extend(idxs[n_val+n_test:])

    random.shuffle(train_idx); random.shuffle(val_idx); random.shuffle(test_idx)
    return train_idx, val_idx, test_idx

train_indices, val_indices, test_indices = stratified_split_3way(
    kept_indices, kept_targets, val_ratio=0.15, test_ratio=0.15, seed=SEED
)

print("Train:", len(train_indices), "Val:", len(val_indices), "Test:", len(test_indices))

# Dataset wrapper (apply tfms + remap labels)
class EuroSAT3Class(torch.utils.data.Dataset):
    def __init__(self, base_dataset, indices, tfms):
        self.base = base_dataset
        self.indices = indices
        self.tfms = tfms
        self.name_by_idx = {i: n for i, n in enumerate(self.base.classes)}

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        real_idx = self.indices[i]
        img, y = self.base[real_idx]
        cname = self.name_by_idx[y]
        y3 = EUROSAT_TO_3[cname]
        if self.tfms is not None:
            img = self.tfms(img)
        return img, y3

train_ds = EuroSAT3Class(base_ds, train_indices, train_tfms)
val_ds   = EuroSAT3Class(base_ds, val_indices, val_tfms)
test_ds  = EuroSAT3Class(base_ds, test_indices, val_tfms)  # test uses val transforms

pin = True if device == "cuda" else False
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=pin)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin)

HTTPError: HTTP Error 403: Forbidden

In [None]:
weights = torchvision.models.ResNet50_Weights.DEFAULT
model = torchvision.models.resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 3)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

def accuracy_from_logits(logits, targets):
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy_from_logits(out, y) * bs
        n += bs
    return total_loss / n, total_acc / n

def train_one_epoch(model, loader, optimizer, scaler=None):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += accuracy_from_logits(out, y) * bs
        n += bs

    return total_loss / n, total_acc / n

In [None]:
best_val_acc = 0.0
scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None

# ---------- Stage 1: train head only ----------
for p in model.parameters():
    p.requires_grad = False
for p in model.fc.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(model.fc.parameters(), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)

print("\n=== Stage 1: Training classifier head ===")
for epoch in range(1, EPOCHS_HEAD + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    va_loss, va_acc = evaluate(model, val_loader)

    print(f"[Head] Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}")

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        torch.save(model.state_dict(), BEST_WEIGHTS_PATH)

# ---------- Stage 2: fine-tune layer4 + head ----------
for p in model.parameters():
    p.requires_grad = False
for p in model.layer4.parameters():
    p.requires_grad = True
for p in model.fc.parameters():
    p.requires_grad = True

params = [
    {"params": model.layer4.parameters(), "lr": LR_FT_BACKBONE},
    {"params": model.fc.parameters(),     "lr": LR_FT_HEAD},
]
optimizer = torch.optim.AdamW(params, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FT)

print("\n=== Stage 2: Fine-tuning layer4 + head ===")
for epoch in range(1, EPOCHS_FT + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    va_loss, va_acc = evaluate(model, val_loader)
    scheduler.step()

    print(f"[FT]   Epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f} | best {best_val_acc:.4f}")

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        torch.save(model.state_dict(), BEST_WEIGHTS_PATH)

print(f"\n‚úÖ Training done. Best Val Acc: {best_val_acc:.4f}")
print(f"‚úÖ Best weights saved to: {BEST_WEIGHTS_PATH}")

# Save a full checkpoint with metadata (recommended)
ckpt = {
    "model_state": torch.load(BEST_WEIGHTS_PATH, map_location="cpu"),
    "class_names": THREE_CLASS_NAMES,
    "img_size": IMG_SIZE,
    "mean": IMAGENET_MEAN,
    "std": IMAGENET_STD,
    "eurosat_to_3": EUROSAT_TO_3,
}
torch.save(ckpt, BEST_CKPT_PATH)
print(f"‚úÖ Full checkpoint saved to: {BEST_CKPT_PATH}")

In [None]:
@torch.no_grad()
def predict_pil(pil_img):
    tfm = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    x = tfm(pil_img.convert("RGB")).unsqueeze(0).to(device)
    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
    pred = int(np.argmax(probs))
    conf = float(probs[pred])
    return pred, conf, probs

# pick 20 random test indices
sample = random.sample(test_indices, 20)

plt.figure(figsize=(16, 14))
for i, idx in enumerate(sample, start=1):
    img, y_orig = base_ds[idx]              # PIL image + original EuroSAT label
    cname = name_by_idx[y_orig]
    true = EUROSAT_TO_3[cname]              # true mapped label

    pred, conf, probs = predict_pil(img)

    ax = plt.subplot(5, 4, i)
    ax.imshow(img)
    ax.axis("off")

    title = f"P:{THREE_CLASS_NAMES[pred]} ({conf*100:.1f}%)\nT:{THREE_CLASS_NAMES[true]}"
    ax.set_title(title, color=("green" if pred == true else "red"), fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
import zipfile, shutil

os.makedirs("sample_test_images", exist_ok=True)

for i, idx in enumerate(sample):
    img, y_orig = base_ds[idx]
    cname = name_by_idx[y_orig]
    true = EUROSAT_TO_3[cname]
    img.save(f"sample_test_images/{i+1:02d}_TRUE_{THREE_CLASS_NAMES[true]}.png")

zip_path = "sample_test_images.zip"
with zipfile.ZipFile(zip_path, "w") as z:
    for f in os.listdir("sample_test_images"):
        z.write(os.path.join("sample_test_images", f), arcname=f)

print("‚úÖ Created:", zip_path)

from google.colab import files
files.download(zip_path)

In [None]:
import os
from google.colab import files
from IPython.display import clear_output, display
from PIL import Image

# Directory to store temporary uploaded images
UPLOAD_DIR = "/content/upload_temp"
os.makedirs(UPLOAD_DIR, exist_ok=True)

def remove_uploaded():
    """Delete uploaded images + clear output."""
    for f in os.listdir(UPLOAD_DIR):
        os.remove(os.path.join(UPLOAD_DIR, f))
    clear_output()
    print("üóëÔ∏è All uploaded images removed. You can upload new ones now.")

while True:
    print("üìå Options:")
    print("1Ô∏è‚É£  Upload Image")
    print("2Ô∏è‚É£  Remove Uploaded Images")
    print("3Ô∏è‚É£  Exit")
    choice = input("Enter choice (1/2/3): ")

    # ============================
    # OPTION 1 ‚Äî UPLOAD IMAGE
    # ============================
    if choice == "1":
        clear_output()
        uploaded = files.upload()  # upload any image

        for fn in uploaded.keys():
            path = os.path.join(UPLOAD_DIR, fn)
            with open(path, 'wb') as f:
                f.write(uploaded[fn])

            img = Image.open(path).convert("RGB")
            display(img)

            pred, conf, probs = predict_pil(img)

            print("\n==========================")
            print("üõ∞Ô∏è  PREDICTION RESULT")
            print("==========================")
            print(f"Predicted Class: {THREE_CLASS_NAMES[pred]}")
            print(f"Confidence: {conf*100:.2f}%")

            print("\nüìä Class Probabilities:")
            for i, c in enumerate(THREE_CLASS_NAMES):
                print(f"  {c:>6}: {probs[i]*100:.2f}%")

        print("\nüëâ You can choose again (upload / remove / exit).")

    # ============================
    # OPTION 2 ‚Äî REMOVE / RESET
    # ============================
    elif choice == "2":
        remove_uploaded()

    # ============================
    # OPTION 3 ‚Äî EXIT LOOP
    # ============================
    elif choice == "3":
        print("üëã Exiting. You can run this cell again anytime.")
        break

    else:
        clear_output()
        print("‚ö†Ô∏è Invalid choice! Please enter 1, 2, or 3.")