# 🤖 Chapter 13: 3D Deep Learning with GeoSAM

Traditional segmentation (clustering) relies on geometry. Deep Learning models like **SAM (Segment Anything Model)** allow us to segment based on semantic features in images, and then project those segmentations to 3D.

**Workflow:**
1.  **Project** 3D point cloud to a 2D spherical image (panorama).
2.  **Segment** the image using SAM (Meta's Foundation Model).
3.  **Back-project** the 2D masks to 3D points.

**Note:** This chapter requires `torch` and `segment_anything` libraries, and the SAM checkpoint weights.

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import laspy
import torch
import time

## 1. Load Data

We load a LiDAR scan.

In [None]:
filename = "../DATA/ITC_BUILDING.las"

try:
    las = laspy.read(filename)
    coords = np.vstack((las.x, las.y, las.z)).transpose()
    # Extract color if available, else standard grey
    try:
        r = (las.red / 65535 * 255).astype(int)
        g = (las.green / 65535 * 255).astype(int)
        b = (las.blue / 65535 * 255).astype(int)
        colors = np.vstack((r, g, b)).transpose()
    except:
        colors = np.ones_like(coords) * 128
        
    print(f"Loaded {len(coords)} points.")
except Exception as e:
    print(f"Error loading data: {e}")
    # Dummy data for demonstration
    coords = np.random.rand(1000, 3) * 100
    colors = np.random.randint(0, 255, (1000, 3))

## 2. Generate Spherical Image (3D to 2D)

We obtain a 360° panoramic view from the center of the point cloud.

In [None]:
def generate_spherical_image(center_coordinates, point_cloud, colors, resolution_y=500):
    translated_points = point_cloud - center_coordinates

    # Spherical coordinates
    theta = np.arctan2(translated_points[:, 1], translated_points[:, 0])
    phi = np.arccos(translated_points[:, 2] / (np.linalg.norm(translated_points, axis=1) + 1e-6))

    # Map to pixels
    x = (theta + np.pi) / (2 * np.pi) * (2 * resolution_y)
    y = phi / np.pi * resolution_y

    resolution_x = 2 * resolution_y
    image = np.zeros((resolution_y, resolution_x, 3), dtype=np.uint8)
    mapping = np.full((resolution_y, resolution_x), -1, dtype=int)

    # Z-buffer check (keep closest point)
    dists = np.linalg.norm(translated_points, axis=1)
    
    # Vectorized or simple loop (loop is slow but clear for demo)
    # Ideally vectorized, but here we stick to the provided logic for clarity
    for i in range(len(translated_points)):
        ix = np.clip(int(x[i]), 0, resolution_x - 1)
        iy = np.clip(int(y[i]), 0, resolution_y - 1)
        
        existing_idx = mapping[iy, ix]
        if existing_idx == -1 or dists[i] < dists[existing_idx]:
            mapping[iy, ix] = i
            image[iy, ix] = colors[i]

    return image, mapping

# Center of the building (approximate)
center = np.mean(coords, axis=0)
start = time.time()
spherical_img, pixel_to_point_map = generate_spherical_image(center, coords, colors, resolution_y=400)
print(f"Projection took {time.time() - start:.2f}s")

plt.figure(figsize=(10, 5))
plt.imshow(spherical_img)
plt.title("Spherical Projection")
plt.axis('off')
plt.show()

## 3. Segment with SAM

We feed this 2D image to SAM. SAM returns masks (segmentations) for "everything" it sees.

In [None]:
# Pseudo-code if SAM is not installed or weights missing
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
    
    # You need to download the checkpoint manually
    CHECKPOINT_PATH = "../../MODELS/sam_vit_h_4b8939.pth"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Loading SAM on {DEVICE}...")
    sam = sam_model_registry["vit_h"](checkpoint=CHECKPOINT_PATH)
    sam.to(device=DEVICE)
    
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(spherical_img)
    
    print(f"Generated {len(masks)} masks.")
    
    # Visualize masks on image
    # (Visualization code would go here)
    
except ImportError:
    print("scale_anything library not installed.")
except Exception as e:
    print(f"Could not run SAM: {e}")
    masks = [] # Empty result

## 4. Back-Project to 3D

We use the `pixel_to_point_map` to assign segment IDs back to the original 3D points.

In [None]:
if len(masks) > 0:
    # Assign a random color to each point based on its mask
    segmented_colors = np.copy(colors)
    
    for mask_data in masks:
        segmentation = mask_data['segmentation'] # boolean 2D array
        
        # Random color for this segment
        rnd_color = np.random.randint(0, 255, 3)
        
        # Get pixels belonging to this mask
        ys, xs = np.where(segmentation)
        
        # Find corresponding 3D points
        point_indices = pixel_to_point_map[ys, xs]
        valid_indices = point_indices[point_indices != -1]
        
        # Colorize
        segmented_colors[valid_indices] = rnd_color
        
    # Save/Visualize
    # To visualize in Open3D, we'd create a new PCD
    import open3d as o3d
    pcd_seg = o3d.geometry.PointCloud()
    pcd_seg.points = o3d.utility.Vector3dVector(coords)
    pcd_seg.colors = o3d.utility.Vector3dVector(segmented_colors / 255.0)
    o3d.visualization.draw_geometries([pcd_seg], window_name="GeoSAM Result")
    
else:
    print("No masks generated (SAM skipped).")