In [27]:
import geoai

In [28]:
from torch.utils.data import Dataset
import numpy as np
import torch
from typing import Callable, List, Optional, Tuple
import rasterio
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Dataset for semantic segmentation from GeoTIFF, PNG, JPG, and other image formats."""

    def __init__(
        self,
        image_paths: List[str],
        label_paths: List[str],
        transforms: Optional[Callable] = None,
        num_channels: Optional[int] = None,
        target_size: Optional[Tuple[int, int]] = None,
        resize_mode: str = "resize",
        num_classes: int = 2,
    ) -> None:
        """
        Initialize dataset for semantic segmentation.

        Args:
            image_paths (list): List of paths to image files (GeoTIFF, PNG, JPG, etc.).
            label_paths (list): List of paths to label files (GeoTIFF, PNG, JPG, etc.).
            transforms (callable, optional): Transformations to apply to images and masks.
            num_channels (int, optional): Number of channels to use from images. If None,
                auto-detected from the first image.
            target_size (tuple, optional): Target size (height, width) for standardizing images.
                If None, images will keep their original sizes.
            resize_mode (str): How to handle size standardization. Options:
                'resize' - Resize images to target_size (may change aspect ratio)
                'pad' - Pad images to target_size (preserves aspect ratio)
            num_classes (int): Number of classes for segmentation. Used for mask normalization.
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transforms = transforms
        self.target_size = target_size
        self.resize_mode = resize_mode
        self.num_classes = num_classes

        # Auto-detect the number of channels if not specified
        if num_channels is None:
            self.num_channels = self._get_num_channels(self.image_paths[0])
        else:
            self.num_channels = num_channels

    def _is_geotiff(self, file_path: str) -> bool:
        """Check if file is a GeoTIFF based on extension."""
        return file_path.lower().endswith((".tif", ".tiff"))

    def _get_num_channels(self, image_path: str) -> int:
        """Get number of channels from an image file."""
        if self._is_geotiff(image_path):
            with rasterio.open(image_path) as src:
                return src.count
        else:
            # For standard image formats, use PIL
            with Image.open(image_path) as img:
                if img.mode == "RGB":
                    return 3
                elif img.mode == "RGBA":
                    return 4
                elif img.mode == "L":
                    return 1
                else:
                    # Convert to RGB and return 3 channels
                    return 3

    def _resize_image_and_mask(
        self, image: np.ndarray, mask: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Resize image and mask to target size."""
        if self.target_size is None:
            return image, mask

        target_h, target_w = self.target_size

        if self.resize_mode == "resize":
            # Direct resize (may change aspect ratio)
            image = F.interpolate(
                image.unsqueeze(0),
                size=(target_h, target_w),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)

            mask = (
                F.interpolate(
                    mask.unsqueeze(0).unsqueeze(0).float(),
                    size=(target_h, target_w),
                    mode="nearest",
                )
                .squeeze(0)
                .squeeze(0)
                .long()
            )
            # Clamp mask values to ensure they're within valid range [0, num_classes-1]
            mask = torch.clamp(mask, 0, self.num_classes - 1)

        elif self.resize_mode == "pad":
            # Pad to target size (preserves aspect ratio)
            image = self._pad_to_size(image, (target_h, target_w))
            mask = self._pad_to_size(mask.unsqueeze(0), (target_h, target_w)).squeeze(0)
            # Clamp mask values to ensure they're within valid range [0, num_classes-1]
            mask = torch.clamp(mask, 0, self.num_classes - 1)

        return image, mask

    def _pad_to_size(
        self, tensor: torch.Tensor, target_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Pad tensor to target size with zeros."""
        target_h, target_w = target_size

        if tensor.dim() == 3:  # Image [C, H, W]
            _, h, w = tensor.shape
        elif tensor.dim() == 2:  # Mask [H, W]
            h, w = tensor.shape
        else:
            raise ValueError(f"Unexpected tensor dimensions: {tensor.shape}")

        # Calculate padding
        pad_h = max(0, target_h - h)
        pad_w = max(0, target_w - w)

        # Pad equally on both sides
        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        # Apply padding (left, right, top, bottom)
        padded = F.pad(tensor, (pad_left, pad_right, pad_top, pad_bottom), value=0)

        # Crop if tensor is larger than target
        if tensor.dim() == 3:
            padded = padded[:, :target_h, :target_w]
        else:
            padded = padded[:target_h, :target_w]

        return padded

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Load image
        image_path = self.image_paths[idx]
        if self._is_geotiff(image_path):
            # Load GeoTIFF using rasterio
            with rasterio.open(image_path) as src:
                # Read as [C, H, W] format
                image = src.read().astype(np.float32)
                # Normalize image to [0, 1] range
                image = image / 255.0
        else:
            # Load standard image formats using PIL
            with Image.open(image_path) as img:
                # Convert to RGB if needed
                if img.mode != "RGB":
                    img = img.convert("RGB")
                # Convert to numpy array [H, W, C]
                image = np.array(img, dtype=np.float32)
                # Normalize to [0, 1] range
                image = image / 255.0
                # Convert to [C, H, W] format
                image = np.transpose(image, (2, 0, 1))

        # Handle different number of channels
        if image.shape[0] > self.num_channels:
            image = image[: self.num_channels]  # Keep only specified bands
        elif image.shape[0] < self.num_channels:
            # Pad with zeros if less than specified bands
            padded = np.zeros(
                (self.num_channels, image.shape[1], image.shape[2]),
                dtype=np.float32,
            )
            padded[: image.shape[0]] = image
            image = padded

        # Convert to CHW tensor
        image = torch.as_tensor(image, dtype=torch.float32)

        # Load label mask
        label_path = self.label_paths[idx]
        if self._is_geotiff(label_path):
            # Load GeoTIFF label using rasterio
            with rasterio.open(label_path) as src:
                label_mask = src.read(1).astype(np.int64)
        else:
            # Load standard image format label using PIL
            with Image.open(label_path) as img:
                # Convert to grayscale if needed
                if img.mode != "L":
                    img = img.convert("L")
                label_mask = np.array(img, dtype=np.int64)

        # Normalize mask values to expected class range [0, num_classes-1]
        # This handles cases where masks contain pixel values outside the expected range
        unique_vals = np.unique(label_mask)
        if len(unique_vals) > 2:
            # For multi-class case, we need to map values to proper class indices
            # For now, we'll use a simple thresholding approach for binary segmentation
            if self.num_classes == 2:
                # Binary segmentation: convert to 0 (background) and 1 (foreground)
                label_mask = (label_mask > 0).astype(np.int64)
            else:
                # For multi-class, we could implement more sophisticated mapping
                # For now, just ensure values are in valid range
                label_mask = np.clip(label_mask, 0, self.num_classes - 1)
        elif len(unique_vals) == 2 and unique_vals.max() > 1:
            # Binary mask with values not in [0,1] range - normalize to [0,1]
            label_mask = (label_mask > 0).astype(np.int64)

        # Convert to tensor
        mask = torch.as_tensor(label_mask, dtype=torch.long)

        # Resize image and mask to target size if specified
        image, mask = self._resize_image_and_mask(image, mask)

        # Apply transforms if specified
        if self.transforms is not None:
            image, mask = self.transforms(image, mask)

        return image, mask

In [36]:
import torch
import random
from typing import Callable, List, Optional, Tuple

class SemanticTransforms:
    """Custom transforms for semantic segmentation."""

    def __init__(self, transforms: List[Callable]) -> None:
        self.transforms = transforms

    def __call__(
        self, image: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        for t in self.transforms:
            image, mask = t(image, mask)
        return image, mask


class SemanticToTensor:
    """Convert numpy.ndarray to tensor for semantic segmentation."""

    def __call__(
        self, image: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return image, mask


class SemanticRandomHorizontalFlip:
    """Random horizontal flip transform for semantic segmentation."""

    def __init__(self, prob: float = 0.5) -> None:
        self.prob = prob

    def __call__(
        self, image: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if random.random() < self.prob:
            # Flip image and mask along width dimension
            image = torch.flip(image, dims=[2])
            mask = torch.flip(mask, dims=[1])
        return image, mask

In [37]:
from typing import Any

def get_semantic_transform(train: bool) -> Any:
    """
    Get transforms for semantic segmentation data augmentation.

    Args:
        train (bool): Whether to include training-specific transforms.

    Returns:
        SemanticTransforms: Composed transforms.
    """
    transforms = []
    transforms.append(SemanticToTensor())

    if train:
        transforms.append(SemanticRandomHorizontalFlip(0.5))

    return SemanticTransforms(transforms)

In [46]:
model_folder = "../models/crop_mapping_subset_3"
images_dir = "../data/processed/crop_mapping_subset_3/images"
labels_dir = "../data/processed/crop_mapping_subset_3/labels"

In [47]:
import torch
import segmentation_models_pytorch as smp

model_path = f"{model_folder}/deeplabv3plus_models/best_model.pth"

# Rebuild the model architecture the same way as training
model = smp.create_model(
    arch="deeplabv3plus",
    encoder_name="efficientnet-b3",
    encoder_weights=None,
    in_channels=4,
    classes=11,
)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

DeepLabV3Plus(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      4, 40, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_expand_conv): Identity()
        (_bn0): Identity()
        (_depthwise_conv): Conv2dStaticSamePadding(
          40, 40, kernel_size=(3, 3), stride=[1, 1], groups=40, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          40, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 40, kernel_size=(1, 1), stride=(1, 1)
          (static_paddin

In [48]:
import os

# Get all image and label files
# Support multiple image formats: GeoTIFF, PNG, JPG, JPEG, TIF, TIFF
image_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
label_extensions = (".tif", ".tiff", ".png", ".jpg", ".jpeg")

image_files = sorted(
    [
        os.path.join(images_dir, f)
        for f in os.listdir(images_dir)
        if f.lower().endswith(image_extensions)
    ]
)
label_files = sorted(
    [
        os.path.join(labels_dir, f)
        for f in os.listdir(labels_dir)
        if f.lower().endswith(label_extensions)
    ]
)

In [56]:
from sklearn.model_selection import train_test_split

train_imgs, val_imgs, train_labels, val_labels = train_test_split(
    image_files, label_files, test_size=0.2, random_state=42
)

In [57]:
from torch.utils.data import DataLoader
from geoai import train

val_dataset = train.SemanticSegmentationDataset(
    val_imgs,             # from your earlier train/val split
    val_labels,
    transforms=train.get_semantic_transform(train=False),
    num_channels=4,
    target_size=(256, 256),  # must match training
    resize_mode="resize",
    num_classes=11,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

In [52]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from typing import Any, Dict

def evaluate_semantic(
    model: torch.nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    criterion: Any,
    num_classes: int = 2,
    ignore_index: int = 255,
) -> Dict[str, float]:
    """
    Evaluate the semantic segmentation model with per-class IoU and macro IoU.

    Args:
        model (torch.nn.Module): The model to evaluate.
        data_loader (torch.utils.data.DataLoader): DataLoader for validation data.
        device (torch.device): Device to evaluate on.
        criterion: Loss function.
        num_classes (int): Number of classes.
        ignore_index (int): Value to ignore in evaluation (e.g., 255).

    Returns:
        dict: Evaluation metrics including loss, Dice, IoU, Macro IoU, Per-class IoU.
    """
    model.eval()
    total_loss = 0.0
    num_batches = len(data_loader)

    # confusion matrix for IoU
    confusion = np.zeros((num_classes, num_classes), dtype=np.int64)

    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            targets = targets.cpu().numpy()

            for p, t in zip(preds, targets):
                mask = t != ignore_index
                p = p[mask]
                t = t[mask]
                cm = np.bincount(
                    num_classes * t.astype(int) + p.astype(int),
                    minlength=num_classes**2,
                ).reshape(num_classes, num_classes)
                confusion += cm

    # per-class IoU
    TP = np.diag(confusion)
    FP = confusion.sum(axis=0) - TP
    FN = confusion.sum(axis=1) - TP
    per_class_iou = TP / (TP + FP + FN + 1e-6)

    # metrics
    macro_iou = np.nanmean(per_class_iou)
    mean_loss = total_loss / num_batches
    dice = 2 * TP.sum() / (2 * TP.sum() + FP.sum() + FN.sum())

    return {
        "loss": mean_loss,
        "Dice": float(dice),
        "IoU": float(per_class_iou.mean()),   # mean IoU across classes
        "MacroIoU": float(macro_iou),
        "PerClassIoU": per_class_iou.tolist(),
    }

In [60]:
metrics_detail = evaluate_semantic(
    model,
    val_loader,
    device,
    criterion=torch.nn.CrossEntropyLoss(),
    num_classes=11,
    ignore_index=0,
)

In [61]:
print("Validation Loss:", metrics_detail["loss"])
print("Global IoU:", metrics_detail["IoU"])
print("Macro IoU:", metrics_detail["MacroIoU"])
print("Dice:", metrics_detail["Dice"])
print("Per-class IoU:")

class_mapping = {
    0: "Background",
    1: "Corn",
    2: "Rice",
    3: "Winter Wheat",
    4: "Alfalfa",
    5: "Tomatoes",
    6: "Grapes",
    7: "Almonds",
    8: "Walnuts",
    9: "Prunes",
    10: "Olives"
}
for i, iou in enumerate(metrics_detail["PerClassIoU"]):
    print(f"  {class_mapping.get(i, str(i))}: {iou:.4f}")

Validation Loss: 0.603565348174748
Global IoU: 0.4800350773148956
Macro IoU: 0.4800350773148956
Dice: 0.7388469457714606
Per-class IoU:
  Background: 0.0000
  Corn: 0.4391
  Rice: 0.8917
  Winter Wheat: 0.4047
  Alfalfa: 0.5111
  Tomatoes: 0.5722
  Grapes: 0.4012
  Almonds: 0.5950
  Walnuts: 0.6419
  Prunes: 0.3566
  Olives: 0.4668
