<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments/segformer-b4-TPU-100-epochs-metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Authenticate WandB and Hugging Face; mount Google Drive; install dependencies
import os

!pip install wandb -qU
!wandb login

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')

!pip install opencv-python
!pip install --upgrade torch_xla torch
!pip install evaluate

In [3]:
# @title Load SegmentationDatasets from Drive { display-mode: "code" }

import torch

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

torch.serialization.add_safe_globals([SegmentationDataset])

# Load the dataset from a binary file
def load_dataset(file_path):
    dataset = torch.load(file_path)
    print(f"Dataset loaded from {file_path}")
    return dataset

# Define file paths for loading
train_data_path = '/content/drive/MyDrive/desCartes/pytorch/train_data.pt'
eval_data_path = '/content/drive/MyDrive/desCartes/pytorch/eval_data.pt'

# Load the datasets from Google Drive
eval_dataset = load_dataset(eval_data_path)
train_dataset = load_dataset(train_data_path)


Dataset loaded from /content/drive/MyDrive/desCartes/pytorch/eval_data.pt
Dataset loaded from /content/drive/MyDrive/desCartes/pytorch/train_data.pt


In [4]:
# @title Safe Mean Intersection-over-Union { display-mode: "code" }

# Copyright 2022 The HuggingFace Evaluate Authors.
# Based on https://huggingface.co/spaces/evaluate-metric/mean_iou/blob/main/mean_iou.py

from typing import Dict, Optional

import datasets
import numpy as np

import evaluate

def intersect_and_union(
    pred_label,
    label,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    if label_map is not None:
        for old_id, new_id in label_map.items():
            label[label == old_id] = new_id

    # turn into Numpy arrays
    pred_label = np.array(pred_label)
    label = np.array(label)

    if reduce_labels:
        label[label == 0] = 255
        label = label - 1
        label[label == 254] = 255

    mask = label != ignore_index
    mask = np.not_equal(label, ignore_index)
    pred_label = pred_label[mask]
    label = np.array(label)[mask]

    intersect = pred_label[pred_label == label]

    area_intersect = np.histogram(intersect, bins=num_labels, range=(0, num_labels - 1))[0]
    area_pred_label = np.histogram(pred_label, bins=num_labels, range=(0, num_labels - 1))[0]
    area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]

    area_union = area_pred_label + area_label - area_intersect

    return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect = np.zeros((num_labels,), dtype=np.float64)
    total_area_union = np.zeros((num_labels,), dtype=np.float64)
    total_area_pred_label = np.zeros((num_labels,), dtype=np.float64)
    total_area_label = np.zeros((num_labels,), dtype=np.float64)
    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
            result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels
        )
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label
    return total_area_intersect, total_area_union, total_area_pred_label, total_area_label


def mean_iou(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    nan_to_num: Optional[int] = None,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(
        results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
    )

    # compute metrics
    metrics = dict()
    eps = 1e-10  # Small constant to prevent division by zero
    min_val, max_val = 1e-5, 1 - 1e-5  # Clip range to avoid extreme values
    round_decimals = 5  # Number of decimal places for rounding

    # Compute metrics with epsilon and clipping
    all_acc = np.clip(total_area_intersect.sum() / (total_area_label.sum() + eps), min_val, max_val)
    iou = np.clip(total_area_intersect / (total_area_union + eps), min_val, max_val)
    acc = np.clip(total_area_intersect / (total_area_label + eps), min_val, max_val)

    # Round values, ensuring that values like 9.9999e-01 are rounded up to 1.0
    iou = np.round(iou, round_decimals)
    acc = np.round(acc, round_decimals)

    # Explicitly round values very close to 1.0 (e.g., 0.99999, 0.999999)
    iou = np.where(iou >= 0.99999, 1.0, iou)
    acc = np.where(acc >= 0.99999, 1.0, acc)

    # Calculate final metrics with rounding
    metrics = {
        "mean_iou": round(np.nanmean(iou), round_decimals),
        "mean_accuracy": round(np.nanmean(acc), round_decimals),
        "overall_accuracy": round(all_acc, round_decimals),
        "per_category_iou": iou,
        "per_category_accuracy": acc,
    }

    if nan_to_num is not None:
        metrics = dict(
            {metric: np.nan_to_num(metric_value, nan=nan_to_num) for metric, metric_value in metrics.items()}
        )

    return metrics

In [None]:
# @title Train Model { display-mode: "code" }

if 'mean_iou' not in globals():
    raise NameError("Function 'mean_iou' is not defined. Run the appropriate cell first.")

if 'train_dataset' not in globals() or 'eval_dataset' not in globals():
    raise NameError("Either 'train_dataset' or 'eval_dataset' is not defined. Run the appropriate cell first.")

if not train_dataset:  # Checks if train_dataset is empty
    raise ValueError("'train_dataset' is empty.")

if not eval_dataset:  # Checks if eval_dataset is empty
    raise ValueError("'eval_dataset' is empty.")

print("All variable checks passed! Proceeding with execution.")

# Import necessary libraries
import os
import numpy as np
import time
import wandb
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torch_xla.distributed.parallel_loader as pl
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer, EarlyStoppingCallback
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support
# Tidy up output [ineffective for TPUs]
# import warnings
# warnings.filterwarnings("ignore", message="Some weights of SegformerForSemanticSegmentation were not initialized", category=UserWarning)
# warnings.filterwarnings("ignore", message=".*feature_extractor_type.*", category=UserWarning)

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# Select Model
model_version = 'b4'

# Define class labels
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]

# Local directory for storing dataset
local_data_dir = "/content/data"

# Training Configuration
per_device_train_batch_size = 2  # Batch size for training
per_device_eval_batch_size = per_device_train_batch_size
gradient_accumulation_steps = 1  # Simulates a batch size of gradient_accumulation_steps * per_device_train_batch_size

###################################################

# Configure label mappings
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

def compute_metrics(eval_pred):
    try:
        with torch.no_grad():
            logits, labels = eval_pred
            logits_tensor = torch.from_numpy(logits)

            # Upsample logits to match labels
            logits_tensor = nn.functional.interpolate(
                logits_tensor,
                size=labels.shape[-2:],  # Match height & width of labels
                mode="bilinear",
                align_corners=False,
            ).argmax(dim=1)  # Convert to predicted class indices

            pred_labels = logits_tensor.detach().cpu().numpy()

            # Call the safe mean_iou function (defined in another cell)
            metrics = mean_iou(
                results=pred_labels,
                gt_seg_maps=labels,
                num_labels=num_classes,
                ignore_index=0, # class 0 is background/ignored
                reduce_labels=False,
            )

            # Extract per-class IoU & accuracy
            per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
            per_category_iou = metrics.pop("per_category_iou").tolist()

            # Compute precision, recall, and F1-score (excluding background)
            pred_flat = pred_labels.flatten()
            labels_flat = labels.flatten()
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels_flat, pred_flat, average="weighted", zero_division=0
            )

            # Store overall metrics
            metrics["overall_accuracy"] = metrics.pop("mean_accuracy")
            metrics["overall_mean_iou"] = metrics.pop("mean_iou")
            metrics["precision"] = precision
            metrics["recall"] = recall
            metrics["f1_score"] = f1

            # Add per-class accuracy & IoU
            metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
            metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

            return metrics

    except Exception as e:
        # Return default zeroed metrics with an error flag
        return {
            "error": True,
            "overall_accuracy": 0.0,
            "overall_mean_iou": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "f1_score": 0.0,
            **{f"accuracy_{id2label[i]}": 0.0 for i in range(num_classes)},
            **{f"iou_{id2label[i]}": 0.0 for i in range(num_classes)},
        }

def _mp_fn(rank, hf_token):
    # Set TPU device inside the function
    device = xm.xla_device()
    world_size = xr.world_size()
    xm.master_print(f"Process {rank}/{world_size} using device {device}")

    # Initialise WandB only for the main TPU process
    if rank == 0:
        print("Initialising WandB...")
        wandb.init(project="tpu-segmentation", name=f"TPU-Training-{model_version}")

    # Load the image processor and model inside _mp_fn
    os.environ['HF_TOKEN'] = hf_token
    image_processor = SegformerImageProcessor.from_pretrained(f'nvidia/segformer-{model_version}-finetuned-ade-512-512')

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/segformer-{model_version}-finetuned-ade-512-512",
        num_labels=num_classes,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    ).to(device)

    # Distributed samplers (drop_last=True to prevent hanging)
    train_sampler = DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
    )
    eval_sampler = DistributedSampler(
        eval_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True
    )

    # Safe TPU DataLoader setup
    def worker_init_fn(worker_id):
        """Ensures each worker has a different random seed"""
        torch.manual_seed(worker_id + rank)

    train_dataloader = DataLoader(
        train_dataset, batch_size=per_device_train_batch_size, sampler=train_sampler,
        num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
    )
    eval_dataloader = DataLoader(
        eval_dataset, batch_size=per_device_eval_batch_size, sampler=eval_sampler,
        num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
    )

    # Wrap data loaders with MpDeviceLoader for TPU support
    train_dataloader = pl.MpDeviceLoader(train_dataloader, device)
    eval_dataloader = pl.MpDeviceLoader(eval_dataloader, device)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"{model_path}/checkpoints",
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        dataloader_num_workers=4,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=5,  # Keep only the last 5 checkpoints
        logging_steps=10,
        logging_strategy="steps",
        report_to=["wandb"] if rank == 0 else [],
        disable_tqdm=(rank != 0),
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=100,
        load_best_model_at_end=True,
        push_to_hub=False,
        fp16=False,
        bf16=True,
        metric_for_best_model="overall_mean_iou",  # Metric to monitor for best model
        greater_is_better=True,  # Set to True to maximize the metric
        run_name=f"desCartes-{model_version}-{per_device_train_batch_size}-{gradient_accumulation_steps}-bf16"
    )

    # Trainer: override standard dataloader methods
    class CustomTrainer(Trainer):
        def get_train_dataloader(self):
            return train_dataloader

        def get_eval_dataloader(self, eval_dataset=None):
            return eval_dataloader

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Synchronize TPUs before starting
    xm.rendezvous("start_training")  # Ensure all TPU processes sync before proceeding
    if rank == 0:
        print("TPUs synchronised. Starting training...")

    # trainer.train(resume_from_checkpoint=os.path.exists(f"{model_path}/checkpoints"))
    trainer.train(resume_from_checkpoint=False)
    xm.rendezvous("training_complete")  # Ensure all TPU processes sync before exit

    if rank == 0:
        wandb.finish()  # Close WandB properly [Leave open for metrics via API]


# Launch TPU training
if __name__ == "__main__":
    hf_token = os.getenv('HF_TOKEN')
    xmp.spawn(_mp_fn, args=(hf_token,), start_method='fork')


All variable checks passed! Proceeding with execution.
Process 0/8 using device xla:0
Initialising WandB...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).
Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([5, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([5, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in 

TPUs synchronised. Starting training...TPUs synchronised. Starting training...TPUs synchronised. Starting training...TPUs synchronised. Starting training...TPUs synchronised. Starting training...TPUs synchronised. Starting training...TPUs synchronised. Starting training...
TPUs synchronised. Starting training...








Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05


Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6739,0.35413,1e-05,1e-05,0.89918,0.948251,0.923064,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05,1e-05
