# Week 3 â€“ Water Segmentation with U-Net

This notebook explores the provided multispectral water dataset and implements a full U-Net segmentation pipeline.

Steps we will follow:
- **Inspect the raw data** (image and label formats, shapes, basic stats)
- **Visualize bands** and example masks
- **Prepare a PyTorch dataset & data loaders**
- **Define a U-Net model** for 12-channel input and 1-channel water mask output
- **Train the model** and track IoU / precision / recall / F1 for the water class
- **Visualize predictions** vs ground-truth masks

In [3]:
# Imports and basic paths
import os
import glob

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import tifffile  # for reading multi-band .tif

# NOTE: The .tif tiles are multispectral and PIL cannot decode them correctly.
# We will *always* use tifffile for reading the images. Please make sure
# `tifffile` is installed in your Python environment: `pip install tifffile`.

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn

# Paths (adjust if you move the notebook)
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))  # points to `Task 3`
IMAGES_DIR = os.path.join(BASE_DIR, "data", "images")
LABELS_DIR = os.path.join(BASE_DIR, "data", "labels")

print("Images dir:", IMAGES_DIR)
print("Labels dir:", LABELS_DIR)

image_files = sorted(glob.glob(os.path.join(IMAGES_DIR, "*.tif")))
label_files = sorted(glob.glob(os.path.join(LABELS_DIR, "*.png")))

print(f"Found {len(image_files)} image tiles")
print(f"Found {len(label_files)} label masks (including augmented variants with underscores)")

tifffile is not installed. Run `pip install tifffile` in this environment.
Images dir: d:\Cellula_Internship\Task 3\data\images
Labels dir: d:\Cellula_Internship\Task 3\data\labels
Found 306 image tiles
Found 456 label masks (including augmented variants with underscores)


In [4]:
# Explore a single example (image + label)

# We will use only label files whose name is an integer (no underscore),
# so that they map directly to the corresponding image index.

def is_base_label(fname: str) -> bool:
    """Return True if label filename (without extension) is a pure integer (e.g. '123')."""
    stem = os.path.splitext(os.path.basename(fname))[0]
    return stem.isdigit()

base_label_files = [f for f in label_files if is_base_label(f)]
base_label_files = sorted(base_label_files, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))

print(f"Using {len(base_label_files)} base label masks (1:1 with image indices)")

# Pick one index to inspect
example_idx = 0
example_img_path = os.path.join(IMAGES_DIR, f"{example_idx}.tif")
example_lbl_path = os.path.join(LABELS_DIR, f"{example_idx}.png")

print("Example image:", example_img_path)
print("Example label:", example_lbl_path)

# Load the multispectral tile *only* with tifffile
img = tifffile.imread(example_img_path)  # expected shape: (H, W, C) or (C, H, W)

print("Raw image shape:", img.shape)

# Make sure we have (C, H, W)
if img.ndim == 3 and img.shape[0] in (12, 13):
    img_chw = img
elif img.ndim == 3 and img.shape[-1] in (12, 13):
    img_chw = np.transpose(img, (2, 0, 1))
else:
    raise ValueError(f"Unexpected image shape {img.shape}; please inspect this cell.")

num_channels, H, W = img_chw.shape
print(f"Image has {num_channels} bands, height={H}, width={W}")

# Load the corresponding label mask
label = np.array(Image.open(example_lbl_path))
print("Raw label shape:", label.shape, "dtype:", label.dtype)

# Ensure label is binary (0/1)
unique_vals = np.unique(label)
print("Unique label values:", unique_vals)

# If values are not 0/1, we will binarize later in the Dataset class if needed.

More samples per pixel than can be decoded: 12


Using 306 base label masks (1:1 with image indices)
Example image: d:\Cellula_Internship\Task 3\data\images\0.tif
Example label: d:\Cellula_Internship\Task 3\data\labels\0.png


UnidentifiedImageError: cannot identify image file 'D:\\Cellula_Internship\\Task 3\\data\\images\\0.tif'

In [None]:
# Visualize a few bands and the label mask

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

# Show up to 7 bands (0-6) + label
max_bands_to_show = 7
for i in range(max_bands_to_show):
    if i >= num_channels:
        break
    ax = axes[i]
    band = img_chw[i]
    im = ax.imshow(band, cmap="gray")
    ax.set_title(f"Band {i}")
    ax.axis("off")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Last subplot: label mask
ax_lbl = axes[-1]
ax_lbl.imshow(label, cmap="gray")
ax_lbl.set_title("Label mask")
ax_lbl.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# PyTorch Dataset for water segmentation

class WaterSegmentationDataset(Dataset):
    def __init__(self, images_dir, labels_dir, base_label_files, normalize=True):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.normalize = normalize

        # Build pairs (image_path, label_path) using base (numeric) label filenames
        self.samples = []
        for lbl_path in base_label_files:
            stem = os.path.splitext(os.path.basename(lbl_path))[0]  # e.g. '123'
            img_path = os.path.join(images_dir, f"{stem}.tif")
            if os.path.exists(img_path):
                self.samples.append((img_path, lbl_path))

        if len(self.samples) == 0:
            raise RuntimeError("No (image, label) pairs found. Check your folder structure.")

        print(f"Dataset initialized with {len(self.samples)} samples.")

        # Optionally pre-compute mean/std per channel on the fly for normalization
        # (for simplicity, we skip this here and do simple scaling in __getitem__).

    def __len__(self):
        return len(self.samples)

    def _load_image(self, path):
        # Always use tifffile for multispectral tiles; PIL cannot read these .tif files reliably.
        img = tifffile.imread(path)

        if img.ndim == 3 and img.shape[0] in (12, 13):
            img_chw = img
        elif img.ndim == 3 and img.shape[-1] in (12, 13):
            img_chw = np.transpose(img, (2, 0, 1))
        else:
            raise ValueError(f"Unexpected image shape {img.shape} for {path}")

        img_chw = img_chw.astype(np.float32)

        # Simple normalization: scale to [0, 1] by global max if needed
        if self.normalize:
            max_val = img_chw.max() if img_chw.max() > 0 else 1.0
            img_chw = img_chw / max_val

        return img_chw

    def _load_label(self, path):
        lbl = np.array(Image.open(path))
        # Ensure we have a single channel mask
        if lbl.ndim == 3:
            # if RGB, convert to single channel (assuming water is white)
            lbl = lbl[..., 0]

        # Binarize: any non-zero is water
        lbl_bin = (lbl > 0).astype(np.float32)
        return lbl_bin

    def __getitem__(self, idx):
        img_path, lbl_path = self.samples[idx]
        img = self._load_image(img_path)
        lbl = self._load_label(lbl_path)

        # Convert to torch tensors
        img_tensor = torch.from_numpy(img)              # (C, H, W)
        lbl_tensor = torch.from_numpy(lbl).unsqueeze(0) # (1, H, W)

        return img_tensor, lbl_tensor


# Instantiate dataset
full_dataset = WaterSegmentationDataset(IMAGES_DIR, LABELS_DIR, base_label_files)

# Simple train/val split
val_fraction = 0.2
val_size = int(len(full_dataset) * val_fraction)
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# U-Net model definition (12-channel input -> 1-channel water mask)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNet(nn.Module):
    def __init__(self, in_ch=12, out_ch=1):
        super().__init__()

        # Encoder
        self.down1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.out_conv = nn.Conv2d(64, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoder
        c1 = self.down1(x)
        p1 = self.pool1(c1)

        c2 = self.down2(p1)
        p2 = self.pool2(c2)

        c3 = self.down3(p2)
        p3 = self.pool3(c3)

        c4 = self.down4(p3)
        p4 = self.pool4(c4)

        # Bottleneck
        bn = self.bottleneck(p4)

        # Decoder with skip connections
        u4 = self.up4(bn)
        u4 = torch.cat([u4, c4], dim=1)
        c5 = self.dec4(u4)

        u3 = self.up3(c5)
        u3 = torch.cat([u3, c3], dim=1)
        c6 = self.dec3(u3)

        u2 = self.up2(c6)
        u2 = torch.cat([u2, c2], dim=1)
        c7 = self.dec2(u2)

        u1 = self.up1(c7)
        u1 = torch.cat([u1, c1], dim=1)
        c8 = self.dec1(u1)

        logits = self.out_conv(c8)
        return logits


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = UNet(in_ch=img_chw.shape[0], out_ch=1).to(device)
print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
# Loss, optimizer, and metrics

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def compute_metrics(preds, targets, eps: float = 1e-7):
    """Compute IoU, precision, recall, F1 for the water class.

    preds and targets are torch tensors of shape (B, 1, H, W) with values 0/1.
    """
    preds = preds.view(-1)
    targets = targets.view(-1)

    tp = torch.sum((preds == 1) & (targets == 1)).float()
    fp = torch.sum((preds == 1) & (targets == 0)).float()
    fn = torch.sum((preds == 0) & (targets == 1)).float()

    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    iou = tp / (tp + fp + fn + eps)

    return {
        "precision": precision.item(),
        "recall": recall.item(),
        "f1": f1.item(),
        "iou": iou.item(),
    }

In [None]:
# Training loop

num_epochs = 10  # adjust as needed

for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        logits = model(images)
        loss = criterion(logits, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    all_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0, "iou": 0.0}
    num_val_batches = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            logits = model(images)
            loss = criterion(logits, masks)
            val_loss += loss.item() * images.size(0)

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            metrics = compute_metrics(preds, masks)
            for k in all_metrics.keys():
                all_metrics[k] += metrics[k]
            num_val_batches += 1

    val_loss /= len(val_loader.dataset)
    for k in all_metrics.keys():
        all_metrics[k] /= max(1, num_val_batches)

    print(f"Epoch {epoch}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"IoU: {all_metrics['iou']:.4f} | F1: {all_metrics['f1']:.4f} | "
          f"Prec: {all_metrics['precision']:.4f} | Rec: {all_metrics['recall']:.4f}")

In [None]:
# Visualize predictions vs ground-truth for a few validation samples

model.eval()

n_visualize = 3

with torch.no_grad():
    for i, (images, masks) in enumerate(val_loader):
        if i >= n_visualize:
            break
        images = images.to(device)
        masks = masks.to(device)

        logits = model(images)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()

        # Move to CPU for plotting
        images_np = images.cpu().numpy()
        masks_np = masks.cpu().numpy()
        preds_np = preds.cpu().numpy()

        batch_size_vis = images_np.shape[0]

        for b in range(batch_size_vis):
            img_chw = images_np[b]
            gt = masks_np[b, 0]
            pr = preds_np[b, 0]

            # Simple visualization: take 3 bands to form an RGB-like image (if at least 3 bands)
            if img_chw.shape[0] >= 3:
                rgb = np.stack([
                    img_chw[0],
                    img_chw[1],
                    img_chw[2],
                ], axis=-1)
                # Normalize to [0,1] for display
                rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-7)
            else:
                rgb = img_chw[0]

            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            axes[0].imshow(rgb)
            axes[0].set_title("Input (3 bands)")
            axes[0].axis("off")

            axes[1].imshow(gt, cmap="gray")
            axes[1].set_title("Ground truth mask")
            axes[1].axis("off")

            axes[2].imshow(pr, cmap="gray")
            axes[2].set_title("Predicted mask")
            axes[2].axis("off")

            plt.tight_layout()
            plt.show()