<a href="https://colab.research.google.com/github/docuracy/desCartes/blob/main/experiments/segformer-TPU-3-epochs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title Authenticate GCS, mount Google Drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!gcloud auth application-default login
!gcloud config set project descartes-404713

!pip install wandb -qU
!wandb login

!pip install opencv-python

Mounted at /content/drive
Go to the following link in your browser, and complete the sign-in prompts:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=c7S0Q8eIORwwIzSoMdD4y0Ivw1oFGF&prompt=consent&token_usage=remote&access_type=offline&code_challenge=H9Bb6BXk94PvHjFx_piosBzqjmAXjTi70emQbi5yQQQ&code_challenge_method=S256

Once finished, enter the verification code provided in your browser: 4/0AQSTgQEuV8SBLV1zGzLiHqpvqOIrM6r3NTCKsgsTcRcUdjnpYkmzZrB-UfAE9Y6LDUoJ1A

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Def

In [None]:
# @title Copy Data from GCS { display-mode: "code" }
import os
import shutil
from google.cloud import storage
import torch
import random

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# GCS Configuration
gcs_key_path = f'{project_path}/descartes-404713-cccf7c3921aa.json'
gcs_project_id = 'descartes-404713'
gcs_bucket_name = 'descartes'
gcs_data_directory = "training_data"

# Local directory for storing dataset
local_data_dir = "/content/data"

# Ensure local directories exist
local_train_dir = f"{local_data_dir}/train"
local_eval_dir = f"{local_data_dir}/eval"
local_corrupt_dir = f"{local_data_dir}/eval_corrupt"
os.makedirs(local_train_dir, exist_ok=True)
os.makedirs(local_eval_dir, exist_ok=True)
os.makedirs(local_corrupt_dir, exist_ok=True)

# **Connect to GCS and list blobs**
# Authenticate with your GCS key file
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = gcs_key_path
storage_client = storage.Client()

# Get the bucket and list blobs within the specified directory
bucket = storage_client.bucket(gcs_bucket_name)
blobs = list(bucket.list_blobs(prefix=gcs_data_directory)) # Get all blobs with the specified prefix

# Function to check if a .pt file is loadable
def check_loadable(file_path):
    try:
        data = torch.load(file_path)
        return True
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        return False

# Download files to respective folders
for blob in blobs:
    if blob.name.endswith(".pt"):
        if "/eval/" in blob.name:
            local_path = os.path.join(local_eval_dir, os.path.basename(blob.name))
        else:
            local_path = os.path.join(local_train_dir, os.path.basename(blob.name))

        if not os.path.exists(local_path):
            print(f"Downloading {blob.name}...")
            blob.download_to_filename(local_path)

        # Check if the file is loadable
        if not check_loadable(local_path):
            # If not loadable, move to the corrupt folder
            print(f"Moving corrupted file {local_path} to {local_corrupt_dir}")
            shutil.move(local_path, os.path.join(local_corrupt_dir, os.path.basename(blob.name)))

print("✅ Train and eval files downloaded to local storage.")

# List all files in the corrupt directory
corrupt_files = os.listdir(local_corrupt_dir)

# Ensure there are corrupt files to replace
if len(corrupt_files) == 0:
    print("No corrupt files found in the eval_corrupt directory.")
else:
    # Loop through the corrupt files
    for file_name in corrupt_files:
        corrupt_file_path = os.path.join(local_corrupt_dir, file_name)
        duplicate_file_path = os.path.join(local_eval_dir, file_name)

        # Select a random file from the train directory to duplicate
        eval_files = os.listdir(local_eval_dir)
        valid_file_name = random.choice(eval_files)
        valid_file_path = os.path.join(local_eval_dir, valid_file_name)

        # Copy the selected valid file to the eval folder with the name of the corrupt file
        shutil.copy(valid_file_path, duplicate_file_path)
        print(f"Replaced corrupted file {file_name} with {valid_file_name}.")

print("✅ Corrupted files have been replaced with duplicates of valid ones.")


In [1]:
# @title Train Model { display-mode: "code" }
# Import necessary libraries
import os
import wandb
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torch_xla.distributed.parallel_loader as pl
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer

# Read Hugging Face Hub User Access Token
from google.colab import userdata
userdata.get('HF_TOKEN')

# Google Drive Path Configuration
project_path = '/content/drive/MyDrive/desCartes'
model_path = f'{project_path}/models'
results_path = f'{project_path}/results'

# Select Model
model_version = 'b2'

# Define class labels
class_labels = ["background", "main_road", "minor_road", "semi_enclosed_path", "unenclosed_path"]

# Local directory for storing dataset
local_data_dir = "/content/data"

# Training Configuration
per_device_train_batch_size = 2  # Batch size for training
per_device_eval_batch_size = per_device_train_batch_size
gradient_accumulation_steps = 1  # Simulates a batch size of gradient_accumulation_steps * per_device_train_batch_size

###################################################

# Configure label mappings
num_classes = len(class_labels)
id2label = {i: label for i, label in enumerate(class_labels)}
label2id = {label: i for i, label in id2label.items()}

def _mp_fn(rank):
    # Set TPU device inside the function
    device = xm.xla_device()
    world_size = xr.world_size()
    xm.master_print(f"Process {rank}/{world_size} using device {device}")

    # Initialize WandB only for the main TPU process
    if xm.is_master_ordinal():
        os.environ["WANDB_DISABLED"] = "false"
        wandb.init(project="tpu-segmentation", name=f"TPU-Training-{model_version}")
    else:
        os.environ["WANDB_DISABLED"] = "true"

    # Load the image processor and model inside _mp_fn
    image_processor = SegformerImageProcessor.from_pretrained(f'nvidia/segformer-{model_version}-finetuned-ade-512-512')

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/segformer-{model_version}-finetuned-ade-512-512",
        num_labels=num_classes,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    ).to(device)

    # Dataset class
    class SegmentationDataset(Dataset):
        def __init__(self, local_dir):
            self.local_dir = local_dir
            self.files = [os.path.join(local_dir, f) for f in os.listdir(local_dir) if f.endswith('.pt')]
            self.image_processor = image_processor  # Use processor loaded inside _mp_fn

        def __len__(self):
            return len(self.files)

        def __getitem__(self, idx):
            local_path = self.files[idx]

            if not os.path.exists(local_path):
                raise FileNotFoundError(f"File not found: {local_path}")

            try:
                data = torch.load(local_path)
            except Exception as e:
                # Log the error and skip the problematic file
                print(f"Error loading file {local_path}: {e}")
                return None

            data = torch.load(local_path)
            inputs = self.image_processor(images=data['images'], return_tensors="pt")
            pixel_values = inputs['pixel_values'].squeeze(0)
            label = data['labels'].squeeze().long()
            return {"pixel_values": pixel_values, "labels": label}

    # Create dataset instances inside _mp_fn
    train_dataset = SegmentationDataset(f"{local_data_dir}/train")
    eval_dataset = SegmentationDataset(f"{local_data_dir}/eval")

    # Distributed samplers (drop_last=True to prevent hanging)
    train_sampler = DistributedSampler(
        train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=True, drop_last=True
    )
    eval_sampler = DistributedSampler(
        eval_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False, drop_last=True
    )

    # Safe TPU DataLoader setup
    def worker_init_fn(worker_id):
        """Ensures each worker has a different random seed"""
        torch.manual_seed(worker_id + rank)

    train_dataloader = DataLoader(
        train_dataset, batch_size=per_device_train_batch_size, sampler=train_sampler,
        num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
    )
    eval_dataloader = DataLoader(
        eval_dataset, batch_size=per_device_eval_batch_size, sampler=eval_sampler,
        num_workers=4, pin_memory=True, persistent_workers=True, worker_init_fn=worker_init_fn
    )

    # Wrap data loaders with MpDeviceLoader for TPU support
    train_dataloader = pl.MpDeviceLoader(train_dataloader, device)
    eval_dataloader = pl.MpDeviceLoader(eval_dataloader, device)

    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        dataloader_num_workers=4,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=10,
        logging_strategy="steps",
        report_to=["wandb"] if xm.is_master_ordinal() else [],
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=3,
        load_best_model_at_end=True,
        push_to_hub=False,
        fp16=False,
        bf16=True,
        run_name=f"desCartes-{model_version}-{per_device_train_batch_size}-{gradient_accumulation_steps}-bf16"
    )

    # Trainer: override standard dataloader methods
    class CustomTrainer(Trainer):
        def get_train_dataloader(self):
            return train_dataloader  # Your manually created DataLoader

        def get_eval_dataloader(self, eval_dataset=None):
            return eval_dataloader  # Your manually created DataLoader

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    trainer.train(resume_from_checkpoint=False)
    xm.rendezvous("training_complete")  # Ensure all TPU processes sync before exit

    if xm.is_master_ordinal():
        wandb.finish()  # Close WandB properly

# Launch TPU training
if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(), start_method='fork')



Process 0/8 using device xla:0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdocuracy[0m ([33mdocuracy-university-of-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).
Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b2-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([5, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b2-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([5, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in 

Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


Epoch,Training Loss,Validation Loss
1,0.7869,0.62173
2,0.4156,0.336733
3,0.3427,0.34712


0,1
eval/loss,█▁▁
eval/runtime,█▁▁
eval/samples_per_second,▁▇█
eval/steps_per_second,▁▇█
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
train/grad_norm,▅▅▅▅█▅▅▄▄▄▃▄▃▄▂▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▆
train/learning_rate,███▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁
train/loss,█▇▇▆▆▅▅▄▄▄▃▃▃▂▃▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁

0,1
eval/loss,0.34712
eval/runtime,2.0303
eval/samples_per_second,18.717
eval/steps_per_second,1.478
total_flos,8.324092922167296e+16
train/epoch,3.0
train/global_step,324.0
train/grad_norm,5.99757
train/learning_rate,0.0
train/loss,0.3427


In [None]:
# @title Visualizing Results { display-mode: "code" }

# Function to display images and predicted masks
def plot_predictions(model, dataset, n_samples=3):
    for i, (images, labels) in enumerate(dataset.take(n_samples)):
        predictions = model(images).logits
        predictions = tf.argmax(predictions, axis=-1)

        for j in range(min(n_samples, len(images))):
            image = images[j].numpy()
            label = labels[j].numpy()
            prediction = predictions[j].numpy()

            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(image)
            axes[0].set_title('Input Image')
            axes[1].imshow(np.argmax(label, axis=-1), cmap='viridis')
            axes[1].set_title('True Label')
            axes[2].imshow(prediction, cmap='viridis')
            axes[2].set_title('Predicted Mask')
            plt.show()

# Display some predictions
plot_predictions(model, val_dataset)


In [None]:
# @title Evaluation Metrics { display-mode: "code" }
from sklearn.metrics import classification_report

# Function to calculate metrics for model evaluation
def evaluate_model(model, dataset):
    all_preds = []
    all_labels = []

    for images, labels in dataset.take(10):  # evaluate on first 10 batches
        predictions = model(images).logits
        preds = tf.argmax(predictions, axis=-1).numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

    # Flatten the lists for classification_report
    all_preds = np.concatenate(all_preds).flatten()
    all_labels = np.concatenate(all_labels).flatten()

    report = classification_report(all_labels, all_preds, output_dict=True)
    return report

# Print evaluation metrics
eval_report = evaluate_model(model, val_dataset)
print("Evaluation Metrics:\n", eval_report)


In [None]:
# @title Model Saving { display-mode: "code" }
# Save the trained model
model.save_pretrained(f'{model_path}/segformer_model')
# Save the image processor
image_processor.save_pretrained(f'{model_path}/image_processor')


In [None]:
# @title Visualizing Training Logs { display-mode: "code" }
import os

# Function to plot training logs
def plot_logs(log_dir='./logs'):
    log_files = [f for f in os.listdir(log_dir) if f.endswith('.json')]

    if len(log_files) == 0:
        print("No log files found.")
        return

    log_file = log_files[0]
    log_path = os.path.join(log_dir, log_file)
    logs = []

    with open(log_path, 'r') as f:
        logs = f.readlines()

    steps, losses = [], []
    for log in logs:
        if 'step' in log and 'loss' in log:
            step = int(log.split('step')[1].split(',')[0].strip())
            loss = float(log.split('loss')[1].split(',')[0].strip())
            steps.append(step)
            losses.append(loss)

    plt.plot(steps, losses)
    plt.xlabel('Training Steps')
    plt.ylabel('Loss')
    plt.title('Training Loss Progress')
    plt.show()

# Plot the training logs
plot_logs()
