# Physics-Aware Multispectral Models (EuroSAT)

Previous experiments treated multispectral bands as independent input channels.
In this notebook, we introduce physics-aware inductive biases by explicitly grouping Sentinel-2 spectral bands according to their physical meaning (visible, red-edge, near-infrared, short-wave infrared).

We evaluate whether such structure improves classification performance, robustness, or interpretability under otherwise identical training conditions.

---

In [1]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
DATA_DIR = PROJECT_ROOT / "data"
SPLITS_DIR = PROJECT_ROOT / "splits"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
ARTIFACTS_DIR.mkdir(exist_ok=True)

In [2]:
splits = np.load(SPLITS_DIR / "eurosat_ms_splits.npz")

train_idx = splits["train_idx"]
val_idx   = splits["val_idx"]
test_idx  = splits["test_idx"]

print(len(train_idx), len(val_idx), len(test_idx))

12960 1620 1620


In [3]:
from torchgeo.datasets import EuroSAT as EuroSAT_MS

ds_ms = EuroSAT_MS(
    root=str(DATA_DIR / "eurosat_ms"),
    download=False
)

sample = ds_ms[0]
print(sample["image"].shape, sample["label"])

torch.Size([13, 64, 64]) tensor(0)


In [4]:
norm = torch.load(ARTIFACTS_DIR / "eurosat_ms_norm.pt")
ms_mean = norm["mean"]
ms_std  = norm["std"]

def normalize_ms(x):
    # x: [C, H, W]
    return (x - ms_mean[:, None, None]) / ms_std[:, None, None]

class NormalizedMS(torch.utils.data.Dataset):
    def __init__(self, base_ds, indices):
        self.base_ds = base_ds
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        item = self.base_ds[self.indices[i]]
        x = normalize_ms(item["image"].float())
        y = item["label"]
        return x, y

In [5]:
train_ds = NormalizedMS(ds_ms, train_idx)
val_ds   = NormalizedMS(ds_ms, val_idx)
test_ds  = NormalizedMS(ds_ms, test_idx)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0)

xb, yb = next(iter(train_loader))
print(xb.shape, yb.shape)

torch.Size([64, 13, 64, 64]) torch.Size([64])


### Typical EuroSAT / Sentinel-2 band ordering:

0–2   : Visible (B2, B3, B4)

3–5   : Red-edge

6–7   : NIR

8–12  : SWIR

In [6]:
BAND_GROUPS = {
    "VIS": [0, 1, 2],          # visible (approx B2,B3,B4)
    "RE":  [3, 4, 5],          # red-edge (approx B5,B6,B7)
    "NIR": [6, 7],             # near infrared (approx B8,B8A)
    "SWIR":[8, 9, 10, 11, 12], # short-wave infrared (approx B11,B12 plus others)
}

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PhysicsAwareStem(nn.Module):
    """
    Multi-branch stem that processes physically meaningful band groups separately,
    then fuses features into a single tensor.

    Input:  x [B, C=13, H, W]
    Output: f [B, F, H/2, W/2]  (by default downsample like ResNet's conv1+maxpool-ish behavior)
    """
    def __init__(self, band_groups, out_per_group=16):
        super().__init__()
        self.band_groups = band_groups

        def make_branch(in_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_per_group, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(out_per_group),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_per_group, out_per_group, kernel_size=3, stride=2, padding=1, bias=False),  # downsample
                nn.BatchNorm2d(out_per_group),
                nn.ReLU(inplace=True),
            )

        self.branches = nn.ModuleDict({
            name: make_branch(len(idxs)) for name, idxs in band_groups.items()
        })

        # Fuse = concat + 1x1 conv to a standard width
        fused_in = out_per_group * len(band_groups)
        self.fuse = nn.Sequential(
            nn.Conv2d(fused_in, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        feats = []
        for name, idxs in self.band_groups.items():
            xi = x[:, idxs, :, :]
            feats.append(self.branches[name](xi))
        f = torch.cat(feats, dim=1)
        return self.fuse(f)


class BasicBlock(nn.Module):
    """A minimal ResNet BasicBlock (same spirit as torchvision)."""
    expansion = 1
    def __init__(self, inplanes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.downsample = None
        if stride != 1 or inplanes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = self.relu(out + identity)
        return out


class PhysicsAwareResNet18(nn.Module):
    """
    Physics-aware version of a ResNet18-style network:
    - PhysicsAwareStem handles band grouping + early processing
    - Standard ResNet blocks for deeper feature learning
    """
    def __init__(self, band_groups, num_classes=10):
        super().__init__()
        self.stem = PhysicsAwareStem(band_groups, out_per_group=16)

        # After stem, we have 64 channels at H/2, W/2
        self.layer1 = self._make_layer(64, 64, blocks=2, stride=1)
        self.layer2 = self._make_layer(64, 128, blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, blocks=2, stride=2)
        self.layer4 = self._make_layer(256, 512, blocks=2, stride=2)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, inplanes, planes, blocks, stride):
        layers = [BasicBlock(inplanes, planes, stride=stride)]
        for _ in range(1, blocks):
            layers.append(BasicBlock(planes, planes, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)              # [B, 64, H/2, W/2]
        x = self.layer1(x)            # [B, 64, ...]
        x = self.layer2(x)            # [B, 128, ...]
        x = self.layer3(x)            # [B, 256, ...]
        x = self.layer4(x)            # [B, 512, ...]
        x = self.pool(x).flatten(1)   # [B, 512]
        return self.fc(x)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PhysicsAwareResNet18(BAND_GROUPS, num_classes=10).to(device)

xb, yb = next(iter(train_loader))  # your MS loader from Notebook 03 (normalized, num_workers=0)
xb = xb.to(device)

with torch.no_grad():
    logits = model(xb)

print("Input:", xb.shape)
print("Logits:", logits.shape)

Input: torch.Size([64, 13, 64, 64])
Logits: torch.Size([64, 10])


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [10]:
def train_one_epoch(model, loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total

In [11]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

In [12]:
best_val_acc = -1.0
best_state = None

epochs = 5  # match Notebook 03 for fair comparison

for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = evaluate(model, val_loader)

    print(
        f"Epoch {epoch}/{epochs} | "
        f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
        f"val loss {val_loss:.4f} acc {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

print("Best val acc:", best_val_acc)

Epoch 1/5 | train loss 0.7840 acc 0.7419 | val loss 0.4790 acc 0.8543
Epoch 2/5 | train loss 0.4403 acc 0.8622 | val loss 0.2968 acc 0.8981
Epoch 3/5 | train loss 0.3038 acc 0.9019 | val loss 0.2933 acc 0.9000
Epoch 4/5 | train loss 0.2584 acc 0.9167 | val loss 0.2947 acc 0.9074
Epoch 5/5 | train loss 0.2059 acc 0.9324 | val loss 0.2942 acc 0.9080
Best val acc: 0.9080246913580247


In [13]:
model.load_state_dict(best_state)

test_loss, test_acc = evaluate(model, test_loader)
print("TEST | loss:", round(test_loss, 4), "acc:", round(test_acc, 4))

TEST | loss: 0.2773 acc: 0.9142


In [14]:
import numpy as np

@torch.no_grad()
def predict_all(model, loader):
    model.eval()
    ys, ps = [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        ps.append(logits.argmax(1).cpu().numpy())
        ys.append(y.numpy())
    return np.concatenate(ys), np.concatenate(ps)

y_true, y_pred = predict_all(model, test_loader)

cm = np.zeros((10, 10), dtype=int)
for t, p in zip(y_true, y_pred):
    cm[t, p] += 1

cm

array([[157,   0,   1,   1,   0,   3,  17,   0,   0,   0],
       [  0, 172,   1,   0,   0,   5,   0,   0,   0,   0],
       [  0,   1, 165,   6,   0,   4,   2,   0,   2,   0],
       [  0,   0,   0, 145,   0,   0,   1,   4,   0,   0],
       [  0,   0,   0,  27,  97,   0,   0,  26,   0,   0],
       [  1,   0,   2,   8,   0, 109,   0,   0,   0,   0],
       [  2,   1,   0,   2,   0,   3, 138,   1,   1,   0],
       [  0,   0,   0,   1,   0,   0,   0, 186,   0,   0],
       [  0,   0,   0,   6,   0,   0,   0,   0, 140,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  10, 172]])

In [15]:
results = {
    "best_val_acc": float(best_val_acc),
    "test_acc": float(test_acc),
    "band_groups": BAND_GROUPS,
    "epochs": epochs,
    "lr": 1e-3,
}
results

{'best_val_acc': 0.9080246913580247,
 'test_acc': 0.9141975308641975,
 'band_groups': {'VIS': [0, 1, 2],
  'RE': [3, 4, 5],
  'NIR': [6, 7],
  'SWIR': [8, 9, 10, 11, 12]},
 'epochs': 5,
 'lr': 0.001}

In [16]:
import json
out_path = ARTIFACTS_DIR / "physics_aware_results.json"
with open(out_path, "w") as f:
    json.dump(results, f, indent=2)
print("Saved:", out_path)

Saved: C:\Users\ishaa\OneDrive\Desktop\Projects\eurosat-physics-aware-image-classification\artifacts\physics_aware_results.json
