<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments/segformer-b4-TPU-loss-prototype.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Mount Google Drive; install dependencies
import os

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install opencv-python
!pip install --upgrade torch_xla torch
!pip install evaluate
!pip install wandb -qU

In [None]:
#@title Downgrade Package for Compatibility (required when continuing training if package has been updated)

# !pip uninstall -y transformers
# !pip install transformers==4.49.0


In [None]:
# @title Load SegmentationDatasets from Drive { display-mode: "code" }

import torch
import json

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

torch.serialization.add_safe_globals([SegmentationDataset])

# Load the dataset from a binary file
def load_dataset(file_path):
    dataset = torch.load(file_path)
    print(f"Dataset loaded from {file_path}")
    return dataset

# Define file paths for loading
train_data_path = '/content/drive/MyDrive/desCartes/pytorch/train_data.pt'
eval_data_path = '/content/drive/MyDrive/desCartes/pytorch/eval_data.pt'

# Load the datasets from Google Drive
eval_dataset = load_dataset(eval_data_path)
train_dataset = load_dataset(train_data_path)


In [3]:
# @title Safe Mean Intersection-over-Union { display-mode: "code" }

# Copyright 2022 The HuggingFace Evaluate Authors.
# Based on https://huggingface.co/spaces/evaluate-metric/mean_iou/blob/main/mean_iou.py

from typing import Dict, Optional

import datasets
import numpy as np

import evaluate

def intersect_and_union(
    pred_label,
    label,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    if label_map is not None:
        for old_id, new_id in label_map.items():
            label[label == old_id] = new_id

    # turn into Numpy arrays
    pred_label = np.array(pred_label)
    label = np.array(label)

    if reduce_labels:
        label[label == 0] = 255
        label = label - 1
        label[label == 254] = 255

    mask = label != ignore_index
    mask = np.not_equal(label, ignore_index)
    pred_label = pred_label[mask]
    label = np.array(label)[mask]

    intersect = pred_label[pred_label == label]

    area_intersect = np.histogram(intersect, bins=num_labels, range=(0, num_labels - 1))[0]
    area_pred_label = np.histogram(pred_label, bins=num_labels, range=(0, num_labels - 1))[0]
    area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]

    area_union = area_pred_label + area_label - area_intersect

    return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect = np.zeros((num_labels,), dtype=np.float64)
    total_area_union = np.zeros((num_labels,), dtype=np.float64)
    total_area_pred_label = np.zeros((num_labels,), dtype=np.float64)
    total_area_label = np.zeros((num_labels,), dtype=np.float64)
    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
            result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels
        )
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label
    return total_area_intersect, total_area_union, total_area_pred_label, total_area_label


def mean_iou(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    nan_to_num: Optional[int] = None,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(
        results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
    )

    # compute metrics
    metrics = dict()
    eps = 1e-10  # Small constant to prevent division by zero
    min_val, max_val = eps, 1 - eps  # Clip range to avoid extreme values
    round_decimals = 5  # Number of decimal places for rounding

    # Compute metrics with epsilon and clipping
    all_acc = np.clip(total_area_intersect.sum() / (total_area_label.sum() + eps), min_val, max_val)
    iou = np.clip(total_area_intersect / (total_area_union + eps), min_val, max_val)
    acc = np.clip(total_area_intersect / (total_area_label + eps), min_val, max_val)

    # Round values, ensuring that values like 9.9999e-01 are rounded up to 1.0
    iou = np.round(iou, round_decimals)
    acc = np.round(acc, round_decimals)

    # Explicitly round values very close to 1.0 (e.g., 0.99999, 0.999999)
    iou = np.where(iou >= 0.99999, 1.0, iou)
    acc = np.where(acc >= 0.99999, 1.0, acc)

    # Assign a default value of 1 for empty classes
    non_empty_classes = total_area_label > 0
    iou[~non_empty_classes] = 1.0
    acc[~non_empty_classes] = 1.0

    # Calculate final metrics with rounding
    metrics = {
        "mean_iou": round(np.nanmean(iou), round_decimals),
        "mean_accuracy": round(np.nanmean(acc), round_decimals),
        "overall_accuracy": round(all_acc, round_decimals),
        "per_category_iou": iou,
        "per_category_accuracy": acc,
    }

    if nan_to_num is not None:
        metrics = dict(
            {metric: np.nan_to_num(metric_value, nan=nan_to_num) for metric, metric_value in metrics.items()}
        )

    return metrics

In [None]:
# @title Train Model { display-mode: "code" }

if 'mean_iou' not in globals():
    raise NameError("Function 'mean_iou' is not defined. Run the appropriate cell first.")

if 'train_dataset' not in globals() or 'eval_dataset' not in globals():
    raise NameError("Either 'train_dataset' or 'eval_dataset' is not defined. Run the appropriate cell first.")

if not train_dataset:  # Checks if train_dataset is empty
    raise ValueError("'train_dataset' is empty.")

if not eval_dataset:  # Checks if eval_dataset is empty
    raise ValueError("'eval_dataset' is empty.")

print("All variable checks passed! Proceeding with execution.")

# Import necessary libraries
import os
import shutil
import numpy as np
import time
import wandb
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torch_xla.distributed.parallel_loader as pl
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer, EarlyStoppingCallback
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support
from google.colab import userdata

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# Select Model
model_version = 'b4'
restart_training = False

# Define class labels
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]

# Local directory for storing dataset
local_data_dir = "/content/data"

# Training Configuration
per_device_train_batch_size = 2  # Batch size for training
per_device_eval_batch_size = per_device_train_batch_size
gradient_accumulation_steps = 1  # Simulates a batch size of gradient_accumulation_steps * per_device_train_batch_size

# Loss Function Configuration
loss_gamma = 2.0  # Focal loss gamma
loss_alpha = 0.25  # Focal loss alpha

###################################################

os.environ["WANDB_API_KEY"] = userdata.get('WANDB_TOKEN')
!wandb login
wandb.init(project="tpu-segmentation", name=f"TPU-Training-{model_version}", settings=wandb.Settings(start_method="fork", _service_wait=60))

# Configure label mappings
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

# Test for existing checkpoints
checkpoint_path = f"{model_path}/checkpoints/{model_version}"
if restart_training and os.path.exists(checkpoint_path):
    shutil.rmtree(checkpoint_path)  # Deletes the folder and its contents
    os.makedirs(checkpoint_path)  # Recreate the empty checkpoint directory
resume_training = os.path.exists(checkpoint_path) and any(os.scandir(checkpoint_path))

def load_or_download_segformer():
    try:
        model_name = f"nvidia/segformer-{model_version}-finetuned-ade-512-512"
        base_model_path = f'{model_path}/base/{model_name}'

        if not os.path.exists(base_model_path):
            print(f"Downloading model from Hugging Face: {model_name}")
            os.makedirs(base_model_path)

            hf_token = userdata.get('HF_TOKEN')
            if hf_token:
                os.environ["HF_TOKEN"] = hf_token

            model = SegformerForSemanticSegmentation.from_pretrained(
                model_name,
                num_labels=num_classes,
                id2label=id2label,
                label2id=label2id,
                ignore_mismatched_sizes=True,
            )
            model.save_pretrained(base_model_path)
        else:
            model = SegformerForSemanticSegmentation.from_pretrained(base_model_path)

        return model
    except Exception as e:
        print(f"Error loading/downloading: {e}")
        raise

def compute_metrics(eval_pred):
    try:
        with torch.no_grad():
            logits, labels = eval_pred
            logits_tensor = torch.from_numpy(logits)

            # Upsample logits to match labels
            logits_tensor = nn.functional.interpolate(
                logits_tensor,
                size=labels.shape[-2:],  # Match height & width of labels
                mode="bilinear",
                align_corners=False,
            ).argmax(dim=1)  # Convert to predicted class indices

            pred_labels = logits_tensor.detach().cpu().numpy()

            # Call the safe mean_iou function (defined in another cell)
            metrics = mean_iou(
                results=pred_labels,
                gt_seg_maps=labels,
                num_labels=num_classes,
                ignore_index=True,
                reduce_labels=False,
            )

            # Extract per-class IoU & accuracy
            per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
            per_category_iou = metrics.pop("per_category_iou").tolist()

            # Compute precision, recall, and F1-score (excluding background)
            pred_flat = pred_labels.flatten()
            labels_flat = labels.flatten()
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels_flat, pred_flat, average="weighted", zero_division=0
            )

            # Store overall metrics
            metrics["overall_accuracy"] = metrics.pop("mean_accuracy")
            metrics["overall_mean_iou"] = metrics.pop("mean_iou")
            metrics["precision"] = precision
            metrics["recall"] = recall
            metrics["f1_score"] = f1

            # Add per-class accuracy & IoU
            metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
            metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

            return metrics

    except Exception as e:
        # Return default zeroed metrics with an error flag
        return {
            "error": True,
            "overall_accuracy": 0.0,
            "overall_mean_iou": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "f1_score": 0.0,
            **{f"accuracy_{id2label[i]}": 0.0 for i in range(num_classes)},
            **{f"iou_{id2label[i]}": 0.0 for i in range(num_classes)},
        }

def tpu_worker_process(rank):

    try:
        # Set TPU device
        device = xm.xla_device()
        world_size = xr.world_size()

        # Ensure that all TPUs are available before proceeding
        xm.rendezvous("ready")
        xm.master_print(f"All {world_size} devices are ready!\n", flush=True)

        # Ensure that model is not fetched from HF more than once
        # Note: model cannot successfully be passed into this function and mounted on each device
        if rank == 0:
            model = load_or_download_segformer()
            model.to(device)
            xm.rendezvous("model_ready")
        else:
            xm.rendezvous("model_ready")
            model = load_or_download_segformer()
            model.to(device)

        # Distributed samplers (drop_last=True to prevent hanging)
        train_sampler = DistributedSampler(
            train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
        )
        eval_sampler = DistributedSampler(
            eval_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True
        )

        # Safe TPU DataLoader setup
        def worker_init_fn(worker_id):
            """Ensures each worker has a different random seed"""
            torch.manual_seed(worker_id + rank)

        train_dataloader = DataLoader(
            train_dataset, batch_size=per_device_train_batch_size, sampler=train_sampler,
            num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
        )
        eval_dataloader = DataLoader(
            eval_dataset, batch_size=per_device_eval_batch_size, sampler=eval_sampler,
            num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
        )

        # Wrap data loaders with MpDeviceLoader for TPU support
        train_dataloader = pl.MpDeviceLoader(train_dataloader, device)
        eval_dataloader = pl.MpDeviceLoader(eval_dataloader, device)

        # Training arguments
        training_args = TrainingArguments(
            output_dir=checkpoint_path,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            dataloader_num_workers=4,
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_steps=10,
            logging_strategy="steps",
            report_to=["wandb"] if rank == 0 else [],
            disable_tqdm=(rank != 0),
            gradient_accumulation_steps=gradient_accumulation_steps,
            fp16=False,
            bf16=True,
            metric_for_best_model="overall_mean_iou",  # Metric to monitor for best model
            greater_is_better=True,  # Set to True to maximize the metric
            num_train_epochs=100,
            save_total_limit=5,  # Keep only the last 5 checkpoints
            load_best_model_at_end=True,
            push_to_hub=False,
            run_name=f"desCartes-{model_version}-{per_device_train_batch_size}-{gradient_accumulation_steps}-bf16"
        )

        # Ensure that all TPUs are properly loaded before proceeding
        xm.rendezvous("steady")
        xm.master_print("All devices are steady!\n", flush=True)

        # Trainer: override standard methods
        class CustomTrainer(Trainer):
            def get_train_dataloader(self):
                return train_dataloader

            def get_eval_dataloader(self, eval_dataset=None):
                return eval_dataloader

            # TODO: The following function needs to be fixed to optimise class-imbalance training
            # At present, even returning a "safe" value causes TPU-reporting to drop out after less than 1 minute

            # def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, gamma=loss_gamma, alpha=loss_alpha):
            #     """
            #     Custom loss computation for Hugging Face Trainer, preserving compatibility with the original logic.
            #     """

            #     labels = inputs.pop("labels", None)  # Extract labels if present

            #     if self.model_accepts_loss_kwargs:
            #         loss_kwargs = {}
            #         if num_items_in_batch is not None:
            #             loss_kwargs["num_items_in_batch"] = num_items_in_batch
            #         inputs = {**inputs, **loss_kwargs}

            #     outputs = model(**inputs)

            #     if self.args.past_index >= 0:
            #         self._past = outputs[self.args.past_index]

            #     if labels is not None:

            #         loss = torch.tensor(0.0, device=labels.device) # <<< BUGFIXING: Return safe value?

            #         # logits = outputs.logits

            #         # # Ensure device consistency
            #         # if logits.device != labels.device:
            #         #     labels = labels.to(logits.device)

            #         # # Compute cross-entropy loss
            #         # ce_loss = F.cross_entropy(logits, labels, reduction="none").float()

            #         # # Compute focal loss components
            #         # pt = torch.exp(-torch.clamp(ce_loss, min=1e-8, max=1e8))  # Avoid instability
            #         # focal_loss = alpha * (1 - pt) ** gamma * ce_loss

            #         # loss = focal_loss.mean()

            #     else:
            #         # Handle case where model should return loss directly
            #         if isinstance(outputs, dict) and "loss" not in outputs:
            #             raise ValueError(
            #                 "The model did not return a loss from the inputs, only the following keys: "
            #                 f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
            #             )
            #         loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

            #     # Adjust loss scaling if using multi-device
            #     if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
            #         loss *= self.accelerator.num_processes

            #     return (loss, outputs) if return_outputs else loss

        trainer = CustomTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
        )

        # Synchronize TPUs before starting training
        xm.rendezvous("start_training")  # Ensure all TPU processes sync before proceeding
        xm.master_print(f"All devices are GO! ... training started (resume={resume_training})...\n", flush=True)

        trainer.train(resume_from_checkpoint=resume_training)
        xm.rendezvous("training_complete")  # Ensure all TPU processes sync before exit
        xm.master_print("Training completed!\n", flush=True)

        # Terminate WandB logging
        if rank == 0:
            wandb.finish()

    except Exception as e:
        raise e

    return

# Launch TPU training with WandB logging
xmp.spawn(tpu_worker_process, args=(), start_method='fork')


All variable checks passed! Proceeding with execution.
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


All 8 devices are ready!





All devices are steady!

All devices are GO! ... training started (resume=False)...



Epoch,Training Loss,Validation Loss,Overall Accuracy,Overall Mean Iou,Precision,Recall,F1 Score,Accuracy Background,Accuracy Main Road,Accuracy Minor Road,Accuracy Semi Enclosed Path,Accuracy Unenclosed Path,Iou Background,Iou Main Road,Iou Minor Road,Iou Semi Enclosed Path,Iou Unenclosed Path
1,0.6116,0.507012,0.39955,0.39082,0.901888,0.946136,0.922573,0.99746,1.0,0.00025,0.0,5e-05,0.9538,1.0,0.00025,0.0,4e-05
2,0.2947,0.337616,0.39998,0.39119,0.899766,0.94817,0.923047,0.99991,1.0,0.0,0.0,0.0,0.95595,1.0,0.0,0.0,0.0
3,0.2483,0.259972,0.4,0.3912,0.89918,0.948251,0.923064,1.0,1.0,0.0,0.0,0.0,0.95602,1.0,0.0,0.0,0.0
4,0.244,0.254836,0.4,0.3912,0.89918,0.948251,0.923064,1.0,1.0,0.0,0.0,0.0,0.95602,1.0,0.0,0.0,0.0
5,0.2507,0.253241,0.4,0.3912,0.89918,0.948251,0.923064,1.0,1.0,0.0,0.0,0.0,0.95602,1.0,0.0,0.0,0.0
6,0.2366,0.251325,0.4,0.3912,0.89918,0.948251,0.923064,1.0,1.0,0.0,0.0,0.0,0.95602,1.0,0.0,0.0,0.0
