# Mask Visualization Tool

This notebook provides an interactive tool to visualize segmentation masks and their scales from the dataset created by `make_segmentation_dataset.py`.

## Features:
- **Image Selection**: Dropdown to choose which image/camera view to display
- **Interactive Navigation**: Slider and buttons to navigate through masks sorted by scale (smallest to largest)
- **Scale Display**: Shows the current mask's scale value and position in the sequence
- **Visual Overlay**: Displays masks as semi-transparent red overlays on the original images

## Usage:
1. Load your segmentation dataset file (`.pt` file created by `make_segmentation_dataset.py`)
2. Use the dropdown to select an image
3. Use the slider or navigation buttons to browse through masks
4. Observe how the scale values change as you navigate from smallest to largest masks


In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output

# Enable widget support and test immediately
print("🔧 Setting up widget environment...")

# Test widget display immediately to verify it works
test_widget = widgets.HTML(value="<b style='color: green;'>✅ Widgets are working!</b>")
display(test_widget)

# Set matplotlib backend
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)

print("📦 All imports completed successfully")
print("If you can see the green checkmark above, widgets are working properly!")


🔧 Setting up widget environment...


HTML(value="<b style='color: green;'>✅ Widgets are working!</b>")

📦 All imports completed successfully
If you can see the green checkmark above, widgets are working properly!


In [2]:
class MaskVisualizer:
    def __init__(self, segmentation_data_path):
        """Initialize the mask visualizer with segmentation dataset."""
        self.data = torch.load(segmentation_data_path, map_location='cpu')
        self.scales = self.data['scales']
        self.mask_ids = self.data['mask_ids']
        self.mask_cdfs = self.data['mask_cdfs']
        self.images = self.data['images']
        self.cam_to_worlds = self.data['cam_to_worlds']
        self.intrinsics = self.data['intrinsics']

        self.num_images = len(self.images)
        self.current_image_idx = 0
        self.current_mask_idx = 0

        # UI components
        self.image_selector = None
        self.mask_slider = None
        self.scale_label = None
        self.output_area = None

        print(f"Loaded segmentation data with {self.num_images} images")

    def get_masks_for_image(self, image_idx):
        """Extract individual masks from pixel_to_mask_id tensor for given image."""
        pixel_to_mask_id = self.mask_ids[image_idx]  # [H, W, MM]
        scales = self.scales[image_idx]  # [M]

        # Get unique mask IDs (excluding -1 which means no mask)
        unique_mask_ids = torch.unique(pixel_to_mask_id)
        unique_mask_ids = unique_mask_ids[unique_mask_ids >= 0]  # Remove -1

        masks = []
        mask_scales = []

        for mask_id in unique_mask_ids:
            # Create binary mask for this ID
            mask = (pixel_to_mask_id == mask_id).any(dim=-1)  # [H, W]
            masks.append(mask)

            # Get the scale for this mask ID
            if mask_id < len(scales):
                mask_scales.append(scales[mask_id].item())
            else:
                mask_scales.append(0.0)  # Fallback

        # Sort by scale (smallest to largest)
        if len(mask_scales) > 0:
            sorted_indices = np.argsort(mask_scales)
            masks = [masks[i] for i in sorted_indices]
            mask_scales = [mask_scales[i] for i in sorted_indices]

        return masks, mask_scales

    def update_display(self, change=None):
        """Update the visualization when image or mask selection changes."""
        with self.output_area:
            clear_output(wait=True)

            # Get current image and masks
            image = self.images[self.current_image_idx].numpy()
            masks, scales = self.get_masks_for_image(self.current_image_idx)

            if len(masks) == 0:
                fig, ax = plt.subplots(1, 1, figsize=(12, 8))
                ax.text(0.5, 0.5, 'No masks found for this image',
                       transform=ax.transAxes, ha='center', va='center', fontsize=16)
                ax.set_title(f"Image {self.current_image_idx + 1}/{self.num_images}")
                self.scale_label.value = "<b style='color: orange;'>No masks available</b>"
                plt.show()
                return

            # Update mask slider range if needed
            if self.mask_slider.max != len(masks) - 1:
                self.mask_slider.max = len(masks) - 1
                self.mask_slider.value = min(self.current_mask_idx, len(masks) - 1)

            # Clamp current mask index
            self.current_mask_idx = min(self.current_mask_idx, len(masks) - 1)

            # Create visualization
            fig, ax = plt.subplots(1, 1, figsize=(12, 8))
            ax.imshow(image)

            # Overlay current mask
            if self.current_mask_idx < len(masks):
                mask = masks[self.current_mask_idx].numpy()
                scale = scales[self.current_mask_idx]

                # Create colored overlay for the mask
                masked_overlay = np.zeros((*mask.shape, 4))
                masked_overlay[mask] = [1, 0, 0, 0.4]  # Semi-transparent red
                ax.imshow(masked_overlay)

                # Update scale label
                self.scale_label.value = f"<b>Scale: {scale:.4f}</b> (Mask {self.current_mask_idx + 1}/{len(masks)})"

            ax.set_title(f"Image {self.current_image_idx + 1}/{self.num_images}")
            ax.axis('off')
            plt.tight_layout()
            plt.show()

    def on_image_change(self, change):
        """Handle image selection change."""
        self.current_image_idx = change['new']
        self.current_mask_idx = 0  # Reset to first mask
        self.update_display()

    def on_mask_change(self, change):
        """Handle mask selection change."""
        self.current_mask_idx = change['new']
        self.update_display()

    def create_widgets(self):
        """Create and display the interactive widgets."""
        # Image selector dropdown
        self.image_selector = widgets.Dropdown(
            options=[(f"Image {i+1}", i) for i in range(self.num_images)],
            value=0,
            description='Image:',
            style={'description_width': 'initial'}
        )
        self.image_selector.observe(self.on_image_change, names='value')

        # Get initial masks to set up slider
        initial_masks, initial_scales = self.get_masks_for_image(0)
        max_masks = len(initial_masks) if initial_masks else 0

        # Mask slider
        self.mask_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=max(0, max_masks - 1),
            step=1,
            description='Mask:',
            continuous_update=True,
            layout=widgets.Layout(width='400px'),
            style={'description_width': 'initial'}
        )
        self.mask_slider.observe(self.on_mask_change, names='value')

        # Scale label
        self.scale_label = widgets.HTML(value="<b>Scale: Loading...</b>")

        # Navigation buttons
        prev_button = widgets.Button(
            description="◀ Previous",
            layout=widgets.Layout(width='100px'),
            button_style='info'
        )
        next_button = widgets.Button(
            description="Next ▶",
            layout=widgets.Layout(width='100px'),
            button_style='info'
        )

        def prev_mask(b):
            if self.mask_slider.value > 0:
                self.mask_slider.value -= 1

        def next_mask(b):
            if self.mask_slider.value < self.mask_slider.max:
                self.mask_slider.value += 1

        prev_button.on_click(prev_mask)
        next_button.on_click(next_mask)

        # Output area for plots
        self.output_area = widgets.Output()

        # Layout
        controls = widgets.VBox([
            widgets.HTML("<h3>🎛️ Controls</h3>"),
            self.image_selector,
            widgets.HBox([prev_button, self.mask_slider, next_button]),
            self.scale_label,
            widgets.HTML("<hr>"),
            self.output_area
        ])

        display(controls)

        # Initial display
        self.update_display()

        return controls

print("✅ MaskVisualizer class defined successfully")


✅ MaskVisualizer class defined successfully


In [10]:
# Configuration - Update this path to your segmentation dataset file
SEGMENTATION_DATA_PATH = "../segmentation_ramen3.pt"

print(f"📁 Will load segmentation data from: {SEGMENTATION_DATA_PATH}")
print("📋 Make sure this path points to a .pth file created by make_segmentation_dataset.py")


📁 Will load segmentation data from: ../segmentation_ramen3.pt
📋 Make sure this path points to a .pth file created by make_segmentation_dataset.py


In [None]:
# Create and run the mask visualizer
try:
    # Initialize the visualizer
    print("🚀 Loading segmentation data...")
    visualizer = MaskVisualizer(SEGMENTATION_DATA_PATH)

    # Create and display the interactive widgets
    print("🎨 Creating interactive interface...")
    controls = visualizer.create_widgets()

    print("\n✅ Interactive mask visualizer loaded successfully!")
    print("Use the controls above to navigate through images and masks.")
    print("Masks are sorted from smallest to largest scale values.")

except FileNotFoundError:
    print(f"❌ Error: Could not find segmentation data file at {SEGMENTATION_DATA_PATH}")
    print("Please update SEGMENTATION_DATA_PATH to point to your segmentation dataset file.")
except Exception as e:
    print(f"❌ Error loading segmentation data: {e}")
    print("Make sure the file was created by make_segmentation_dataset.py and contains the expected data structure.")


🚀 Loading segmentation data...


  self.data = torch.load(segmentation_data_path, map_location='cpu')


Loaded segmentation data with 131 images
🎨 Creating interactive interface...


VBox(children=(HTML(value='<h3>🎛️ Controls</h3>'), Dropdown(description='Image:', options=(('Image 1', 0), ('I…


✅ Interactive mask visualizer loaded successfully!
Use the controls above to navigate through images and masks.
Masks are sorted from smallest to largest scale values.
