# Dataset Preparation (EuroSAT RGB)

This notebook loads the EuroSAT RGB dataset, creates stratified train/validation/test splits, and saves them for reproducible experiments.

---

## 1. Project setup and paths

In [1]:
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()

DATA_DIR = PROJECT_ROOT / "data"
SPLITS_DIR = PROJECT_ROOT / "splits"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"

for d in [DATA_DIR, SPLITS_DIR, ARTIFACTS_DIR]:
    d.mkdir(exist_ok=True)

PROJECT_ROOT

WindowsPath('C:/Users/ishaa/OneDrive/Desktop/Projects/eurosat-physics-aware-image-classification')

In [2]:
import torch
from torchvision.datasets import EuroSAT
from torchvision import transforms

rgb_root = DATA_DIR / "eurosat_rgb"

rgb_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

ds_rgb = EuroSAT(root=str(rgb_root), download=True, transform=rgb_transform)

print("RGB samples:", len(ds_rgb))
print("Classes:", ds_rgb.classes)

x0, y0 = ds_rgb[0]
print("One sample:", x0.shape, y0)

RGB samples: 27000
Classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
One sample: torch.Size([3, 64, 64]) 0


In [4]:
import numpy as np
from sklearn.model_selection import train_test_split

# Indices + labels
idx = np.arange(len(ds_rgb))
labels = np.array([ds_rgb[i][1] for i in range(len(ds_rgb))])

# 80% train, 20% temp
train_idx, temp_idx = train_test_split(
    idx, test_size=0.2, stratify=labels, random_state=42
)

# Split temp into 10% val, 10% test
temp_labels = labels[temp_idx]
val_idx, test_idx = train_test_split(
    temp_idx, test_size=0.5, stratify=temp_labels, random_state=42
)

print("Split sizes:", len(train_idx), len(val_idx), len(test_idx))

Split sizes: 21600 2700 2700


In [5]:
import numpy as np

split_path = SPLITS_DIR / "eurosat_rgb_splits.npz"
np.savez(
    split_path,
    train_idx=train_idx,
    val_idx=val_idx,
    test_idx=test_idx,
    seed=np.array([42]),
)

print("Saved splits to:", split_path)

Saved splits to: C:\Users\ishaa\OneDrive\Desktop\Projects\eurosat-physics-aware-image-classification\splits\eurosat_rgb_splits.npz


In [6]:
from collections import Counter

def label_counts(indices):
    return Counter(labels[indices].tolist())

print("Train label counts (top 5):", label_counts(train_idx).most_common(5))
print("Val   label counts (top 5):", label_counts(val_idx).most_common(5))
print("Test  label counts (top 5):", label_counts(test_idx).most_common(5))

Train label counts (top 5): [(0, 2400), (9, 2400), (7, 2400), (2, 2400), (1, 2400)]
Val   label counts (top 5): [(9, 300), (1, 300), (0, 300), (7, 300), (2, 300)]
Test  label counts (top 5): [(9, 300), (7, 300), (2, 300), (1, 300), (0, 300)]


In [7]:
from torch.utils.data import DataLoader, Subset

train_ds = Subset(ds_rgb, train_idx)
val_ds   = Subset(ds_rgb, val_idx)
test_ds  = Subset(ds_rgb, test_idx)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch shapes:", xb.shape, yb.shape)



Batch shapes: torch.Size([128, 3, 64, 64]) torch.Size([128])


## 2. EuroSAT RGB: loading and stratified splits

In [8]:
import numpy as np
import torch

from torchgeo.datasets import EuroSAT as EuroSAT_MS

ms_root = DATA_DIR / "eurosat_ms"

# TorchGeo returns dicts: {"image": Tensor[C,H,W], "label": int}
ds_ms = EuroSAT_MS(root=str(ms_root), download=True)

print("MS samples:", len(ds_ms))
sample = ds_ms[0]
print("Keys:", sample.keys())
print("Image shape:", sample["image"].shape, "Label:", sample["label"])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 314k/314k [00:00<00:00, 8.34MB/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.07G/2.07G [00:58<00:00, 35.4MB/s]


MS samples: 16200
Keys: dict_keys(['image', 'label'])
Image shape: torch.Size([13, 64, 64]) Label: tensor(0)


In [10]:
from sklearn.model_selection import train_test_split

idx = np.arange(len(ds_ms))
labels_ms = np.array([ds_ms[i]["label"] for i in range(len(ds_ms))])

train_idx_ms, temp_idx_ms = train_test_split(
    idx, test_size=0.2, stratify=labels_ms, random_state=42
)
temp_labels_ms = labels_ms[temp_idx_ms]
val_idx_ms, test_idx_ms = train_test_split(
    temp_idx_ms, test_size=0.5, stratify=temp_labels_ms, random_state=42
)

print("MS split sizes:", len(train_idx_ms), len(val_idx_ms), len(test_idx_ms))

split_path = SPLITS_DIR / "eurosat_ms_splits.npz"
np.savez(split_path, train_idx=train_idx_ms, val_idx=val_idx_ms, test_idx=test_idx_ms, seed=np.array([42]))
print("Saved MS splits to:", split_path)

MS split sizes: 12960 1620 1620
Saved MS splits to: C:\Users\ishaa\OneDrive\Desktop\Projects\eurosat-physics-aware-image-classification\splits\eurosat_ms_splits.npz


In [13]:
from torch.utils.data import DataLoader, Subset

def collate_ms(batch):
    x = torch.stack([b["image"] for b in batch], dim=0).float()   # [B, 13, H, W]
    y = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    return x, y

train_ms = Subset(ds_ms, train_idx_ms)
val_ms   = Subset(ds_ms, val_idx_ms)
test_ms  = Subset(ds_ms, test_idx_ms)

train_ms_loader = DataLoader(train_ms, batch_size=64, shuffle=True,  num_workers=0, pin_memory=True, collate_fn=collate_ms)
val_ms_loader   = DataLoader(val_ms,   batch_size=64, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_ms)
test_ms_loader  = DataLoader(test_ms,  batch_size=64, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_ms)

xb, yb = next(iter(train_ms_loader))
print("Batch:", xb.shape, yb.shape)

Batch: torch.Size([64, 13, 64, 64]) torch.Size([64])


In [14]:
from tqdm import tqdm

@torch.no_grad()
def compute_band_mean_std(loader):
    # streaming sums (memory-safe)
    sum_ = None
    sumsq = None
    count = 0

    for x, _ in tqdm(loader, desc="Computing MS normalization"):
        # x: [B, C, H, W]
        b, c, h, w = x.shape
        x = x.float()
        if sum_ is None:
            sum_ = torch.zeros(c)
            sumsq = torch.zeros(c)

        # sum over batch and pixels
        sum_ += x.sum(dim=(0, 2, 3)).cpu()
        sumsq += (x ** 2).sum(dim=(0, 2, 3)).cpu()
        count += b * h * w

    mean = sum_ / count
    var = sumsq / count - mean**2
    std = torch.sqrt(torch.clamp(var, min=1e-12))
    return mean, std

ms_mean, ms_std = compute_band_mean_std(train_ms_loader)
print("MS mean:", ms_mean)
print("MS std :", ms_std)

norm_path = ARTIFACTS_DIR / "eurosat_ms_norm.pt"
torch.save({"mean": ms_mean, "std": ms_std}, norm_path)
print("Saved MS normalization to:", norm_path)

Computing MS normalization: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 203/203 [01:38<00:00,  2.06it/s]

MS mean: tensor([1353.1909, 1116.9252, 1041.3521,  946.2354, 1198.6338, 2001.6254,
        2372.2817, 2299.9617,  732.5312,   12.1109, 1820.4744, 1119.2936,
        2598.0645])
MS std : tensor([ 245.3798,  332.1978,  394.9087,  593.4796,  566.7594,  862.4005,
        1088.6469, 1119.5021,  404.8098,    4.7760, 1002.0773,  761.4476,
        1233.4044])
Saved MS normalization to: C:\Users\ishaa\OneDrive\Desktop\Projects\eurosat-physics-aware-image-classification\artifacts\eurosat_ms_norm.pt





## Completed outputs

- Saved stratified splits for RGB and multispectral EuroSAT to `splits/`

- Computed per-band mean/std (train split only) for multispectral inputs and saved to `artifacts/eurosat_ms_norm.pt`