# Multispectral Models (EuroSAT)

This notebook evaluates image classification models trained on 13-band Sentinel-2 multispectral inputs from EuroSAT. Using the same fixed train/validation/test splits as previous experiments, we isolate the effect of physically meaningful spectral information on classification performance.

---

In [1]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import models

# Robust project root
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()

DATA_DIR = PROJECT_ROOT / "data"
SPLITS_DIR = PROJECT_ROOT / "splits"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"

In [2]:
splits = np.load(SPLITS_DIR / "eurosat_ms_splits.npz")

train_idx = splits["train_idx"]
val_idx   = splits["val_idx"]
test_idx  = splits["test_idx"]

print(len(train_idx), len(val_idx), len(test_idx))

12960 1620 1620


In [3]:
from torchgeo.datasets import EuroSAT as EuroSAT_MS

ds_ms = EuroSAT_MS(
    root=str(DATA_DIR / "eurosat_ms"),
    download=False
)

sample = ds_ms[0]
print(sample["image"].shape, sample["label"])

torch.Size([13, 64, 64]) tensor(0)


In [4]:
norm = torch.load(ARTIFACTS_DIR / "eurosat_ms_norm.pt")
ms_mean = norm["mean"]
ms_std  = norm["std"]

def normalize_ms(x):
    # x: [C, H, W]
    return (x - ms_mean[:, None, None]) / ms_std[:, None, None]

class NormalizedMS(torch.utils.data.Dataset):
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        item = self.base_ds[self.indices[i]]
        x = normalize_ms(item["image"].float())
        y = item["label"]
        return x, y

In [5]:
train_ds = NormalizedMS(ds_ms, train_idx)
val_ds   = NormalizedMS(ds_ms, val_idx)
test_ds  = NormalizedMS(ds_ms, test_idx)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0)

xb, yb = next(iter(train_loader))
print(xb.shape, yb.shape)

torch.Size([64, 13, 64, 64]) torch.Size([64])


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 10

model = models.resnet18(weights=None)

# Replace first conv: 3 â†’ 13 channels
model.conv1 = nn.Conv2d(
    in_channels=13,
    out_channels=64,
    kernel_size=7,
    stride=2,
    padding=3,
    bias=False
)

model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

print(model.conv1)

Conv2d(13, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [9]:
def train_one_epoch(model, loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total

In [10]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return correct / total

In [11]:
epochs = 5

for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train loss: {train_loss:.4f} | "
        f"Train acc: {train_acc:.4f} | "
        f"Val acc: {val_acc:.4f}"
    )

Epoch 1/5 | Train loss: 0.8093 | Train acc: 0.7299 | Val acc: 0.7438
Epoch 2/5 | Train loss: 0.4763 | Train acc: 0.8449 | Val acc: 0.8543
Epoch 3/5 | Train loss: 0.3438 | Train acc: 0.8868 | Val acc: 0.8951
Epoch 4/5 | Train loss: 0.3097 | Train acc: 0.8969 | Val acc: 0.9086
Epoch 5/5 | Train loss: 0.2507 | Train acc: 0.9159 | Val acc: 0.9130


In [12]:
test_acc = evaluate(model, test_loader)
print("Test accuracy:", test_acc)

Test accuracy: 0.9191358024691358
