# Fine-Tuning SegFormer (PyTorch) for Stone vs. Separation

This notebook fine-tunes a SegFormer model (e.g., `nvidia/mit-b4`) from the Hugging Face Hub for the task of segmenting stone vs. separation regions.

**Dataset Structure:**
- **Images:** `../data/augmented/images`
- **Masks:**  `../data/augmented/masks`

**Details:**
- Input images are resized to **256 x 256** pixels.
- The model outputs logits that are upsampled to the original image size for visualization.

In [None]:
# Import necessary libraries
import os
import numpy as np
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from datasets import Dataset, Image
from transformers import (
    SegformerForSemanticSegmentation, 
    SegformerImageProcessor, 
    TrainingArguments, 
    Trainer
)
import evaluate

# Set default plot size for visualization
plt.rcParams['figure.figsize'] = (12, 12)

## 1. Settings and Configurations

Define dataset directories, choose the pre-trained model checkpoint, and set training parameters.

In [None]:
# Dataset directories for images and masks
IMAGE_DIR = "../data/augmented/images"
MASK_DIR  = "../data/augmented/masks"

# Select pre-trained model checkpoint from Hugging Face Hub.
# Options: "nvidia/mit-b0", "nvidia/mit-b1", "nvidia/mit-b2", "nvidia/mit-b3", "nvidia/mit-b4", "nvidia/mit-b5"
MODEL_CHECKPOINT = "nvidia/mit-b4"

# Training parameters
VAL_SIZE = 0.1       # 10% of the data will be used for validation
BATCH_SIZE = 2
EPOCHS = 1
LR = 6e-5            # Learning rate
IMG_SIZE = 256       # Resize images to 256 x 256 pixels

## 2. Data Preparation

We load image and mask file paths, perform a train/validation split, and convert them into Hugging Face datasets.
The images are cast to the `Image()` type so they are automatically loaded during transformation.

In [None]:
# List image and mask filenames (ignoring hidden files)
image_files = sorted([f for f in os.listdir(IMAGE_DIR) if not f.startswith('.')])
mask_files  = sorted([f for f in os.listdir(MASK_DIR) if not f.startswith('.')])

# Build full paths for images and masks
images = [os.path.join(IMAGE_DIR, f) for f in image_files]
masks  = [os.path.join(MASK_DIR, f) for f in mask_files]

print(f"Total images found: {len(images)}")
print(f"Total masks found:  {len(masks)}")

# Split the data into training and validation sets
train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SIZE, random_state=42, shuffle=True
)
print(f"Training images: {len(train_images)}")
print(f"Validation images: {len(val_images)}")

In [None]:
def create_dataset(image_paths, mask_paths):
    """Create a Hugging Face dataset from image and mask file paths."""
    dataset = Dataset.from_dict({
        'pixel_values': image_paths,
        'label': mask_paths
    })
    # Cast the columns to Image() to enable auto-loading
    dataset = dataset.cast_column('pixel_values', Image())
    dataset = dataset.cast_column('label', Image())
    return dataset

# Create datasets for training and validation
ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

## 3. Data Transformation

We initialize the `SegformerImageProcessor` to handle resizing, normalization, and label encoding.
The `apply_transforms` function processes each batch of images and masks.

In [None]:
# Initialize the feature extractor from the chosen model checkpoint
feature_extractor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)

def apply_transforms(batch):
    """
    Apply preprocessing to a batch of images and masks:
      - Resizes images to (IMG_SIZE, IMG_SIZE)
      - Normalizes images and processes labels
    """
    images = batch['pixel_values']
    labels = batch['label']
    # Process images and labels: resize, normalize, etc.
    inputs = feature_extractor(images, labels, size=(IMG_SIZE, IMG_SIZE), return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].contiguous()
    inputs["labels"] = inputs["labels"].contiguous()
    return inputs

# Set the transformation for the datasets (batch-level processing)
ds_train.set_transform(apply_transforms)
ds_valid.set_transform(apply_transforms)

## 4. Model Setup

We define the SegFormer model for our segmentation task. In this case:
- **Class 0:** Stone
- **Class 1:** Separation

Label mappings are specified via `id2label` and `label2id`.

In [None]:
# Define label mappings
id2label = {0: "stone", 1: "separation"}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

# Load the pre-trained SegFormer model and modify the head for our segmentation task
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True  # Allows for different output head dimensions
)

## 5. Define Evaluation Metrics

We use the Mean Intersection over Union (Mean IoU) metric to evaluate performance.
The logits output by the model are upsampled to the ground truth mask size before computing the metric.

In [None]:
# Load the Mean IoU metric
metric = evaluate.load('mean_iou')

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        # Ensure the NumPy array is contiguous
        logits_np = np.ascontiguousarray(logits)
        # Convert to a torch tensor and force a reshape to match the original shape
        logits_tensor = torch.from_numpy(logits_np).reshape(logits.shape)
        print(f"Logits shape: {logits_tensor.shape}")
        print("Strides:", logits_tensor.stride())

        # Upsample the logits to the size of the ground-truth mask
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode='bilinear',
            align_corners=False,
        )
        # Get predicted class labels
        pred_labels = logits_tensor.argmax(dim=1).detach().cpu().numpy()
        
        # Compute metrics using the Mean IoU metric
        results = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=0,  
            reduce_labels=feature_extractor.do_reduce_labels,
        )
        
        # Include per-category metrics in the results
        per_category_accuracy = results.pop("per_category_accuracy").tolist()
        per_category_iou = results.pop("per_category_iou").tolist()
        results.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        results.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
        
        return results



## 6. Training Setup

We configure the training using Hugging Face's `TrainingArguments` and initialize the `Trainer`.  
Evaluation and checkpointing occur every 20 steps.

In [None]:
training_args = TrainingArguments(
    output_dir="segformer_stone_finetuned",
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    save_total_limit=3,
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to="none"
)

# Initialize the Trainer with the model, datasets, and metric computation function
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

# Save the fine-tuned model to disk
model.save_pretrained("segformer_stone")

## 7. Inference on the Validation Set

We run inference on the validation set and display:
- The original image
- The ground truth mask
- The predicted mask

In [None]:
for i in range(len(val_images)):
    image_path = val_images[i]
    mask_path = val_masks[i]
    
    # Load image and mask using OpenCV
    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
    print(f"Validation image #{i + 1}")
    
    # Prepare input for the model using the feature extractor
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = torch.tensor(outputs.logits.detach().cpu().numpy(), device=outputs.logits.device)

    
    # Upsample logits to the original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.shape[:2],
        mode="bilinear",
        align_corners=False
    )
    pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
    
    # Plot the original image, ground truth mask, and predicted mask side-by-side
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    
    axes[1].imshow(mask, cmap="gray")
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis("off")
    
    axes[2].imshow(pred_mask, cmap="gray")
    axes[2].set_title("Predicted Mask")
    axes[2].axis("off")
    
    plt.show()

## 8. Inference on the Test Set

If a test set exists, this section processes each test image and saves the predicted masks to disk.

In [None]:
TEST_DIR = "../data/augmented/test_images"
if os.path.exists(TEST_DIR):
    test_images = sorted([os.path.join(TEST_DIR, f) for f in os.listdir(TEST_DIR) if not f.startswith('.')])
    output_dir = "test_predictions"
    os.makedirs(output_dir, exist_ok=True)
    
    for img_path in test_images:
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Upsample logits to original image dimensions
        upsampled_logits = nn.functional.interpolate(
            logits,
            size=image.shape[:2],
            mode="bilinear",
            align_corners=False
        )
        pred_mask = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
        
        # Save the predicted mask using the original image filename
        filename = os.path.basename(img_path)
        save_path = os.path.join(output_dir, f"mask_{filename}")
        plt.imsave(save_path, pred_mask, cmap="gray")
    
    print(f"Predicted masks saved to {output_dir}/")
else:
    print("Test directory not found. Skipping test inference.")