<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments%20/segformer-b4-TPU-100-epochs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Authenticate GCS, WandB, and Hugging Face; mount Google Drive; install dependencies

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!gcloud auth application-default login
!gcloud config set project descartes-404713

!pip install wandb -qU
!wandb login

from google.colab import userdata
userdata.get('HF_TOKEN')

!pip install opencv-python
!pip install --upgrade torch_xla torch
!pip install evaluate

In [None]:
# @title Load Data from GCS { display-mode: "code" }
# Import necessary libraries
import torch
from google.cloud import storage
from transformers import SegformerImageProcessor
import io
import os

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'

# Google Cloud Storage (GCS) configuration
gcs_key_path = f'{project_path}/descartes-404713-cccf7c3921aa.json'
gcs_project_id = 'descartes-404713'
gcs_bucket_name = 'descartes'
gcs_data_directory = "training_data"

# Authenticate with your GCS key file
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = gcs_key_path
storage_client = storage.Client()

# Initialize image processor (same as before)
image_processor = SegformerImageProcessor.from_pretrained('nvidia/segformer-b2-finetuned-ade-512-512')

# Function to check if a .pt file is loadable
def check_loadable(file_data):
    try:
        # Attempt to load the tensor from the file data
        data = torch.load(io.BytesIO(file_data))
        return True
    except Exception as e:
        print(f"Error loading data: {e}")
        return False

# Load dataset from GCS into memory
def load_data_from_gcs(bucket_name, data_directory):
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=data_directory)  # List all blobs in the data directory

    data = []  # This will hold the loaded data

    for blob in blobs:
        if blob.name.endswith(".pt"):
            print(f"Processing {blob.name}...")

            # Read the blob into memory (without saving it locally)
            file_data = blob.download_as_bytes()

            # Check if the file is loadable
            if check_loadable(file_data):
                try:
                    # Load data directly into memory
                    file_tensor = torch.load(io.BytesIO(file_data))
                    inputs = image_processor(images=file_tensor['images'], return_tensors="pt")
                    pixel_values = inputs['pixel_values'].squeeze(0)
                    label = file_tensor['labels'].squeeze().long()

                    # Append the data to the list
                    data.append({"pixel_values": pixel_values, "labels": label})

                except Exception as e:
                    print(f"Error processing {blob.name}: {e}")
            else:
                print(f"Skipping corrupt file: {blob.name}")

    return data

# Load train and eval data from GCS
train_data = load_data_from_gcs(gcs_bucket_name, f"{gcs_data_directory}/train")
eval_data = load_data_from_gcs(gcs_bucket_name, f"{gcs_data_directory}/eval")

# Convert the loaded data into a custom dataset
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create datasets for training and evaluation
train_dataset = SegmentationDataset(train_data)
eval_dataset = SegmentationDataset(eval_data)

# Now train_dataset and eval_dataset are ready to be used for training


In [6]:
# @title Save SegmentationDatasets to Drive { display-mode: "code" }

import torch
import os

# Save the dataset (train and eval data) to a binary file
def save_dataset(dataset, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    torch.save(dataset, file_path)
    print(f"Dataset saved to {file_path}")

# Define file paths for saving
train_data_path = '/content/drive/MyDrive/desCartes/pytorch/train_data.pt'
eval_data_path = '/content/drive/MyDrive/desCartes/pytorch/eval_data.pt'

# Save the datasets
save_dataset(train_dataset, train_data_path)
save_dataset(eval_dataset, eval_data_path)


Dataset saved to /content/drive/MyDrive/desCartes/pytorch/train_data.pt
Dataset saved to /content/drive/MyDrive/desCartes/pytorch/eval_data.pt


In [3]:
# @title Load SegmentationDatasets from Drive { display-mode: "code" }

import torch

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

torch.serialization.add_safe_globals([SegmentationDataset])

# Load the dataset from a binary file
def load_dataset(file_path):
    dataset = torch.load(file_path)
    print(f"Dataset loaded from {file_path}")
    return dataset

# Define file paths for loading
train_data_path = '/content/drive/MyDrive/desCartes/pytorch/train_data.pt'
eval_data_path = '/content/drive/MyDrive/desCartes/pytorch/eval_data.pt'

# Load the datasets from Google Drive
eval_dataset = load_dataset(eval_data_path)
train_dataset = load_dataset(train_data_path)


Dataset loaded from /content/drive/MyDrive/desCartes/pytorch/eval_data.pt
Dataset loaded from /content/drive/MyDrive/desCartes/pytorch/train_data.pt


In [15]:
# @title Safe Mean IoU { display-mode: "code" }

# Copyright 2022 The HuggingFace Evaluate Authors.
# Based on https://huggingface.co/spaces/evaluate-metric/mean_iou/blob/main/mean_iou.py
"""Mean IoU (Intersection-over-Union) metric."""

from typing import Dict, Optional

import datasets
import numpy as np

import evaluate

def intersect_and_union(
    pred_label,
    label,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    if label_map is not None:
        for old_id, new_id in label_map.items():
            label[label == old_id] = new_id

    # turn into Numpy arrays
    pred_label = np.array(pred_label)
    label = np.array(label)

    if reduce_labels:
        label[label == 0] = 255
        label = label - 1
        label[label == 254] = 255

    mask = label != ignore_index
    mask = np.not_equal(label, ignore_index)
    pred_label = pred_label[mask]
    label = np.array(label)[mask]

    intersect = pred_label[pred_label == label]

    area_intersect = np.histogram(intersect, bins=num_labels, range=(0, num_labels - 1))[0]
    area_pred_label = np.histogram(pred_label, bins=num_labels, range=(0, num_labels - 1))[0]
    area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]

    area_union = area_pred_label + area_label - area_intersect

    return area_intersect, area_union, area_pred_label, area_label


def total_intersect_and_union(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect = np.zeros((num_labels,), dtype=np.float64)
    total_area_union = np.zeros((num_labels,), dtype=np.float64)
    total_area_pred_label = np.zeros((num_labels,), dtype=np.float64)
    total_area_label = np.zeros((num_labels,), dtype=np.float64)
    for result, gt_seg_map in zip(results, gt_seg_maps):
        area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
            result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels
        )
        total_area_intersect += area_intersect
        total_area_union += area_union
        total_area_pred_label += area_pred_label
        total_area_label += area_label
    return total_area_intersect, total_area_union, total_area_pred_label, total_area_label


def mean_iou(
    results,
    gt_seg_maps,
    num_labels,
    ignore_index: bool,
    nan_to_num: Optional[int] = None,
    label_map: Optional[Dict[int, int]] = None,
    reduce_labels: bool = False,
):
    total_area_intersect, total_area_union, total_area_pred_label, total_area_label = total_intersect_and_union(
        results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
    )

    # compute metrics
    metrics = dict()
    eps = 1e-10  # Small constant to prevent division by zero
    min_val, max_val = 1e-5, 1 - 1e-5  # Clip range to avoid extreme values
    round_decimals = 5  # Number of decimal places for rounding

    # Compute metrics with epsilon and clipping
    all_acc = np.clip(total_area_intersect.sum() / (total_area_label.sum() + eps), min_val, max_val)
    iou = np.clip(total_area_intersect / (total_area_union + eps), min_val, max_val)
    acc = np.clip(total_area_intersect / (total_area_label + eps), min_val, max_val)

    # Round values, ensuring that values like 9.9999e-01 are rounded up to 1.0
    iou = np.round(iou, round_decimals)
    acc = np.round(acc, round_decimals)

    # Explicitly round values very close to 1.0 (e.g., 0.99999, 0.999999)
    iou = np.where(iou >= 0.99999, 1.0, iou)
    acc = np.where(acc >= 0.99999, 1.0, acc)

    # Calculate final metrics with rounding
    metrics = {
        "mean_iou": round(np.nanmean(iou), round_decimals),
        "mean_accuracy": round(np.nanmean(acc), round_decimals),
        "overall_accuracy": round(all_acc, round_decimals),
        "per_category_iou": iou,
        "per_category_accuracy": acc,
    }

    if nan_to_num is not None:
        metrics = dict(
            {metric: np.nan_to_num(metric_value, nan=nan_to_num) for metric, metric_value in metrics.items()}
        )

    return metrics


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class MeanIoU(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": datasets.Image(),
                    "references": datasets.Image(),
                }
            ),
            reference_urls=[
                "https://github.com/open-mmlab/mmsegmentation/blob/71c201b1813267d78764f306a297ca717827c4bf/mmseg/core/evaluation/metrics.py"
            ],
        )

    def _compute(
        self,
        predictions,
        references,
        num_labels: int,
        ignore_index: bool,
        nan_to_num: Optional[int] = None,
        label_map: Optional[Dict[int, int]] = None,
        reduce_labels: bool = False,
    ):
        iou_result = mean_iou(
            results=predictions,
            gt_seg_maps=references,
            num_labels=num_labels,
            ignore_index=ignore_index,
            nan_to_num=nan_to_num,
            label_map=label_map,
            reduce_labels=reduce_labels,
        )
        return iou_result

In [16]:
# Assuming the cell with the mean_iou function has been run.

# Example data (replace with your actual data)
predictions = [np.array([[1, 2], [3, 4]]), np.array([[2, 3], [4, 1]])]
ground_truth = [np.array([[1, 1], [3, 4]]), np.array([[2, 3], [4, 2]])]
num_labels = 5  # Example number of labels
ignore_index = 255  # Example ignore index

# Call the mean_iou function
iou_results = mean_iou(
    results=predictions,
    gt_seg_maps=ground_truth,
    num_labels=num_labels,
    ignore_index=ignore_index,
    reduce_labels=False,
)

# Print the results
print(iou_results)

{'mean_iou': np.float64(0.53333), 'mean_accuracy': np.float64(0.6), 'overall_accuracy': np.float64(0.75), 'per_category_iou': array([1.0000e-05, 3.3333e-01, 3.3333e-01, 1.0000e+00, 1.0000e+00]), 'per_category_accuracy': array([1.e-05, 5.e-01, 5.e-01, 1.e+00, 1.e+00])}


In [None]:
# @title Train Model { display-mode: "code" }
# Import necessary libraries
import os
import numpy as np
import time
import wandb
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torch_xla.distributed.parallel_loader as pl
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer, EarlyStoppingCallback
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support
# Tidy up output [ineffective for TPUs]
# import warnings
# warnings.filterwarnings("ignore", message="Some weights of SegformerForSemanticSegmentation were not initialized", category=UserWarning)
# warnings.filterwarnings("ignore", message=".*feature_extractor_type.*", category=UserWarning)

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# Select Model
model_version = 'b4'

# Define class labels
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]

# Local directory for storing dataset
local_data_dir = "/content/data"

# Training Configuration
per_device_train_batch_size = 2  # Batch size for training
per_device_eval_batch_size = per_device_train_batch_size
gradient_accumulation_steps = 1  # Simulates a batch size of gradient_accumulation_steps * per_device_train_batch_size

###################################################

# Configure label mappings
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

def compute_metrics(eval_pred):
    try:
        with torch.no_grad():
            logits, labels = eval_pred
            logits_tensor = torch.from_numpy(logits)

            # Upsample logits to match labels
            logits_tensor = nn.functional.interpolate(
                logits_tensor,
                size=labels.shape[-2:],  # Match height & width of labels
                mode="bilinear",
                align_corners=False,
            ).argmax(dim=1)  # Convert to predicted class indices

            pred_labels = logits_tensor.detach().cpu().numpy()

            # Call the safe mean_iou function (defined in another cell)
            metrics = mean_iou(
                results=pred_labels,
                gt_seg_maps=labels,
                num_labels=num_classes,
                ignore_index=0, # class 0 is background/ignored
                reduce_labels=False,
            )

            # Extract per-class IoU & accuracy
            per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
            per_category_iou = metrics.pop("per_category_iou").tolist()

            # Compute precision, recall, and F1-score (excluding background)
            pred_flat = pred_labels.flatten()
            labels_flat = labels.flatten()
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels_flat, pred_flat, average="weighted", zero_division=0
            )

            # Store overall metrics
            metrics["overall_accuracy"] = metrics.pop("mean_accuracy")
            metrics["overall_mean_iou"] = metrics.pop("mean_iou")
            metrics["precision"] = precision
            metrics["recall"] = recall
            metrics["f1_score"] = f1

            # Add per-class accuracy & IoU
            metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
            metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

            return metrics

    except Exception as e:
        # Return default zeroed metrics with an error flag
        return {
            "error": True,
            "overall_accuracy": 0.0,
            "overall_mean_iou": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "f1_score": 0.0,
            **{f"accuracy_{id2label[i]}": 0.0 for i in range(num_classes)},
            **{f"iou_{id2label[i]}": 0.0 for i in range(num_classes)},
        }

def _mp_fn(rank):
    # Set TPU device inside the function
    device = xm.xla_device()
    world_size = xr.world_size()
    xm.master_print(f"Process {rank}/{world_size} using device {device}")

    # Synchronize TPUs before starting
    xm.rendezvous("start_training")  # Ensure all TPU processes sync before proceeding

    # Initialize WandB only for the main TPU process
    if rank == 0:
        wandb.init(project="tpu-segmentation", name=f"TPU-Training-{model_version}")

    # Load the image processor and model inside _mp_fn
    image_processor = SegformerImageProcessor.from_pretrained(f'nvidia/segformer-{model_version}-finetuned-ade-512-512')

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/segformer-{model_version}-finetuned-ade-512-512",
        num_labels=num_classes,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    ).to(device)

    # Distributed samplers (drop_last=True to prevent hanging)
    train_sampler = DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
    )
    eval_sampler = DistributedSampler(
        eval_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True
    )

    # Safe TPU DataLoader setup
    def worker_init_fn(worker_id):
        """Ensures each worker has a different random seed"""
        torch.manual_seed(worker_id + rank)

    train_dataloader = DataLoader(
        train_dataset, batch_size=per_device_train_batch_size, sampler=train_sampler,
        num_workers=4, pin_memory=True, persistent_workers=False, worker_init_fn=worker_init_fn
    )
    eval_dataloader = DataLoader(
        eval_dataset, batch_size=per_device_eval_batch_size, sampler=eval_sampler,
        num_workers=4, pin_memory=True, persistent_workers=False, worker_init_fn=worker_init_fn
    )

    # Wrap data loaders with MpDeviceLoader for TPU support
    train_dataloader = pl.MpDeviceLoader(train_dataloader, device)
    eval_dataloader = pl.MpDeviceLoader(eval_dataloader, device)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"{model_path}/checkpoints",
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        dataloader_num_workers=4,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=5,  # Keep only the last 5 checkpoints
        logging_steps=10,
        logging_strategy="steps",
        report_to=["wandb"] if rank == 0 else [],
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=100,
        load_best_model_at_end=True,
        push_to_hub=False,
        fp16=False,
        bf16=True,
        metric_for_best_model="overall_mean_iou",  # Metric to monitor for best model
        greater_is_better=True,  # Set to True to maximize the metric
        run_name=f"desCartes-{model_version}-{per_device_train_batch_size}-{gradient_accumulation_steps}-bf16"
    )

    # Trainer: override standard dataloader methods
    class CustomTrainer(Trainer):
        def get_train_dataloader(self):
            return train_dataloader

        def get_eval_dataloader(self, eval_dataset=None):
            return eval_dataloader

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    trainer.train(resume_from_checkpoint=os.path.exists(f"{model_path}/checkpoints"))
    xm.rendezvous("training_complete")  # Ensure all TPU processes sync before exit

    if rank == 0:
        wandb.finish()  # Close WandB properly [Leave open for metrics via API]


# Launch TPU training
if __name__ == "__main__":

    xmp.spawn(_mp_fn, args=(), start_method='fork')


Process 0/8 using device xla:0


In [None]:
# @title Visualising Results { display-mode: "code" }

# Function to display images and predicted masks
def plot_predictions(model, dataset, n_samples=3):
    for i, (images, labels) in enumerate(dataset.take(n_samples)):
        predictions = model(images).logits
        predictions = tf.argmax(predictions, axis=-1)

        for j in range(min(n_samples, len(images))):
            image = images[j].numpy()
            label = labels[j].numpy()
            prediction = predictions[j].numpy()

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(image)
            axes[0].set_title('Input Image')
            axes[1].imshow(np.argmax(label, axis=-1), cmap='viridis')
            axes[1].set_title('True Label')
            axes[2].imshow(prediction, cmap='viridis')
            axes[2].set_title('Predicted Mask')
            plt.show()

# Display some predictions
plot_predictions(model, val_dataset)


In [None]:
# @title Evaluation Metrics { display-mode: "code" }
from sklearn.metrics import classification_report

# Function to calculate metrics for model evaluation
def evaluate_model(model, dataset):
    all_preds = []
    all_labels = []

    for images, labels in dataset.take(10):  # evaluate on first 10 batches
        predictions = model(images).logits
        preds = tf.argmax(predictions, axis=-1).numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

    # Flatten the lists for classification_report
    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()

    report = classification_report(all_labels, all_preds, output_dict=True)
    return report

# Print evaluation metrics
eval_report = evaluate_model(model, val_dataset)
print("Evaluation Metrics:\n", eval_report)


In [None]:
# @title Model Saving { display-mode: "code" }
# Save the trained model
model.save_pretrained(f'{model_path}/segformer_model')
# Save the image processor
image_processor.save_pretrained(f'{model_path}/image_processor')


In [None]:
# @title Visualising Training Logs { display-mode: "code" }
import os

# Function to plot training logs
def plot_logs(log_dir='./logs'):
    log_files = [f for f in os.listdir(log_dir) if f.endswith('.json')]

    if len(log_files) == 0:
        print("No log files found.")
        return

    log_file = log_files[0]
    log_path = os.path.join(log_dir, log_file)
    logs = []

    with open(log_path, 'r') as f:
        logs = f.readlines()

    steps, losses = [], []
    for log in logs:
        if 'step' in log and 'loss' in log:
            step = int(log.split('step')[1].split(',')[0].strip())
            loss = float(log.split('loss')[1].split(',')[0].strip())
            steps.append(step)
            losses.append(loss)

    plt.plot(steps, losses)
    plt.xlabel('Training Steps')
    plt.ylabel('Loss')
    plt.title('Training Loss Progress')
    plt.show()

# Plot the training logs
plot_logs()
