# Code Challenge 1 â€” B10 Reconstruction & Inference Pipeline

This notebook trains a lightweight CNN to reconstruct the missing Sentinel-2 B10 band and then applies a pre-trained EuroSAT classifier to Kaggle tiles after inserting the predicted band.

In [None]:
!pip install -q rasterio transformers huggingface_hub accelerate

In [None]:
from google.colab import drive
from pathlib import Path
from datetime import datetime
import os, random, math, time, shutil, glob, re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

import rasterio

from transformers import AutoImageProcessor, AutoModelForImageClassification

def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


RANDOM_SEED = 42
set_seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
drive.mount('/content/drive', force_remount=True)

ZIP_PATH = "/content/drive/MyDrive/ML_HSG/EuroSAT_MS.zip"
DATA_ROOT = Path('/content/EuroSAT_MS')
MODELS_DIR = Path('/content/drive/MyDrive/ML_HSG/models')
PLOT_DIR = Path('/content/drive/MyDrive/ML_HSG/plots')
B10_MODEL_DIR = MODELS_DIR / 'cirrus_cnn'
OUTPUT_DIR = Path('/content/drive/MyDrive/ML_HSG/kaggle_submissions')

for path in (MODELS_DIR, PLOT_DIR, B10_MODEL_DIR, OUTPUT_DIR):
    path.mkdir(parents=True, exist_ok=True)

if not DATA_ROOT.exists():
    print(f"Extracting dataset from {ZIP_PATH} ...")
    !unzip -q -o "$ZIP_PATH" -d '/content'
else:
    print(f"Dataset already available at {DATA_ROOT}")

RUN_TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"RUN_TIMESTAMP={RUN_TIMESTAMP}")
print(f"Plots -> {PLOT_DIR}")
print(f"Models -> {MODELS_DIR}")
print(f"Kaggle submissions -> {OUTPUT_DIR}")

In [None]:
CLASS_NAMES = []
samples = []

if DATA_ROOT.exists():
    class_dirs = sorted([d for d in DATA_ROOT.iterdir() if d.is_dir()])
    for d in class_dirs:
        label = len(CLASS_NAMES)
        CLASS_NAMES.append(d.name)
        for tif_path in sorted(d.glob('*.tif')):
            samples.append((tif_path, label))
else:
    raise FileNotFoundError(f"{DATA_ROOT} not found. Extract the archive first.")

CLASS_TO_IDX = {name: idx for idx, name in enumerate(CLASS_NAMES)}
IDX_TO_CLASS = {idx: name for name, idx in CLASS_TO_IDX.items()}

print(f"Detected {len(CLASS_NAMES)} classes")
print(f"Total samples: {len(samples)}")

DROP_BAND_INDEX = 10  # Sentinel-2 B10 (cirrus)
KEEP_IDX_13 = np.array([i for i in range(13) if i != DROP_BAND_INDEX])
KEEP_IDX_12 = np.array([i for i in range(12)])
CIRRUS_SCALE = 10000.0
CIRRUS_MODEL = None

In [None]:
class CirrusCNN(nn.Module):
    def __init__(self, in_channels: int = 12, hidden_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, 1, kernel_size=1, bias=True)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def sample_patches(arr: np.ndarray, num_patches: int, patch_size: int, rng: np.random.Generator):
    c, h, w = arr.shape
    ps = min(patch_size, h, w)
    patches_x = []
    patches_y = []
    for _ in range(num_patches):
        top = int(rng.integers(0, max(1, h - ps + 1)))
        left = int(rng.integers(0, max(1, w - ps + 1)))
        patch = arr[:, top:top + ps, left:left + ps]
        patches_x.append(patch[KEEP_IDX_13])
        patches_y.append(patch[DROP_BAND_INDEX:DROP_BAND_INDEX + 1])
    return patches_x, patches_y


def fit_cirrus_cnn(sample_paths,
                    max_tiles: int = 120,
                    patches_per_tile: int = 48,
                    patch_size: int = 64,
                    epochs: int = 10,
                    batch_size: int = 128,
                    lr: float = 5e-4) -> dict:
    if not sample_paths:
        raise ValueError("No samples provided for cirrus CNN fitting")

    rng = np.random.default_rng(RANDOM_SEED)
    paths = list(sample_paths)
    if len(paths) > max_tiles:
        paths = rng.choice(paths, size=max_tiles, replace=False)
    print(f"[CirrusCNN] Using {len(paths)} tiles x {patches_per_tile} patches (patch {patch_size}x{patch_size}).")

    features = []
    targets = []
    for path in paths:
        with rasterio.open(path) as src:
            arr = src.read().astype(np.float32)
        if arr.shape[0] != 13:
            continue
        arr = np.clip(arr, 0.0, CIRRUS_SCALE) / CIRRUS_SCALE
        px, py = sample_patches(arr, patches_per_tile, patch_size, rng)
        features.extend(px)
        targets.extend(py)

    if not features:
        raise ValueError("Cirrus CNN requires at least one 13-band tile for training")

    X = torch.from_numpy(np.stack(features)).float()
    y = torch.from_numpy(np.stack(targets)).float()

    perm = torch.randperm(X.shape[0], generator=torch.Generator().manual_seed(RANDOM_SEED))
    X = X[perm]
    y = y[perm]
    split = int(0.9 * X.shape[0])
    X_train, X_val = X[:split], X[split:]
    y_train, y_val = y[:split], y[split:]

    train_ds = torch.utils.data.TensorDataset(X_train, y_train)
    val_ds = torch.utils.data.TensorDataset(X_val, y_val)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CirrusCNN().to(device)
    criterion = nn.SmoothL1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    best_state = None
    best_val = float('inf')
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            optimizer.zero_grad(set_to_none=True)
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * xb.size(0)
        train_loss /= len(train_ds)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                preds = model(xb)
                val_loss += criterion(preds, yb).item() * xb.size(0)
        val_loss /= max(1, len(val_ds))
        print(f"[CirrusCNN] epoch {epoch + 1}/{epochs} train={train_loss:.6f} val={val_loss:.6f}")
        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)

    with torch.no_grad():
        preds = model(X.to(device)).cpu()
        mae = torch.nn.functional.l1_loss(preds, y, reduction='mean').item()
        rmse = torch.sqrt(torch.nn.functional.mse_loss(preds, y, reduction='mean')).item()

    total_patches = len(train_ds) + len(val_ds)
    print(f"[CirrusCNN] Fitted on {total_patches} patches | MAE={mae:.6f} | RMSE={rmse:.6f}")

    model.cpu()
    return {
        'module': model,
        'mae': mae,
        'rmse': rmse,
        'scale': CIRRUS_SCALE,
    }


def synthesize_cirrus(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr, dtype=np.float32)
    if CIRRUS_MODEL is None:
        raise RuntimeError('CIRRUS_MODEL is None. Train or load the B10 reconstruction model before calling synthesize_cirrus.')

    scale = CIRRUS_MODEL['scale']
    arr_scaled = np.clip(arr, 0.0, scale) / scale
    if arr.shape[0] == 13:
        feats = arr_scaled[KEEP_IDX_13]
    elif arr.shape[0] == 12:
        feats = arr_scaled[KEEP_IDX_12]
    else:
        raise ValueError(f"Expected 12 or 13 bands, got {arr.shape}")

    tensor = torch.from_numpy(feats).unsqueeze(0)
    module = CIRRUS_MODEL['module']
    module.eval()
    with torch.no_grad():
        pred = module(tensor).squeeze(0).squeeze(0).numpy()
    pred = np.clip(pred, 0.0, 1.0) * scale
    return pred.astype(np.float32)


def pad_to_13_bands(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr, dtype=np.float32)
    if arr.shape[0] == 13:
        out = arr.copy()
        out[DROP_BAND_INDEX] = synthesize_cirrus(arr)
        return out
    if arr.shape[0] == 12:
        cirrus = synthesize_cirrus(arr)
        return np.concatenate([arr[:DROP_BAND_INDEX], cirrus[np.newaxis, ...], arr[DROP_BAND_INDEX:]], axis=0)
    raise ValueError(f"Expected 12 or 13 bands, got {arr.shape}")


def robust_normalize(arr: np.ndarray) -> np.ndarray:
    arr = arr.astype(np.float32, copy=False)
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        band = arr[i]
        lo, hi = np.percentile(band, [2, 98])
        if hi > lo:
            band = (band - lo) / (hi - lo)
        else:
            min_v, max_v = band.min(), band.max()
            if max_v > min_v:
                band = (band - min_v) / (max_v - min_v)
            else:
                band = np.zeros_like(band)
        out[i] = np.clip(band, 0.0, 1.0)
    return out


def load_multispectral(path: Path) -> np.ndarray:
    with rasterio.open(path) as src:
        arr = src.read()
    arr = pad_to_13_bands(arr)
    arr = robust_normalize(arr)
    return arr


In [None]:
cirrus_paths = [p for p, _ in samples]
CIRRUS_MODEL = fit_cirrus_cnn(cirrus_paths, epochs=12, patches_per_tile=64, patch_size=64)
CIRRUS_MODEL['module'] = CIRRUS_MODEL['module'].cpu()
CIRRUS_MODEL['module'].eval()

b10_bundle = {
    'state_dict': CIRRUS_MODEL['module'].state_dict(),
    'mae': CIRRUS_MODEL['mae'],
    'rmse': CIRRUS_MODEL['rmse'],
    'scale': CIRRUS_MODEL['scale'],
    'drop_band_index': DROP_BAND_INDEX,
    'keep_idx_12': KEEP_IDX_12,
    'keep_idx_13': KEEP_IDX_13,
}

b10_model_path = B10_MODEL_DIR / f'cirrus_cnn_{RUN_TIMESTAMP}.pt'
torch.save(b10_bundle, b10_model_path)
print(f"Saved B10 reconstruction weights to {b10_model_path}")

In [None]:
MODEL_ID = 'Rhodham96/EuroSatCNN'
MODEL_SOURCE = 'auto'  # auto -> local first, fallback to HF
LOCAL_MODEL_DIR_CANDIDATES = [
    Path('./local_models/Rhodham96-EuroSatCNN'),
    Path('/content/drive/MyDrive/ML_HSG/models/Rhodham96-EuroSatCNN'),
    Path('/content/drive/MyDrive/ML_HSG/models/Rhodham96- EuroSatCNN'),
]

LOCAL_MODEL_DIR = next((p for p in LOCAL_MODEL_DIR_CANDIDATES if p.exists()), LOCAL_MODEL_DIR_CANDIDATES[0])
LOCAL_MODEL_DEF = LOCAL_MODEL_DIR / 'model_def.py'
LOCAL_MODEL_WEIGHTS = LOCAL_MODEL_DIR / 'pytorch_model.bin'

print(f"Loading classifier {MODEL_ID} (source={MODEL_SOURCE})")
print(f"Resolved local directory: {LOCAL_MODEL_DIR}")
model = None
processor = None
image_size = 224

local_available = LOCAL_MODEL_DEF.exists() and LOCAL_MODEL_WEIGHTS.exists()

if MODEL_SOURCE in ('auto', 'local') and local_available:
    import importlib.util
    spec = importlib.util.spec_from_file_location('eurosat_local_model', LOCAL_MODEL_DEF)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    if not hasattr(module, 'EuroSATCNN'):
        raise AttributeError(f"EuroSATCNN class not found in {LOCAL_MODEL_DEF}")
    EuroSATCNN = module.EuroSATCNN
    model = EuroSATCNN(num_classes=len(CLASS_NAMES))
    state_dict = torch.load(LOCAL_MODEL_WEIGHTS, map_location='cpu')
    if isinstance(state_dict, dict) and 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    model.load_state_dict(state_dict)
    image_size = getattr(model, 'image_size', getattr(model, 'input_resolution', image_size))
    print(f"Loaded local weights from {LOCAL_MODEL_WEIGHTS}")

if model is None and MODEL_SOURCE in ('auto', 'hf'):
    print(f"Falling back to Hugging Face hub for {MODEL_ID}")
    model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
    image_size = getattr(model.config, 'image_size', image_size)
    try:
        processor = AutoImageProcessor.from_pretrained(MODEL_ID)
    except (OSError, IndexError) as hf_proc_err:
        print(f"Processor load failed: {hf_proc_err}")
        processor = None

if processor is None:
    print('Using fallback processor.')

    class EuroSatFallbackImageProcessor:
        def __init__(self, image_size: int = 224):
            self.image_size = image_size

        def _prep_single(self, image):
            if isinstance(image, np.ndarray):
                tensor = torch.from_numpy(image)
                if tensor.dtype != torch.float32:
                    tensor = tensor.float()
            elif torch.is_tensor(image):
                tensor = image.float()
            else:
                tensor = torch.from_numpy(np.asarray(image, dtype=np.float32))

            if tensor.ndim != 3:
                raise ValueError(f"Expected image with 3 dims, got {tuple(tensor.shape)}")

            if tensor.shape[0] not in (3, 13):
                tensor = tensor.permute(2, 0, 1)

            tensor = tensor.clamp(0.0, 1.0)

            if tensor.shape[1:] != (self.image_size, self.image_size):
                tensor = torch.nn.functional.interpolate(
                    tensor.unsqueeze(0),
                    size=(self.image_size, self.image_size),
                    mode='bilinear',
                    align_corners=False
                ).squeeze(0)

            return tensor

        def __call__(self, images, return_tensors=None):
            if isinstance(images, (list, tuple)):
                batch = torch.stack([self._prep_single(img) for img in images])
            else:
                batch = self._prep_single(images).unsqueeze(0)

            if return_tensors in (None, 'pt'):
                return {'pixel_values': batch}

            raise ValueError(f"Unsupported return_tensors value: {return_tensors!r}")

    processor = EuroSatFallbackImageProcessor(image_size=image_size)

if not hasattr(model, 'config'):
    from types import SimpleNamespace
    model.config = SimpleNamespace()

model.config.label2id = {cls: idx for idx, cls in enumerate(CLASS_NAMES)}
model.config.id2label = {idx: cls for cls, idx in model.config.label2id.items()}
model.config.num_labels = len(CLASS_NAMES)
model.config.image_size = image_size

model.to(DEVICE)
model.eval()
print('Classifier ready for inference.')

In [None]:
class EuroSATNPYDataset(Dataset):
    def __init__(self, paths, processor):
        self.paths = paths
        self.processor = processor

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        arr = np.load(path, allow_pickle=False)
        if arr.ndim == 3 and arr.shape[0] not in (12, 13) and arr.shape[-1] in (12, 13):
            arr = np.moveaxis(arr, -1, 0)
        arr = pad_to_13_bands(arr)
        arr = robust_normalize(arr)
        arr = np.moveaxis(arr, 0, -1)
        inputs = self.processor(images=arr, return_tensors='pt')
        pixel_values = inputs['pixel_values'].squeeze(0)
        sample_id = self._extract_id(path.stem)
        return {'pixel_values': pixel_values, 'id': sample_id}

    @staticmethod
    def _extract_id(stem: str) -> int:
        match = re.search(r'(\d+)$', stem)
        if match:
            return int(match.group(1))
        digits = ''.join(ch for ch in stem if ch.isdigit())
        return int(digits) if digits else -1


In [None]:
TEST_ROOT = Path('/content/drive/MyDrive/ML_HSG/kaggle_data/testset/testset')
if not TEST_ROOT.exists():
    raise FileNotFoundError(f"Expected test directory at {TEST_ROOT}")

npy_paths = sorted(TEST_ROOT.glob('*.npy'))
if not npy_paths:
    npy_paths = sorted(TEST_ROOT.glob('**/*.npy'))

if len(npy_paths) == 0:
    raise FileNotFoundError('No .npy files found in Kaggle test directory')

print(f"Found {len(npy_paths)} inference tiles")

BATCH_SIZE = 32
NUM_WORKERS = 2

kaggle_dataset = EuroSATNPYDataset(npy_paths, processor)
kaggle_loader = DataLoader(kaggle_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model.eval()
pred_ids = []
pred_indices = []

with torch.no_grad():
    for batch in tqdm(kaggle_loader, desc='Kaggle inference'):
        pixel_values = batch['pixel_values'].to(DEVICE, non_blocking=True)
        outputs = model(pixel_values=pixel_values)
        if hasattr(outputs, 'logits'):
            logits = outputs.logits
        elif isinstance(outputs, dict) and 'logits' in outputs:
            logits = outputs['logits']
        else:
            logits = outputs
        preds = logits.argmax(dim=1).cpu().tolist()
        pred_indices.extend(preds)
        if isinstance(batch['id'], torch.Tensor):
            pred_ids.extend(batch['id'].cpu().tolist())
        else:
            pred_ids.extend(batch['id'])

pred_labels = [IDX_TO_CLASS[idx] for idx in pred_indices]
submission = pd.DataFrame({'test_id': pred_ids, 'label': pred_labels})
submission = submission.sort_values('test_id').reset_index(drop=True)

submission_name = f'submission_with_cirrus_{RUN_TIMESTAMP}.csv'
submission_path = Path('/content') / submission_name
submission.to_csv(submission_path, index=False)
print(f"Saved submission to {submission_path}")
print(submission.head())

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
drive_submission_path = OUTPUT_DIR / submission_name
shutil.copy2(submission_path, drive_submission_path)
print(f"Copied submission to {drive_submission_path}")