In [None]:
# ============================================================
# 0. Setup: installs & imports
# ============================================================

!pip install -q huggingface_hub torch torchvision matplotlib

import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from huggingface_hub import snapshot_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# ============================================================
# 1. Download Geneva dataset
# ============================================================

dataset_root = snapshot_download(repo_id="raphaelattias/overfitteam-geneva-satellite-images", repo_type="dataset")
print("Dataset root:", dataset_root)

In [None]:
# ============================================================
# 2. Base dataset: load images & masks with transforms
# ============================================================


IMAGE_SIZE = 256  # resize tiles to this

img_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

mask_transform = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=Image.NEAREST),
        transforms.ToTensor(),  # gives float [0,1] for grayscale
    ]
)


class GenevaRooftopDataset(Dataset):
    """
    Regular (image, mask) dataset for one split and category.
    Mask is binary: 1 = rooftop suitable for PV, 0 = background.
    """

    def __init__(self, root, split="train", category="all"):
        super().__init__()
        self.root = root
        self.split = split
        self.category = category

        self.image_dir = os.path.join(root, split, "images", category)
        self.label_dir = os.path.join(root, split, "labels", category)

        self.filenames = sorted([f for f in os.listdir(self.image_dir) if f.endswith(".png")])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]

        img_path = os.path.join(self.image_dir, fname)
        label_name = fname.replace(".png", "_label.png")
        mask_path = os.path.join(self.label_dir, label_name)

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transforms
        img = img_transform(img)
        mask = mask_transform(mask)  # [1, H, W], float

        # Binarise: any non-zero pixel -> 1
        mask = (mask > 0.5).float()

        return img, mask


# Instantiate datasets (we focus on 'all' category)
train_base = GenevaRooftopDataset(dataset_root, split="train", category="all")
val_base = GenevaRooftopDataset(dataset_root, split="val", category="all")
test_base = GenevaRooftopDataset(dataset_root, split="test", category="all")

print(f"Train samples: {len(train_base)}, Val: {len(val_base)}, Test: {len(test_base)}")

In [None]:
# ============================================================
# 3. Quick visualisation: image + mask for sanity check
# ============================================================


def show_sample(dataset, idx=None):
    if idx is None:
        idx = random.randint(0, len(dataset) - 1)
    img, mask = dataset[idx]  # img [3,H,W], mask [1,H,W]

    # Undo normalisation for plotting
    img_np = img.permute(1, 2, 0).numpy()
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)

    # mask: [1, H, W] -> [H, W]
    mask_np = mask.squeeze(0).numpy()

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(mask_np, cmap="gray")
    plt.title("Mask (rooftop)")
    plt.axis("off")
    plt.show()


show_sample(train_base)

In [8]:
# ============================================================
# 4. Few-shot / episodic dataset: support + query
# ============================================================


class EpisodeDataset(Dataset):
    """
    Yields (support_image, support_mask, query_image, query_mask).
    We sample two different tiles from the base training set for each episode.
    """

    def __init__(self, base_dataset, episodes_per_epoch=500):
        self.base = base_dataset
        self.episodes_per_epoch = episodes_per_epoch
        self.n = len(base_dataset)

    def __len__(self):
        return self.episodes_per_epoch

    def __getitem__(self, idx):
        i, j = random.sample(range(self.n), 2)
        img_s, mask_s = self.base[i]
        img_q, mask_q = self.base[j]
        return img_s, mask_s, img_q, mask_q


episodes_per_epoch = 500  # adjust for speed vs quality
episode_dataset = EpisodeDataset(train_base, episodes_per_epoch=episodes_per_epoch)
episode_loader = DataLoader(episode_dataset, batch_size=1, shuffle=True)

In [None]:
# ============================================================
# 5. Encoder backbone for few-shot segmentation
# ============================================================


class Encoder(nn.Module):
    """
    Convolutional encoder using ResNet18 up to layer3 (downsampling by 8),
    followed by a 1x1 conv to get a compact embedding.
    """

    def __init__(self, out_channels=256, pretrained=True):
        super().__init__()
        backbone = models.resnet18(pretrained=pretrained)
        self.features = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
            backbone.layer1,
            backbone.layer2,
            backbone.layer3,
        )
        # resnet18.layer3 output has 256 channels
        self.proj = nn.Conv2d(256, out_channels, kernel_size=1)

    def forward(self, x):
        f = self.features(x)  # [B, 256, H', W']
        f = self.proj(f)  # [B, C,   H', W']
        return f


encoder = Encoder(out_channels=256, pretrained=True).to(device)
print(encoder)

In [10]:
# ============================================================
# 6. Prototype computation & query classification
# ============================================================


def compute_prototypes(feat_support, mask_support):
    """
    Compute background/foreground prototypes from multiple support images.

    feat_support: [K, C, H', W']   (K support images)
    mask_support: [K, 1, H,  W]    (binary masks)
    Returns: prototypes [2, C] (0=background, 1=foreground)
    """
    # Downsample mask to feature resolution
    mask_small = F.interpolate(mask_support, size=feat_support.shape[2:], mode="nearest")
    mask_fg = (mask_small > 0.5).float()  # [K,1,H',W']
    mask_bg = 1.0 - mask_fg  # [K,1,H',W']

    K, C, Hf, Wf = feat_support.shape

    # Flatten across batch and spatial dims: [K,C,H',W'] -> [C, K*H'*W']
    fs = feat_support.permute(1, 0, 2, 3).contiguous().view(C, -1)  # [C, K*H'*W']
    fg_w = mask_fg.view(1, -1)  # [1, K*H'*W']
    bg_w = mask_bg.view(1, -1)

    eps = 1e-6

    # Weighted average for foreground
    fg_proto = (fs * fg_w).sum(dim=1) / (fg_w.sum(dim=1) + eps)  # [C]
    # Weighted average for background
    bg_proto = (fs * bg_w).sum(dim=1) / (bg_w.sum(dim=1) + eps)  # [C]

    prototypes = torch.stack([bg_proto, fg_proto], dim=0)  # [2,C]
    return prototypes


def classify_query(feat_query, prototypes):
    """
    Classify query pixels by distance to prototypes.

    feat_query: [1, C, H', W']
    prototypes: [2, C]
    Returns: logits [1, 2, H', W']
    """
    B, C, Hq, Wq = feat_query.shape

    # [1,C,H',W'] -> [H'*W', C]
    fq = feat_query.view(C, -1).t()  # [H'*W', C]

    # [2, C]
    protos = prototypes  # [2,C]

    # Compute squared Euclidean distance from each pixel to each prototype
    # torch.cdist expects [B, N, D], so add batch dim
    # fq_batch: [1, H'*W', C], protos_batch: [1, 2, C]
    dists = torch.cdist(fq.unsqueeze(0), protos.unsqueeze(0))  # [1, H'*W', 2]
    dists = dists.squeeze(0)  # [H'*W', 2]
    dists = dists**2

    # Convert distances to similarity logits: negative distance
    logits_flat = -dists  # [H'*W', 2]
    logits = logits_flat.t().view(1, 2, Hq, Wq)  # [1,2,H',W']

    return logits

In [11]:
# ============================================================
# 7. IoU metric for evaluation
# ============================================================


def iou_from_logits(logits, target_mask, eps=1e-6):
    """
    logits: [1,2,H,W] (class 0=background, 1=foreground)
    target_mask: [1,1,H,W], binary 0/1
    """
    # predicted class (0 or 1)
    pred = logits.argmax(dim=1, keepdim=True).float()  # [1,1,H,W]
    target = (target_mask > 0.5).float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection

    iou = (intersection + eps) / (union + eps)
    return iou.item()

In [None]:
# ============================================================
# 8. Episodic meta-training loop
# ============================================================

optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)


def meta_train(num_epochs=5):
    encoder.train()
    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0

        for img_s, mask_s, img_q, mask_q in episode_loader:
            img_s = img_s.to(device)
            mask_s = mask_s.to(device)
            img_q = img_q.to(device)
            mask_q = mask_q.to(device)

            optimizer.zero_grad()

            # 1) Encode support & query
            feat_s = encoder(img_s)  # [1,C,H',W']
            feat_q = encoder(img_q)  # [1,C,H',W']

            # 2) Prototypes from support
            prototypes = compute_prototypes(feat_s, mask_s)  # [2,C]

            # 3) Classify query pixels
            logits_q_small = classify_query(feat_q, prototypes)  # [1,2,H',W']

            # 4) Upsample logits to original mask size
            logits_q = F.interpolate(
                logits_q_small, size=mask_q.shape[2:], mode="bilinear", align_corners=False
            )  # [1,2,H,W]

            # 5) Cross-entropy loss; target: 0/1
            target_q = mask_q.long().squeeze(1)  # [1,H,W] with values {0,1}
            loss = F.cross_entropy(logits_q, target_q)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(episode_loader)
        print(f"Epoch {epoch}/{num_epochs} | avg episode loss: {avg_loss:.4f}")


# Run meta-training
meta_train(num_epochs=5)

In [18]:
# ============================================================
# 9. Few-shot inference on test images
# ============================================================


def k_shot_predict(encoder, support_imgs, support_masks, query_img):
    """
    K-shot segmentation for a query image given K support images+masks.

    support_imgs:  [K, 3, H, W]
    support_masks: [K, 1, H, W]
    query_img:     [3, H, W]
    Returns: logits [1, 2, H, W]
    """
    encoder.eval()
    with torch.no_grad():
        support_imgs = support_imgs.to(device)  # [K,3,H,W]
        support_masks = support_masks.to(device)  # [K,1,H,W]
        query_img = query_img.to(device).unsqueeze(0)  # [1,3,H,W]

        feat_s = encoder(support_imgs)  # [K,C,H',W']
        feat_q = encoder(query_img)  # [1,C,H',W']

        prototypes = compute_prototypes(feat_s, support_masks)  # [2,C]

        logits_small = classify_query(feat_q, prototypes)  # [1,2,H',W']
        logits = F.interpolate(
            logits_small,
            size=(query_img.shape[2], query_img.shape[3]),
            mode="bilinear",
            align_corners=False,
        )  # [1,2,H,W]

    return logits.cpu()


def one_shot_predict(encoder, support_img, support_mask, query_img):
    """
    1-shot helper that wraps single support into K=1 form.
    """
    support_imgs = support_img.unsqueeze(0)  # [1,3,H,W]
    support_masks = support_mask.unsqueeze(0)  # [1,1,H,W]
    return k_shot_predict(encoder, support_imgs, support_masks, query_img)

In [19]:
# ============================================================
# 9a. K-shot inference on test images
# ============================================================

import torch


def evaluate_kshot_iou(encoder, train_dataset, test_dataset, K=5, num_samples=None):
    """
    Evaluate K-shot IoU on 'num_samples' random test images.
    For each test image, randomly sample K *distinct* support images from the train set.
    """
    encoder.eval()
    rng = np.random.default_rng(42)

    # If num_samples is None, evaluate on all test samples
    if num_samples is None:
        num_samples = len(test_dataset)

    ious = []
    for _ in range(num_samples):
        # pick random test index
        ti = rng.integers(0, len(test_dataset))
        img_q, mask_q = test_dataset[ti]

        # pick K distinct support indices
        support_indices = rng.choice(len(train_dataset), size=K, replace=False)
        support_imgs = []
        support_masks = []
        for si in support_indices:
            img_s, mask_s = train_dataset[si]
            support_imgs.append(img_s)
            support_masks.append(mask_s)
        support_imgs = torch.stack(support_imgs, dim=0)  # [K,3,H,W]
        support_masks = torch.stack(support_masks, dim=0)  # [K,1,H,W]

        # run K-shot prediction
        logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)  # [1,2,H,W]

        iou = iou_from_logits(logits, mask_q.unsqueeze(0))
        ious.append(iou)

    ious = np.array(ious)
    print(f"{K}-shot mean IoU over {num_samples} test samples: {ious.mean():.3f} Â± {ious.std():.3f}")
    return ious

In [None]:
# ============================================================
# 9b. Visualise a 1-shot episode (support + query + prediction)
# ============================================================


def tensor_to_rgb(img_tensor):
    """Undo normalisation and convert [3,H,W] tensor to [H,W,3] RGB numpy."""
    img_np = img_tensor.detach().cpu().permute(1, 2, 0).numpy()
    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_np = np.clip(img_np, 0, 1)
    return img_np


def visualise_few_shot_example(encoder, train_dataset, test_dataset):
    encoder.eval()
    rng = np.random.default_rng()

    # pick support from train, query from test
    si = rng.integers(0, len(train_dataset))
    ti = rng.integers(0, len(test_dataset))

    img_s, mask_s = train_dataset[si]
    img_q, mask_q = test_dataset[ti]

    logits = one_shot_predict(encoder, img_s, mask_s, img_q)  # [1,2,H,W]
    pred_mask = logits.argmax(dim=1, keepdim=True).float().squeeze(0).squeeze(0).numpy()

    # convert to numpy for plotting
    img_s_np = tensor_to_rgb(img_s)
    img_q_np = tensor_to_rgb(img_q)
    mask_s_np = mask_s.squeeze(0).numpy()
    mask_q_np = mask_q.squeeze(0).numpy()

    plt.figure(figsize=(12, 6))

    plt.subplot(2, 3, 1)
    plt.imshow(img_s_np)
    plt.title("Support image")
    plt.axis("off")

    plt.subplot(2, 3, 2)
    plt.imshow(mask_s_np, cmap="gray")
    plt.title("Support mask")
    plt.axis("off")

    plt.subplot(2, 3, 3)
    plt.imshow(img_q_np)
    plt.title("Query image")
    plt.axis("off")

    plt.subplot(2, 3, 5)
    plt.imshow(mask_q_np, cmap="gray")
    plt.title("Query GT mask")
    plt.axis("off")

    plt.subplot(2, 3, 6)
    plt.imshow(pred_mask, cmap="gray")
    plt.title("Predicted mask (1-shot)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()


visualise_few_shot_example(encoder, train_base, test_base)

In [None]:
# ============================================================
# 9d. K-shot inference on test images
# ============================================================

# 1-shot using the same function
ious_1shot = evaluate_kshot_iou(encoder, train_base, test_base, K=1, num_samples=None)

# K-shot
ious_5shot = evaluate_kshot_iou(encoder, train_base, test_base, K=5, num_samples=None)

In [None]:
# ============================================================
# 9e. Visualise a K-shot episode (support + query + prediction)
# ============================================================


def visualise_kshot_example(encoder, train_dataset, test_dataset, K=5):
    encoder.eval()
    rng = np.random.default_rng()

    # pick query from test
    ti = rng.integers(0, len(test_dataset))
    img_q, mask_q = test_dataset[ti]

    # pick K supports from train
    support_indices = rng.choice(len(train_dataset), size=K, replace=False)
    support_imgs, support_masks = [], []
    for si in support_indices:
        img_s, mask_s = train_dataset[si]
        support_imgs.append(img_s)
        support_masks.append(mask_s)
    support_imgs = torch.stack(support_imgs, dim=0)  # [K,3,H,W]
    support_masks = torch.stack(support_masks, dim=0)  # [K,1,H,W]

    # prediction
    logits = k_shot_predict(encoder, support_imgs, support_masks, img_q)
    pred_mask = logits.argmax(dim=1, keepdim=True).float().squeeze().numpy()

    img_q_np = tensor_to_rgb(img_q)
    mask_q_np = mask_q.squeeze(0).numpy()

    # Plot
    cols = max(K, 3)
    plt.figure(figsize=(4 * cols, 8))

    # first row: support images
    for i in range(K):
        plt.subplot(2, cols, i + 1)
        plt.imshow(tensor_to_rgb(support_imgs[i]))
        plt.title(f"Support {i+1} image")
        plt.axis("off")

    # second row: support masks
    for i in range(K):
        plt.subplot(2, cols, cols + i + 1)
        plt.imshow(support_masks[i].squeeze(0).numpy(), cmap="gray")
        plt.title(f"Support {i+1} mask")
        plt.axis("off")

    # query image & masks to the right (replace last columns if needed)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(img_q_np)
    plt.title("Query image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(mask_q_np, cmap="gray")
    plt.title("Query GT mask")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask, cmap="gray")
    plt.title(f"Predicted mask ({K}-shot)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()


# Example 5-shot visualisation
visualise_kshot_example(encoder, train_base, test_base, K=5)