# Analysis, Robustness, and Interpretability (EuroSAT)

This notebook compares RGB baseline, multispectral baseline, and physics-aware multispectral models using consistent metrics, confusion matrices, and robustness tests.

---

In [1]:
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
DATA_DIR = PROJECT_ROOT / "data"
SPLITS_DIR = PROJECT_ROOT / "splits"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [2]:
from torchvision.datasets import EuroSAT
from torchvision import transforms
from torch.utils.data import DataLoader, Subset

rgb_splits = np.load(SPLITS_DIR / "eurosat_rgb_splits.npz")
train_idx_rgb, val_idx_rgb, test_idx_rgb = rgb_splits["train_idx"], rgb_splits["val_idx"], rgb_splits["test_idx"]

rgb_transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
ds_rgb = EuroSAT(root=str(DATA_DIR / "eurosat_rgb"), download=False, transform=rgb_transform)

test_rgb = Subset(ds_rgb, test_idx_rgb)
test_rgb_loader = DataLoader(test_rgb, batch_size=128, shuffle=False, num_workers=0)

class_names = ds_rgb.classes
num_classes = len(class_names)
print("RGB test size:", len(test_rgb))

RGB test size: 2700


In [3]:
from torchgeo.datasets import EuroSAT as EuroSAT_MS
from torch.utils.data import Dataset

ms_splits = np.load(SPLITS_DIR / "eurosat_ms_splits.npz")
train_idx_ms, val_idx_ms, test_idx_ms = ms_splits["train_idx"], ms_splits["val_idx"], ms_splits["test_idx"]

norm = torch.load(ARTIFACTS_DIR / "eurosat_ms_norm.pt")
ms_mean, ms_std = norm["mean"].float(), norm["std"].float()

ds_ms = EuroSAT_MS(root=str(DATA_DIR / "eurosat_ms"), download=False)

class NormalizedMSDataset(Dataset):
    def __init__(self, base_ds, indices, mean, std):
        self.base_ds = base_ds
        self.indices = indices
        self.mean = mean
        self.std = std
    def __len__(self): return len(self.indices)
    def __getitem__(self, i):
        item = self.base_ds[int(self.indices[i])]
        x = item["image"].float()
        x = (x - self.mean[:, None, None]) / self.std[:, None, None]
        y = int(item["label"])
        return x, y

test_ms = NormalizedMSDataset(ds_ms, test_idx_ms, ms_mean, ms_std)
test_ms_loader = DataLoader(test_ms, batch_size=64, shuffle=False, num_workers=0)

print("MS test size:", len(test_ms))

MS test size: 1620
