# Segmentation Visualization Demo

This notebook demonstrates how to visualize Sam3 segmentation results using the `visualization_utils` module.

Features:
- Convert RLE/Polygon masks to dense format
- Overlay masks on images
- Create grid visualizations
- Side-by-side comparisons
- Polygon visualization
- Comprehensive montages

In [None]:
# Imports
import sys
sys.path.append('.')  # Add examples directory to path

from visualization_utils import (
    visualize_sam3_results,
    create_mask_montage,
    show_mask_statistics,
    rle_to_dense,
    polygon_to_dense,
    visualize_masks_overlay,
    visualize_masks_grid,
    visualize_polygons
)

import requests
import base64
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

## 1. Load Image and Get Segmentation Results

In [None]:
# Load your image
image_path = "image.jpg"  # Update this path
image = Image.open(image_path)

# Display original image
plt.figure(figsize=(10, 6))
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()

print(f"Image size: {image.size}")

In [None]:
# Encode image to base64 (WebP for efficiency)
import io

def encode_to_webp(img, quality=85):
    if img.mode not in ("RGB", "RGBA"):
        img = img.convert("RGB")
    buffered = io.BytesIO()
    img.save(buffered, format="WEBP", quality=quality)
    return base64.b64encode(buffered.getvalue()).decode()

image_b64 = encode_to_webp(image)

In [None]:
# Get segmentation results from API
response = requests.post(
    "http://localhost:8000/inference/sam3",
    json={
        "images": [image_b64],
        "prompts": ["food", "plate"],  # Your prompts here
        "threshold": 0.5,
        "output_format": "rle"  # or "polygons" or "dense"
    }
)

result = response.json()
print(f"Processing time: {result['processing_time']:.2f}s")
print(f"Model: {result['model_id']}")

# Get results for first image
image_results = result['results'][0]

## 2. Show Mask Statistics

In [None]:
# Print statistics
show_mask_statistics(image_results, output_format='rle')

## 3. Overlay Visualization (Most Common)

In [None]:
# Quick visualization with overlay
fig = visualize_sam3_results(
    image,
    image_results,
    output_format='rle',
    mode='overlay',
    alpha=0.5,
    figsize=(14, 10)
)
plt.show()

# Save if needed
# fig.savefig('segmentation_overlay.png', dpi=150, bbox_inches='tight')

## 4. Grid Visualization (Individual Masks)

In [None]:
# Show each mask separately in a grid
fig = visualize_sam3_results(
    image,
    image_results,
    output_format='rle',
    mode='grid',
    figsize=(15, 10)
)
plt.show()

## 5. Side-by-Side Comparison

In [None]:
# Original vs Segmented side by side
fig = visualize_sam3_results(
    image,
    image_results,
    output_format='rle',
    mode='side_by_side',
    alpha=0.6
)
plt.show()

## 6. Comprehensive Montage (3 Panels)

In [None]:
# Create 3-panel montage
fig = create_mask_montage(
    image,
    image_results,
    output_format='rle'
)
plt.show()

## 7. Working with Different Output Formats

### RLE Format

In [None]:
# Example: Convert RLE to dense mask manually
if image_results and 'result' in image_results[0]:
    result_data = image_results[0]['result']
    
    if 'masks_rle' in result_data:
        rle_mask = result_data['masks_rle'][0]  # First mask
        
        # Convert to dense
        dense_mask = rle_to_dense(rle_mask)
        
        # Display
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        axes[0].imshow(image)
        axes[0].set_title('Original')
        axes[0].axis('off')
        
        axes[1].imshow(dense_mask, cmap='gray')
        axes[1].set_title('Mask (from RLE)')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"Mask shape: {dense_mask.shape}")
        print(f"Mask area: {dense_mask.sum()} pixels")

### Polygon Format

In [None]:
# Get polygon results
response_poly = requests.post(
    "http://localhost:8000/inference/sam3",
    json={
        "images": [image_b64],
        "prompts": ["person"],
        "threshold": 0.5,
        "output_format": "polygons"
    }
)

poly_results = response_poly.json()['results'][0]

# Visualize polygons
fig = visualize_sam3_results(
    image,
    poly_results,
    output_format='polygons',
    mode='overlay',
    alpha=0.4
)
plt.show()

## 8. Manual Visualization with Custom Colors

In [None]:
# Extract all masks
all_masks = []
all_labels = []

for entry in image_results:
    prompt = entry['prompt']
    masks_rle = entry['result'].get('masks_rle', [])
    
    for i, rle_mask in enumerate(masks_rle):
        mask = rle_to_dense(rle_mask)
        all_masks.append(mask)
        all_labels.append(f"{prompt} #{i+1}")

print(f"Total masks: {len(all_masks)}")

In [None]:
# Visualize with custom colors
custom_colors = [
    (1.0, 0.0, 0.0),  # Red
    (0.0, 1.0, 0.0),  # Green
    (0.0, 0.0, 1.0),  # Blue
    (1.0, 1.0, 0.0),  # Yellow
    (1.0, 0.0, 1.0),  # Magenta
    (0.0, 1.0, 1.0),  # Cyan
]

fig = visualize_masks_overlay(
    np.array(image),
    all_masks,
    labels=all_labels,
    colors=custom_colors,
    alpha=0.5,
    show_boxes=True,
    show_labels=True,
    figsize=(14, 10)
)
plt.show()

## 9. Export Masks as Individual Images

In [None]:
# Save individual masks
for i, mask in enumerate(all_masks):
    # Convert to PIL Image
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    
    # Save
    # mask_img.save(f'mask_{i}.png')
    
    # Display
    plt.figure(figsize=(6, 4))
    plt.imshow(mask, cmap='gray')
    plt.title(f"Mask {i}: {all_labels[i]}")
    plt.axis('off')
    plt.show()
    
    if i >= 2:  # Limit to first 3 for demo
        break

## 10. Analyze Mask Properties

In [None]:
# Analyze each mask
for i, (mask, label) in enumerate(zip(all_masks, all_labels)):
    # Calculate properties
    area = mask.sum()
    total_pixels = mask.size
    coverage = (area / total_pixels) * 100
    
    # Bounding box
    rows, cols = np.where(mask > 0)
    if len(rows) > 0:
        y1, y2 = rows.min(), rows.max()
        x1, x2 = cols.min(), cols.max()
        bbox_area = (x2 - x1) * (y2 - y1)
        
        print(f"{label}:")
        print(f"  Area: {area:,} pixels ({coverage:.2f}% of image)")
        print(f"  Bounding box: ({x1}, {y1}) to ({x2}, {y2})")
        print(f"  BBox size: {x2-x1}x{y2-y1}")
        print()

## Summary

This notebook demonstrated:

1. ✅ Loading images and getting segmentation results
2. ✅ Converting RLE masks to dense format
3. ✅ Overlay visualization
4. ✅ Grid visualization
5. ✅ Side-by-side comparison
6. ✅ Comprehensive montages
7. ✅ Working with different output formats
8. ✅ Custom visualization options
9. ✅ Exporting masks
10. ✅ Analyzing mask properties

**Key Functions:**
- `visualize_sam3_results()` - Quick visualization (recommended)
- `create_mask_montage()` - 3-panel comprehensive view
- `rle_to_dense()` - Convert RLE to binary mask
- `polygon_to_dense()` - Convert polygon to binary mask
- `show_mask_statistics()` - Print mask statistics