<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments/SegFormer-b3-multispectral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Mount Google Drive; install dependencies
import os

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install opencv-python
!pip install --upgrade torch_xla torch
!pip install evaluate
!pip install wandb -qU
!pip install torchmetrics

In [None]:
# @title Load SegmentationDatasets from Drive { display-mode: "code" }

import torch
import json

training_data_directory = '/content/drive/MyDrive/desCartes/training_data/'

channel_count = 21 # @param {type:"integer"}
sample = False # @param {type:"boolean"}

input_channels = None
num_classes = None

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.augmentations_per_sample = 8  # 4 orientations + flips
        self.length = len(base_dataset) * self.augmentations_per_sample

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        base_idx = idx // self.augmentations_per_sample
        aug_idx = idx % self.augmentations_per_sample

        sample = self.base_dataset[base_idx]

        orientation = aug_idx % 4
        flip = aug_idx >= 4

        def augment(tensor):
            if flip:
                tensor = tensor.flip(-1)
            if orientation:
                tensor = torch.rot90(tensor, orientation, dims=(-2, -1))
            return tensor

        return {
            'pixel_values': augment(sample['pixel_values']),
            'labels': augment(sample['labels'])
        }

torch.serialization.add_safe_globals([SegmentationDataset])

# Load the dataset from a binary file
def load_dataset(file_path):
    global input_channels, num_classes

    dataset = torch.load(file_path)
    print(f"Dataset of {len(dataset)} items loaded from {file_path}")

    # Reduce dataset size to 10% if sample=True
    if sample:
        reduced_size = max(4, int(len(dataset) * 0.1))  # Ensure at least 4 samples
        dataset.data = dataset.data[:reduced_size]
        print(f"Reduced dataset size: {len(dataset)} samples.")

    # Wrap dataset in SegmentationDataset class
    dataset = SegmentationDataset(dataset)

    # Iterate through the data and print the type and shape of the first sample
    first_sample = dataset[0]

    input_channels = first_sample['pixel_values'].shape[0]
    num_classes = first_sample['labels'].max().item() + 1

    if isinstance(first_sample, torch.Tensor):
        print(f"First sample shape: {first_sample.shape}")
    else:
        print(f"First sample: {first_sample}")

    return dataset

# Define file paths for loading
pytorch_path = f"{training_data_directory}pytorch/{channel_count}-channel"
train_data_path = f'{pytorch_path}/train_dataset.pt'
eval_data_path = f'{pytorch_path}/eval_dataset.pt'

# Load the datasets from Google Drive
eval_dataset = load_dataset(eval_data_path)
train_dataset = load_dataset(train_data_path)

print(f"Input channels: {input_channels}")
print(f"Number of classes: {num_classes}")


In [3]:
# @title Safe Intersection-over-Union { display-mode: "code" }

from typing import Dict, Optional

import datasets
import numpy as np

import evaluate

def clip_round(value, decimals=5):
    """
    Efficient rounding that avoids precision issues near 0 and 1.
    """
    eps = 10 ** -decimals  # Calculate epsilon
    rounded_value = np.round(value, decimals)

    if np.isscalar(rounded_value):
        return 1.0 if rounded_value >= 1 - eps else (0.0 if rounded_value <= eps else rounded_value)

    rounded_value[rounded_value >= 1 - eps] = 1.0
    rounded_value[rounded_value <= eps] = 0.0
    return rounded_value

def intersect_and_union(
    pred_label,
    label,
    num_labels,
    ignore_index
):
    """
    Computes the intersection and union of predictions and labels.
    """
    if ignore_index >= 0:
        mask = label != ignore_index
        pred_label, label = pred_label[mask], label[mask]

    intersect = pred_label[pred_label == label]

    area_intersect = np.histogram(intersect, bins=num_labels, range=(0, num_labels - 1))[0].astype(np.uint32)
    area_pred_label = np.histogram(pred_label, bins=num_labels, range=(0, num_labels - 1))[0].astype(np.uint32)
    area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0].astype(np.uint32)

    area_union = area_pred_label + area_label - area_intersect

    return area_intersect, area_union, area_pred_label, area_label

def total_intersect_and_union(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index
):
    """
    Accumulates intersection and union over all samples.
    """
    total_area_intersect = np.zeros(num_labels, dtype=np.uint32)
    total_area_union = np.zeros(num_labels, dtype=np.uint32)
    total_area_pred_label = np.zeros(num_labels, dtype=np.uint32)
    total_area_label = np.zeros(num_labels, dtype=np.uint32)

    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
            result, gt_seg_map, num_labels, ignore_index
        )
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label

    return total_area_intersect, total_area_union, total_area_pred_label, total_area_label


def iou(
    results,  # pred_labels (the predicted class indices)
    gt_seg_maps,  # labels (the ground truth class indices)
    num_labels,  # The total number of classes in the segmentation task
    ignore_index,
    nan_to_num=None,
):
    """
    Computes per-category and overall IoU.
    """
    total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(
        results, gt_seg_maps, num_labels, ignore_index
    )

    eps = 1e-10
    round_decimals = 5

    total_label_sum = total_area_label.sum()  # Store sum to avoid redundant computations

    all_acc = np.clip(total_area_intersect.sum() / (total_label_sum + eps), eps, 1 - eps)
    iou = np.clip(total_area_intersect / (total_area_union + eps), eps, 1 - eps)

    all_acc = clip_round(all_acc, round_decimals)
    iou = clip_round(iou, round_decimals)

    # Assign IoU = 1 for empty classes
    iou[total_area_label == 0] = 1.0

    # Set ignored index to 0
    if ignore_index >= 0:
        iou[ignore_index] = 0.0

    metrics = {"overall_accuracy": all_acc, "per_category_iou": iou}

    if nan_to_num is not None:
        metrics = {k: np.nan_to_num(v, nan=nan_to_num) for k, v in metrics.items()}

    return metrics

In [4]:
# @title Multispectral SegFormer Model { display-mode: "code" }

import os
from transformers import SegformerConfig, SegformerForSemanticSegmentation

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# Select Model (must match image size of samples)
# Original Publication: https://github.com/NVlabs/SegFormer/tree/master/local_configs/segformer
# Largest model which can be trained with Colab TPU v2-4 memory limit is b4 (512 x 512): ~75% capacity at batch size of 2
model_version = "b3" # @param ['b0', 'b1', 'b2', 'b3', 'b3n', 'b4', 'b5'] {type:'string'}

# Define class labels
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]

# Configure label mappings
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

def create_or_fetch_segformer(model_size=model_version, input_channels=input_channels, num_classes=num_classes):
    try:
        model_name = f"docuracy/segformer-{model_size}-{input_channels}-{num_classes}-512-512"
        base_model_path = f'{model_path}/base/{model_name}'

        if not os.path.exists(base_model_path):
            print(f"Creating model: {model_name}\n", flush=True)
            os.makedirs(base_model_path)

            # SegFormer base configurations
            base_configs = {
                "b0": dict(hidden_sizes=[32, 64, 160, 256], decoder_hidden_size=128, depths=[2, 2, 2, 2], mlp_ratios=[4, 4, 4, 4]),
                "b1": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=256, depths=[2, 2, 2, 2], mlp_ratios=[4, 4, 4, 4]),
                "b2": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=768, depths=[3, 4, 6, 3], mlp_ratios=[4, 4, 4, 4]),
                "b3": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=768, depths=[3, 4, 18, 3], mlp_ratios=[4, 4, 4, 4]),
                "b3n": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=768, depths=[3, 4, 18, 3], mlp_ratios=[4, 4, 4, 4],
                          patch_sizes=[3, 3, 3, 3], strides=[2, 2, 2, 2]),  # Narrower variant
                "b4": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=768, depths=[3, 8, 27, 3], mlp_ratios=[4, 4, 4, 4]),
                "b5": dict(hidden_sizes=[64, 128, 320, 512], decoder_hidden_size=768, depths=[3, 6, 40, 3], mlp_ratios=[4, 4, 4, 4]),
            }

            if model_size not in base_configs:
                raise ValueError(f"Unsupported model_size '{model_size}'. Choose from 'b0' to 'b5'.")

            config_args = base_configs[model_size]

            # Create a randomly initialised model of the required specification
            # See https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
            config = SegformerConfig(
                num_channels=input_channels,
                num_labels=num_classes,
                hidden_sizes=config_args["hidden_sizes"],
                decoder_hidden_size=config_args["decoder_hidden_size"],
                depths=config_args["depths"],
                mlp_ratios=config_args["mlp_ratios"],
                patch_sizes=config_args.get("patch_sizes", [7, 3, 3, 3]),  # fallback to default
                strides=config_args.get("strides", [4, 2, 2, 2]),           # fallback to default
                # Optional extras:
                hidden_act="gelu",
                classifier_dropout=0.1,
                backbone_type="mit",
                id2label=id2label,
                label2id=label2id,
            )

            model = SegformerForSemanticSegmentation(config)
            model.save_pretrained(base_model_path)
            print(f"Model saved to: {base_model_path}\n", flush=True)
        else:
            model = SegformerForSemanticSegmentation.from_pretrained(base_model_path)

        return model
    except Exception as e:
        print(f"Error loading/downloading: {e}\n", flush=True)
        raise

# Ensure that model exists before proceeding
create_or_fetch_segformer()

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(21, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True

In [None]:
# @title Train Model { display-mode: "code" }

if 'iou' not in globals():
    raise NameError("Function 'mean_iou' is not defined. Run the appropriate cell first.")

if 'create_or_fetch_segformer' not in globals():
    raise NameError("Function 'create_or_fetch_segformer' is not defined. Run the appropriate cell first.")

if 'train_dataset' not in globals() or 'eval_dataset' not in globals():
    raise NameError("Either 'train_dataset' or 'eval_dataset' is not defined. Run the appropriate cell first.")

if not train_dataset:  # Checks if train_dataset is empty
    raise ValueError("'train_dataset' is empty.")

if not eval_dataset:  # Checks if eval_dataset is empty
    raise ValueError("'eval_dataset' is empty.")

print("All variable checks passed! Proceeding with execution.")

# Import necessary libraries
import os
import sys
import shutil
import numpy as np
import gc
import time
import math
import wandb
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torch_xla.distributed.parallel_loader as pl
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer, EarlyStoppingCallback
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support
from google.colab import userdata

restart_training = True # @param {type:"boolean"}
checkpoint_number = 0 # @param {type: "integer"}

class_weights_path = f'/content/drive/MyDrive/desCartes/models/class_weights.json'

# Local directory for storing dataset
local_data_dir = "/content/data"

# Training Configuration: by trial and error, these are found to be the maximum that can be accommodated without TPU stalling
per_device_train_batch_size = 2  # Batch size for training
per_device_eval_batch_size = per_device_train_batch_size

# Number of workers used by DataLoaders
num_workers = 4
persistent_workers = num_workers > 0

# Raising the following radically increases training time: unnecessary as 8x TPUs effectively multiply the batch size
gradient_accumulation_steps = 1  # Simulates a batch size of gradient_accumulation_steps * per_device_train_batch_size

# Loss Function Configuration
loss_gamma = 1.5  # Focal loss gamma
# loss_alpha = 0.40  # Focal loss alpha (now handled automatically)
dice_weight = 0.7
loss_weight = 1 - dice_weight

###################################################

!wandb --version
os.environ["WANDB_API_KEY"] = userdata.get('WANDB_TOKEN')
!wandb login
wandb.init(project="tpu-segmentation", name=f"TPU-Training-{model_version}", settings=wandb.Settings(_service_wait=60))

hf_token = userdata.get('HF_TOKEN')

# Load class weights from JSON
with open(class_weights_path, 'r') as f:
    weights_dict = json.load(f)
CLASS_WEIGHTS_NP = [float(weights_dict[str(i)]) for i in range(len(weights_dict))]
class_weights = torch.tensor(
    CLASS_WEIGHTS_NP,
    dtype=torch.bfloat16 # Consistent with TrainingArguments
)
print(f"Class weights loaded: {CLASS_WEIGHTS_NP}")

# Test for existing checkpoints
checkpoint_path = f"{model_path}/checkpoints/{model_version}"

"""
This codeblock simply removes any checkpoints greater than the given value, which
may be necessary if the last training run crashed before all the required files
had been saved. It also allows for rewinding training.
"""
if checkpoint_number > 0 and os.path.exists(f"{checkpoint_path}/checkpoint-{checkpoint_number}"):
    # Delete checkpoint directories with greater indices
    if os.path.exists(checkpoint_path):
        for checkpoint_dir in os.scandir(checkpoint_path):
            # Check if the item is a directory and starts with 'checkpoint-'
            if checkpoint_dir.is_dir() and checkpoint_dir.name.startswith("checkpoint-"):
                current_checkpoint_number = int(checkpoint_dir.name.split("-")[1])
                if current_checkpoint_number > checkpoint_number:
                    # Delete the directory
                    shutil.rmtree(checkpoint_dir.path)
else:
    if checkpoint_number > 0:
        # Stop execution if the desired checkpoint doesn't exist
        print(f"Checkpoint {checkpoint_number} does not exist. Stopping execution.")
        sys.exit(1)

    if restart_training and os.path.exists(checkpoint_path):
        # Delete the folder and its contents if restarting training
        shutil.rmtree(checkpoint_path)
        os.makedirs(checkpoint_path)  # Recreate the empty checkpoint directory

# Set the checkpoint to the latest remaining one (if any exists)
checkpoint = os.path.exists(checkpoint_path) and any(os.scandir(checkpoint_path))

def compute_metrics(eval_pred, batch_size=16):
    ignore_index = -1
    logits, labels = eval_pred
    num_samples = logits.shape[0]

    all_metrics = {
        "overall_accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "weighted_mean_iou": [],
    }
    per_category_ious = [[] for _ in range(num_classes)]

    try:
        for start_idx in range(0, num_samples, batch_size):
            end_idx = min(start_idx + batch_size, num_samples)

            # Use uint8 for labels since the values range 0-4
            batch_logits = torch.from_numpy(logits[start_idx:end_idx]).cpu()  # Offload early
            batch_labels = labels[start_idx:end_idx].astype(np.uint8)  # Ensure efficient storage

            with torch.no_grad():
                logits_tensor = nn.functional.interpolate(
                    batch_logits, size=batch_labels.shape[-2:], mode="nearest"
                ).argmax(dim=1).numpy().astype(np.uint8)  # Convert to uint8

            batch_metrics = iou(
                results=logits_tensor,
                gt_seg_maps=batch_labels,
                num_labels=num_classes,
                ignore_index=ignore_index,
            )

            all_metrics["overall_accuracy"].append(
                np.float32(batch_metrics["overall_accuracy"])
            )
            per_category_iou = batch_metrics["per_category_iou"]

            pred_flat = logits_tensor.ravel()  # Use ravel() instead of flatten() to avoid copies
            labels_flat = batch_labels.ravel()  # Same here

            precision, recall, f1, _ = precision_recall_fscore_support(
                labels_flat, pred_flat, average="weighted", zero_division=0
            )

            all_metrics["precision"].append(np.float32(precision))
            all_metrics["recall"].append(np.float32(recall))
            all_metrics["f1_score"].append(np.float32(f1))

            for i, v in enumerate(per_category_iou):
                if i != ignore_index:
                    per_category_ious[i].append(np.float32(v))

            gc.collect()  # Explicit garbage collection to free memory

        # Compute final metrics
        all_metrics["overall_accuracy"] = np.mean(all_metrics["overall_accuracy"], dtype=np.float32).item()
        all_metrics["precision"] = np.mean(all_metrics["precision"], dtype=np.float32).item()
        all_metrics["recall"] = np.mean(all_metrics["recall"], dtype=np.float32).item()
        all_metrics["f1_score"] = np.mean(all_metrics["f1_score"], dtype=np.float32).item()

        final_per_category_ious = [np.mean(ious, dtype=np.float32).item() for ious in per_category_ious]

        for i, v in enumerate(final_per_category_ious):
            if i != ignore_index:
                all_metrics[f"iou_{id2label[i]}"] = v

        all_metrics["unweighted_mean_roads_iou"] = (
            (final_per_category_ious[1] + final_per_category_ious[2]) / 2
        )
        all_metrics["unweighted_mean_paths_iou"] = (
            (final_per_category_ious[3] + final_per_category_ious[4]) / 2
        )
        all_metrics["unweighted_mean_iou"] = np.average(final_per_category_ious).item()

        if ignore_index >= 0:
            reduced_class_weights = np.delete(CLASS_WEIGHTS_NP, ignore_index)
        else:
            reduced_class_weights = CLASS_WEIGHTS_NP

        all_metrics["weighted_mean_iou"] = np.average(
            final_per_category_ious, weights=reduced_class_weights
        ).item()

        return all_metrics

    except Exception as e:
        print(f"Error in compute_metrics: {e}")
        return {
            "error": True,
            "overall_accuracy": 0.0,
            "weighted_mean_iou": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "f1_score": 0.0,
            **{f"iou_{id2label[i]}": 0.0 for i in range(num_classes)},
        }

def tpu_worker_process(rank):

    try:
        # Set TPU device
        device = xm.xla_device()
        world_size = xr.world_size()

        # Calculate effective batch size (across all TPUs)
        effective_batch_size = per_device_train_batch_size * world_size * gradient_accumulation_steps
        # Calculate steps per epoch (without dropping the last batch)
        steps_per_epoch = math.ceil(train_dataset.length / effective_batch_size)
        xm.master_print(f"Steps per epoch: {steps_per_epoch}")

        # Ensure that all TPUs are available before proceeding
        xm.rendezvous("ready")
        xm.master_print(f"All {world_size} devices are ready!\n", flush=True)

        model = create_or_fetch_segformer()
        model.to(device)
        class_weights.to(device)

        # Distributed samplers (drop_last=True to prevent hanging)
        train_sampler = DistributedSampler(
            train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
        )
        eval_sampler = DistributedSampler(
            eval_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True
        )

        # Safe TPU DataLoader setup
        def worker_init_fn(worker_id):
            """Ensures each worker has a different random seed"""
            torch.manual_seed(worker_id + rank)

        train_dataloader = DataLoader(
            train_dataset, batch_size=per_device_train_batch_size, sampler=train_sampler,
            num_workers=num_workers, pin_memory=True, persistent_workers=persistent_workers, worker_init_fn=worker_init_fn
        )
        eval_dataloader = DataLoader(
            eval_dataset, batch_size=per_device_eval_batch_size, sampler=eval_sampler,
            num_workers=num_workers, pin_memory=True, persistent_workers=persistent_workers, worker_init_fn=worker_init_fn
        )

        # Wrap data loaders with MpDeviceLoader for TPU support
        train_dataloader = pl.MpDeviceLoader(train_dataloader, device)
        eval_dataloader = pl.MpDeviceLoader(eval_dataloader, device)

        # Training arguments
        training_args = TrainingArguments(
            output_dir=checkpoint_path,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            dataloader_num_workers=num_workers,
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_steps=max(steps_per_epoch // 8, 1),
            logging_strategy="steps",
            report_to=["wandb"] if rank == 0 else [],
            disable_tqdm=(rank != 0),
            gradient_accumulation_steps=gradient_accumulation_steps,
            fp16=False,
            bf16=True,
            metric_for_best_model="weighted_mean_iou",  # Metric to monitor for best model
            greater_is_better=True,  # Set to True to maximize the metric
            num_train_epochs=30,
            save_total_limit=10,  # Keep only the last `n` checkpoints
            load_best_model_at_end=True,
            push_to_hub=False,
            run_name=f"desCartes-{model_version}-{per_device_train_batch_size}-{gradient_accumulation_steps}-bf16",
            ## Perhaps re-enable EarlyStoppingCallback in CustomTrainer
            ## The following are now overridden by `optimizers` in CustomTrainer
            # lr_scheduler_type="cosine_with_restarts",
            # learning_rate=1e-4,  # Slightly higher than default 5e-5
            # warmup_steps=steps_per_epoch * 2,  # Warm-up steps to stabilise learning
            # optim="adamw_torch",  # Ensure TPU-optimized optimizer
        )

        # Ensure that all TPUs are properly loaded before proceeding
        xm.rendezvous("steady")
        xm.master_print("All devices are steady!\n", flush=True)

        # Trainer: override standard methods
        class CustomTrainer(Trainer):
            def __init__(self, *args, class_weights=None, **kwargs):
                super().__init__(*args, **kwargs)
                self.class_weights = class_weights

            def optimizer_step(self, model, optimizer, optimizer_idx=None, **kwargs):
                # Ensure TPU sync before stepping
                xm.optimizer_step(optimizer, barrier=True)

            def get_train_dataloader(self):
                return train_dataloader

            def get_eval_dataloader(self, eval_dataset=None):
                return eval_dataloader

            def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):

                def dice_loss(logits, target, class_weights, smooth=1e-6):
                    try:
                        pred = F.softmax(logits, dim=1)  # Convert logits to probabilities
                        target_one_hot = F.one_hot(target.long(), num_classes=num_classes).permute(0, 3, 1, 2).float()

                        intersection = (pred * target_one_hot).sum(dim=(2, 3), keepdim=True)
                        denominator = (pred + target_one_hot).sum(dim=(2, 3), keepdim=True)

                        dice_per_class = (2.0 * intersection + smooth) / (denominator + smooth)

                        weighted_dice = (dice_per_class * class_weights).sum(dim=1) / (class_weights.sum() + smooth)
                        return 1 - weighted_dice.mean()

                    except Exception as e:
                        xm.master_print(f"Error calculating Dice loss: {e}")
                        return torch.tensor(0.0, device=logits.device)

                # Override default method to incorporate class weights
                if self.model_accepts_loss_kwargs:
                    loss_kwargs = {}
                    if num_items_in_batch is not None:
                        loss_kwargs["num_items_in_batch"] = num_items_in_batch
                    inputs = {**inputs, **loss_kwargs}

                outputs = model(**inputs)

                if isinstance(outputs, dict) and "loss" not in outputs:
                    raise ValueError(
                        "The model did not return a loss from the inputs, only the following keys: "
                        f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                    )

                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

                try:
                    logits = outputs.logits.to(torch.bfloat16)

                    # Ensure device consistency
                    device = logits.device
                    labels = inputs.pop("labels").to(device).to(torch.long)
                    if self.class_weights is not None:
                        self.class_weights = self.class_weights.to(device)
                    else:
                        xm.master_print("Warning: self.class_weights is None.")
                        self.class_weights = torch.ones(num_classes, device=device)

                    # Reshape labels to match logits
                    labels = F.interpolate(labels.unsqueeze(1).float().to(torch.bfloat16), size=(128, 128), mode="nearest").squeeze(1).long()

                    # Compute cross-entropy loss
                    ce_loss = F.cross_entropy(logits, labels, reduction="none", weight=self.class_weights).float().to(torch.bfloat16)

                    # Compute dynamic alpha based on class frequencies in the batch
                    try:
                        class_freq = labels.view(-1).float().histc(bins=num_classes, min=0, max=num_classes - 1) + 1e-6
                        alpha_t = (1 / class_freq).div_(class_freq.sum())  # Normalize alpha_t
                    except Exception as e:
                        xm.master_print(f"Error computing class frequencies: {e}")
                        alpha_t = 0.35 * torch.ones(num_classes, device=logits.device)

                    # Compute focal loss components
                    try:
                        pt = torch.exp(-torch.clamp(ce_loss, min=1e-8, max=1e8))
                        alpha_t_gathered = alpha_t.gather(0, labels.view(-1)).view(labels.shape)
                        focal_loss = (ce_loss * ((1 - pt) ** loss_gamma) * alpha_t_gathered).mean()

                    except Exception as e:
                        xm.master_print(f"Error calculating focal loss: {e}")
                        focal_loss = ce_loss.mean()

                    # Dice loss component
                    dice = dice_loss(logits, labels, self.class_weights)

                    # Weighted sum of losses
                    loss = loss_weight * focal_loss + dice_weight * dice

                except Exception as e:
                    xm.master_print(f"Error calculating loss - defaulting to unweighted cross-entropy: {e}")

                if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
                    loss *= self.accelerator.num_processes

                return (loss, outputs) if return_outputs else loss

        # Optimizer
        initial_learning_rate = 1.4e-4
        optimizer = torch.optim.AdamW(model.parameters(), lr=initial_learning_rate, weight_decay=1e-4)

        # Scheduler
        total_steps = training_args.num_train_epochs * steps_per_epoch

        warmup_steps = int(total_steps * 0.05)  # 5% warmup
        T_max = total_steps - warmup_steps  # Remaining steps for cosine annealing

        # Warmup Scheduler
        warmup_scheduler = LambdaLR(
            optimizer, lr_lambda=lambda step: step / warmup_steps if step < warmup_steps else 1.0
        )

        # Cosine Annealing Scheduler
        cosine_scheduler = CosineAnnealingLR(
            optimizer, T_max=T_max, eta_min=initial_learning_rate * 0.05
        )

        # Sequential Scheduler
        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_steps],
        )

        xm.master_print(f"Scheduler set up: warmup_steps={warmup_steps}", flush=True)

        trainer = CustomTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
            class_weights=class_weights,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=25)],
            optimizers=(optimizer, scheduler),
        )

        # Synchronize TPUs before starting training
        xm.rendezvous("start_training")  # Ensure all TPU processes sync before proceeding
        xm.master_print(f"All devices are GO! ... training started (resume={checkpoint})...\n", flush=True)

        trainer.train(resume_from_checkpoint=checkpoint)
        xm.rendezvous("training_complete")  # Ensure all TPU processes sync before exit
        xm.master_print("Training completed!\n", flush=True)

        # Terminate WandB logging
        if rank == 0:
            wandb.finish()

    except ValueError as e:
        xm.master_print(f"Error calculating loss (ValueError): {e}")
    except RuntimeError as e:
        xm.master_print(f"Error calculating loss (RuntimeError): {e}")
    except Exception as e:
        print(f"Error in tpu_worker_process {rank}: {e}")
        sys.exit(1)

    return

# Launch TPU training with WandB logging
xmp.spawn(tpu_worker_process, args=(), start_method='fork')


All variable checks passed! Proceeding with execution.
wandb, version 0.19.9
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Class weights loaded: [0.010366307282922722, 0.23456335194160186, 0.14726239675182454, 0.30696693283280224, 0.3008410111908487]
Steps per epoch: 161
All 8 devices are ready!





All devices are steady!

Scheduler set up: warmup_steps=241
All devices are GO! ... training started (resume=False)...



Epoch,Training Loss,Validation Loss,Overall Accuracy,Precision,Recall,F1 Score,Weighted Mean Iou,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path,Unweighted Mean Roads Iou,Unweighted Mean Paths Iou,Unweighted Mean Iou
1,0.5142,0.502729,0.914309,0.960074,0.91431,0.932232,0.495551,0.92105,0.474937,0.718198,0.361651,0.524603,0.596567,0.443127,0.600088
2,0.4772,0.482723,0.947728,0.928239,0.947728,0.92345,0.374059,0.947731,0.264773,0.622244,0.334063,0.358825,0.443509,0.346444,0.505527
3,0.4608,0.437054,0.954604,0.953791,0.954604,0.952606,0.548298,0.952419,0.546316,0.737342,0.42861,0.565505,0.641829,0.497058,0.646039
4,0.4508,0.435609,0.955801,0.95071,0.955802,0.948345,0.534823,0.954296,0.441426,0.716463,0.469621,0.570805,0.578944,0.520213,0.630522
5,0.4462,0.427907,0.957684,0.954645,0.957684,0.953842,0.58038,0.955618,0.524318,0.736873,0.517906,0.598305,0.630595,0.558105,0.666604


In [None]:
# @title Inference { display-mode: "code" }

import os
import numpy as np
import torch
import torch_xla.core.xla_model as xm
from transformers import SegformerForSemanticSegmentation
from PIL import Image
from tqdm.notebook import tqdm
from safetensors.torch import load_file

# Define class colors (with alpha)
class_colors = {
    0: (0, 0, 0, 0),      # Background (Black, Transparent)
    1: (255, 0, 0, 255),  # Main Road (Red, Opaque)
    2: (0, 255, 0, 255),  # Minor Road (Green, Opaque)
    3: (255, 165, 0, 255), # Orange (Opaque)
    4: (255, 255, 0, 255)  # Yellow (Opaque)
}

if 'model_version' not in globals():
    model_version = 'b3'
if 'input_channels' not in globals():
    input_channels = 10

project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
checkpoints_dir = f"{model_path}/checkpoints/{model_version}"
data_dir = f'{project_path}/training_data/tiles/inference'  # Directory with .pt files
output_dir = f'{project_path}/inference_output/{model_version}'  # Directory to save PNGs

# Get the list of checkpoint subdirectories
checkpoint_subdirs = [
    os.path.join(checkpoints_dir, d)
    for d in os.listdir(checkpoints_dir)
    if d.startswith("checkpoint-") and os.path.isdir(os.path.join(checkpoints_dir, d))
]

# Sort by checkpoint number (assuming names like checkpoint-12345)
checkpoint_subdirs.sort(key=lambda x: int(x.split('-')[-1]))

# Get the most recent (highest numbered) checkpoint
latest_checkpoint = checkpoint_subdirs[-1]
trainer_state_path = os.path.join(latest_checkpoint, "trainer_state.json")

# Load trainer state
with open(trainer_state_path, 'r') as f:
    trainer_state = json.load(f)

best_checkpoint_path = trainer_state.get("best_model_checkpoint")

if best_checkpoint_path:
    fine_tuned_model_path = os.path.join(best_checkpoint_path, "model.safetensors")
    print(f"✅ Best model checkpoint: {fine_tuned_model_path}")
else:
    raise ValueError("No 'best_model_checkpoint' key found in trainer_state.json")

# Define class labels (same as in training)
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

def process_and_save_prediction_overlay(model, tensor_path, output_dir):
    """Loads a PyTorch tensor, makes a prediction, overlays it on the original image, and saves a PNG."""

    try:
        # Load the PyTorch tensor and move it to the TPU
        input_tensor = torch.load(tensor_path).unsqueeze(0).to(device)

        # Perform inference (no gradients needed)
        with torch.no_grad():
            output = model(input_tensor).logits
            predictions = torch.argmax(output, dim=1).numpy()[0]

        # Create a colored mask (now RGBA)
        color_mask = np.zeros((*predictions.shape, 4), dtype=np.uint8)
        for class_id, color in class_colors.items():
            color_mask[predictions == class_id] = color

        # Load the original JPG image
        jpg_path = tensor_path.replace(".segformer_input.pt", ".jpg")
        original_image = Image.open(jpg_path).convert("RGBA")  # Ensure it's RGBA

        # Resize the color mask to match the original image dimensions
        color_mask_image = Image.fromarray(color_mask, 'RGBA').resize(original_image.size)

        # Overlay the color mask on the original image
        # You can adjust the alpha value (0-1) to control the transparency of the overlay
        overlayed_image = Image.alpha_composite(original_image, color_mask_image)

        # Save the overlayed image as PNG
        filename = os.path.basename(tensor_path).replace(".segformer_input.pt", "_overlay.png")
        output_path = os.path.join(output_dir, filename)
        overlayed_image.save(output_path)
        print(f"Saved: {output_path}")

    except Exception as e:
        print(f"Error processing {tensor_path}: {e}")

def process_folder(data_dir, output_dir, model):
    """Processes all .pt tensors in a folder and saves overlayed PNGs."""

    # Ensure the output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Get a list of all .pt files in the directory
    tensor_files = [f for f in os.listdir(data_dir) if f.endswith(".segformer_input.pt")]

    # Initialize tqdm progress bar
    progress_bar = tqdm(total=len(tensor_files), desc="Processing Tensors")

    for tensor_file in tensor_files:
        tensor_path = os.path.join(data_dir, tensor_file)
        process_and_save_prediction_overlay(model, tensor_path, output_dir)
        progress_bar.update(1)

    progress_bar.close()

if __name__ == "__main__":
    # Set device
    device = torch.device("cpu")

    # Load the trained SegFormer model
    model = create_or_fetch_segformer(model_size=model_version, input_channels=input_channels, num_classes=num_classes)
    state_dict = load_file(fine_tuned_model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()  # Set to evaluation mode

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Process the tensors in the folder
    process_folder(data_dir, output_dir, model)