In [1]:
# Import system utilities and set module path
import sys
import os
from pathlib import Path

sys.path.append(os.path.abspath(".."))

# Import PyTorch utilities
import torch
from torch.utils.data import DataLoader

# Import dataset handling
from torchgeo.datasets import NAIP, stack_samples


In [2]:
# Import project configs and data handlers
from configs import config
from data.kc import KaneCounty
from data.sampler import BalancedRandomBatchGeoSampler

# Import visualization and utility tools
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from einops import rearrange
from matplotlib.colors import ListedColormap
from datetime import datetime

In [None]:
# Load NAIP dataset
naip_dataset = NAIP("/net/projects/cmap/data/KC-images")

# Define path to shape file
shape_path = Path(config.KC_SHAPE_ROOT) / config.KC_SHAPE_FILENAME

# Define dataset configuration using config parameters
dataset_config = (
    config.KC_LAYER,
    config.KC_LABELS, 
    config.PATCH_SIZE,
    naip_dataset.crs,
    naip_dataset.res,
)

# Initialize Kane County dataset
kc_dataset = KaneCounty(shape_path, dataset_config)

# Merge NAIP dataset with Kane County dataset
train_dataset = naip_dataset & kc_dataset

In [4]:
# Create a balanced random batch sampler for training
train_sampler = BalancedRandomBatchGeoSampler(
    config={
        "dataset": train_dataset,
        "size": config.PATCH_SIZE,  
        "batch_size": 1,
    }
)

# DataLoader for visualizing training samples
plot_dataloader = DataLoader(
    dataset=train_dataset,  
    batch_sampler=train_sampler,  
    collate_fn=stack_samples,  
    num_workers=config.NUM_WORKERS,  
)

In [None]:
# Iterate through the dataloader and print the shapes of image and mask tensors,
MAX_BATCHES = 2
EXPECTED_DIM = 3
SINGLE_CHANNEL = 1
RGB_CHANNELS = 3

for batch, sample in enumerate(plot_dataloader):
    print(sample["image"].shape, sample["mask"].shape)
    if batch > MAX_BATCHES:
        break

In [6]:
# --- Create a custom colormap from kc_dataset.colors ---
# Sort the keys so that the colormap is consistent
sorted_keys = sorted(kc_dataset.colors.keys())
# Normalize the colors (RGB only) to [0,1]
color_list = [tuple(np.array(kc_dataset.colors[k][:3]) / 255.0) for k in sorted_keys]
custom_cmap = ListedColormap(color_list)

# --- Extract one sample from the dataloader ---
sample = next(iter(plot_dataloader))
img_tensor = sample["image"][0]  # shape: (channels, height, width)
mask_tensor = sample["mask"][0]  # shape: (height, width) or (1, height, width)

if mask_tensor.dim() == EXPECTED_DIM and mask_tensor.shape[0] == SINGLE_CHANNEL:
    mask_tensor = mask_tensor.squeeze(0)

In [7]:
# Use only the first 3 channels for RGB visualization
img_rgb = (
    img_tensor[:RGB_CHANNELS] if img_tensor.shape[0] >= RGB_CHANNELS else img_tensor
)

# Rearrange the image to (height, width, channels) and convert to NumPy uint8
img_rgb_np = rearrange(img_rgb, "c h w -> h w c").cpu().numpy().astype("uint8")
# Convert mask to NumPy array for plotting
mask_np = mask_tensor.cpu().numpy()

In [None]:
# --- Create two subplots: RGB Image & True Label Mask ---
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# 1. Display RGB Image (No Mask Overlay)
axs[0].imshow(img_rgb_np, interpolation="none")
axs[0].set_title("RGB Image")
axs[0].axis("off")

# 2. Display True Label Mask (solid colors, no alpha blending)
axs[1].imshow(
    mask_np, cmap=custom_cmap, interpolation="none", vmin=0, vmax=len(sorted_keys) - 1
)
axs[1].set_title("True Label Mask")
axs[1].axis("off")

In [9]:
# --- Create a legend for the True Label Mask ---
legend_handles = []
unique_vals_in_mask = set(
    np.unique(mask_np)
)  # Only include labels that appear in the image

for k in sorted_keys:
    if k in unique_vals_in_mask:
        label_name = kc_dataset.labels_inverse.get(k, str(k))
        color = tuple(np.array(kc_dataset.colors[k][:3]) / 255.0)
        patch = mpatches.Patch(color=color, label=label_name)
        legend_handles.append(patch)

In [None]:
# Add legend below the True Label Mask (only for present labels)
if legend_handles:
    fig.legend(
        handles=legend_handles, loc="lower center", ncol=4, bbox_to_anchor=(0.5, -0.1)
    )

plt.tight_layout()
plt.show()

# Segmentation code:

In [11]:
# Ensure the module path is correctly added
sys.path.append("/home/gregoryc25/CMAP/segment_anything_source_code")

# Import necessary components
from segment_anything.build_sam import sam_model_registry
from segment_anything.predictor import SamPredictor

# Define checkpoint path
home_dir = Path.home()
sam_checkpoint = home_dir / "CMAP/segment_anything_source_code/sam_vit_h.pth"

# Ensure the checkpoint file exists
if not sam_checkpoint.exists():
    raise FileNotFoundError(f"Checkpoint file not found: {sam_checkpoint}")

# Load the model
sam = sam_model_registry["vit_h"](checkpoint=str(sam_checkpoint))  # Ensure the path is a string
predictor = SamPredictor(sam)


In [12]:
# 1) Get unique values from the mask (excluding background=0)
unique_vals = torch.unique(mask_tensor)
valid_labels = unique_vals[unique_vals > 0]  # Exclude background (0)

if len(valid_labels) == 0:
    raise ValueError("No valid foreground labels found!")

# 2) Pick a random label from the valid set
random_label_idx = torch.randint(0, len(valid_labels), (1,)).item()
chosen_label = valid_labels[random_label_idx].item()

# 3) Gather all pixel coordinates of the chosen label
ys, xs = torch.where(mask_tensor == chosen_label)

In [None]:
if len(xs) > 0 and len(ys) > 0:
    # 4) Pick a random point within that label
    random_index = torch.randint(0, len(xs), (1,)).item()
    seed_x = xs[random_index].item()
    seed_y = ys[random_index].item()

    print(f"Selected label: {chosen_label}")
    print(f"Seed coordinate: (x={seed_x}, y={seed_y})")

    # Sanity check
    if mask_tensor[seed_y, seed_x] != chosen_label:
        raise ValueError("Seed point not inside the chosen label!")
else:
    raise ValueError("No valid pixel found for chosen label!")

In [None]:
# Ensure we're using the same stored seed point
print(f"Using  segmentation seed coordinate: (x={seed_x}, y={seed_y})")

# Create the seed coordinate array in the expected shape (1, 2)
seed_coordinate = np.array([[seed_x, seed_y]])
seed_label = np.array([1])  # Positive prompt

# Set the image before predicting
predictor.set_image(img_rgb_np)

# --- Call the SAM predictor with the **same seed** ---
masks, scores, logits = predictor.predict(
    point_coords=seed_coordinate,
    point_labels=seed_label,
    multimask_output=False,  # Single best mask
)

# Select the **highest scoring mask**
best_mask = masks[0]
best_score = scores[0]

In [None]:
# --- Create Three Subplots: RGB, True Label Mask, Segmentation Mask ---
fig, axs = plt.subplots(1, 3, figsize=(18, 6))

# 1. Display RGB Image with Seed Point
axs[0].imshow(img_rgb_np, interpolation="none")
axs[0].scatter(seed_x, seed_y, color="red", s=50, marker="o", edgecolors="black")
axs[0].set_title("RGB Image")
axs[0].axis("off")

# 2. Display True Label Mask with Seed Point (Solid Colors)
axs[1].imshow(
    mask_np, cmap=custom_cmap, interpolation="none", vmin=0, vmax=len(sorted_keys) - 1
)

In [None]:
axs[1].scatter(seed_x, seed_y, color="red", s=50, marker="o", edgecolors="black")
axs[1].set_title("True Label Mask")
axs[1].axis("off")

# 3. Display Best Segmentation Mask with Seed Point
axs[2].imshow(best_mask, cmap="gray", interpolation="none")
axs[2].scatter(seed_x, seed_y, color="red", s=50, marker="o", edgecolors="black")
axs[2].set_title(f"Segmentation Mask\nScore: {best_score:.2f}")
axs[2].axis("off")

In [None]:
output_folder = home_dir / "CMAP/segment-anything/kc_sam_outputs"
output_folder.mkdir(parents=True, exist_ok=True)

# Generate a unique filename (timestamp-based)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_folder / f"segmentation_{timestamp}.png"

# Save the figure
plt.savefig(output_path, bbox_inches="tight", dpi=300)

plt.tight_layout()
plt.show()

print(f"Saved segmentation result to: {output_folder}")

In [None]:
# Per-class & Instance IoU
sys.path.append(os.path.abspath(".."))
from utils.kc_visualizations_helper import compute_iou_per_class, compute_instance_iou

# Instance IoU
iou_instance = compute_instance_iou(best_mask, mask_np, chosen_label)
print(f"Instance {chosen_label} IoU: {iou_instance:.4f}")