# Introduction to PolytopAX

Welcome to PolytopAX! This notebook provides a hands-on introduction to differentiable convex hull computation with JAX.

## What You'll Learn

- Basic convex hull computation
- Understanding differentiable algorithms
- JAX integration and transformations
- Visualizing results

## Prerequisites

- Basic Python knowledge
- Familiarity with NumPy arrays
- Basic understanding of convex geometry (helpful but not required)

## Setup and Imports

In [None]:
# Install PolytopAX if not already installed
# !pip install polytopax

import jax
import jax.numpy as jnp

# For visualization
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon

import polytopax as ptx

# Set up matplotlib
plt.style.use('default')
plt.rcParams['figure.figsize'] = (10, 6)

print(f"PolytopAX version: {ptx.__version__}")
print(f"JAX version: {jax.__version__}")

## 1. Your First Convex Hull

Let's start with a simple example: computing the convex hull of a set of 2D points.

In [None]:
# Create a set of 2D points
points = jnp.array([
    [0.0, 0.0],   # corner
    [1.0, 0.0],   # corner
    [1.0, 1.0],   # corner
    [0.0, 1.0],   # corner
    [0.5, 0.5],   # interior point
    [0.3, 0.7],   # interior point
    [0.8, 0.2],   # interior point
])

print(f"Input points shape: {points.shape}")
print(f"Points:\n{points}")

In [None]:
# Compute the convex hull
hull = ptx.ConvexHull.from_points(points, n_directions=20)

print(f"Hull vertices shape: {hull.vertices.shape}")
print(f"Number of hull vertices: {hull.n_vertices}")
print(f"Hull vertices:\n{hull.vertices}")

### Visualization

In [None]:
def plot_hull(points, hull, title="Convex Hull"):
    """Helper function to visualize convex hull."""
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))

    # Plot original points
    ax.scatter(points[:, 0], points[:, 1], c='red', s=50, alpha=0.7, label='Input points')

    # Plot hull vertices
    hull_vertices = hull.vertices
    ax.scatter(hull_vertices[:, 0], hull_vertices[:, 1], c='blue', s=100,
               marker='s', label='Hull vertices')

    # Draw hull polygon (approximate)
    # For visualization, we'll connect hull vertices in a reasonable order
    if hull_vertices.shape[0] > 2:
        # Sort by angle for better visualization
        center = jnp.mean(hull_vertices, axis=0)
        angles = jnp.arctan2(hull_vertices[:, 1] - center[1], hull_vertices[:, 0] - center[0])
        order = jnp.argsort(angles)
        ordered_vertices = hull_vertices[order]

        polygon = Polygon(ordered_vertices, alpha=0.3, facecolor='blue', edgecolor='blue')
        ax.add_patch(polygon)

    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axis('equal')

    return fig, ax

# Plot the result
plot_hull(points, hull, "First Convex Hull Example")
plt.show()

## 2. Hull Properties

PolytopAX can compute various geometric properties of the convex hull.

In [None]:
# Compute geometric properties
area = hull.volume()  # In 2D, volume() returns area
perimeter = hull.surface_area()  # In 2D, surface_area() returns perimeter
centroid = hull.centroid()
diameter = hull.diameter()
bbox_min, bbox_max = hull.bounding_box()

print("Geometric Properties:")
print(f"  Area: {area:.4f}")
print(f"  Perimeter: {perimeter:.4f}")
print(f"  Centroid: [{centroid[0]:.4f}, {centroid[1]:.4f}]")
print(f"  Diameter: {diameter:.4f}")
print(f"  Bounding box: [{bbox_min[0]:.2f}, {bbox_min[1]:.2f}] to [{bbox_max[0]:.2f}, {bbox_max[1]:.2f}]")

### Point Inclusion Testing

In [None]:
# Test point inclusion
test_points = jnp.array([
    [0.5, 0.5],   # should be inside
    [0.0, 0.0],   # on boundary
    [1.5, 1.5],   # outside
    [-0.1, 0.5],  # outside
])

print("Point Inclusion Tests:")
for i, point in enumerate(test_points):
    is_inside = hull.contains(point)
    distance = hull.distance_to(point)
    print(f"  Point {point}: inside={is_inside}, distance={distance:.4f}")

In [None]:
# Visualize point inclusion
fig, ax = plot_hull(points, hull, "Point Inclusion Testing")

# Add test points
for i, point in enumerate(test_points):
    is_inside = hull.contains(point)
    color = 'green' if is_inside else 'orange'
    marker = 'o' if is_inside else 'x'
    ax.scatter(point[0], point[1], c=color, s=100, marker=marker,
              label=f'Test point {i} ({"inside" if is_inside else "outside"})')

ax.legend()
plt.show()

## 3. Understanding Differentiability

The key feature of PolytopAX is that convex hull operations are differentiable. Let's explore this.

In [None]:
# Define a function that depends on hull properties
def hull_area_function(points):
    """Function that returns the area of the convex hull."""
    hull = ptx.ConvexHull.from_points(points, n_directions=15)
    return hull.volume()

# Compute the area
area = hull_area_function(points)
print(f"Hull area: {area:.4f}")

# Compute gradients with respect to input points
grad_fn = jax.grad(hull_area_function)
gradients = grad_fn(points)

print(f"\nGradients shape: {gradients.shape}")
print(f"Gradients:\n{gradients}")
print("\nGradients show how moving each point affects the hull area")

In [None]:
# Visualize gradients
fig, ax = plot_hull(points, hull, "Gradients of Hull Area")

# Plot gradient vectors
scale = 0.5  # Scale factor for visualization
for i, (point, grad) in enumerate(zip(points, gradients, strict=False)):
    ax.arrow(point[0], point[1], grad[0] * scale, grad[1] * scale,
             head_width=0.03, head_length=0.02, fc='red', ec='red')
    ax.text(point[0] + 0.05, point[1] + 0.05, f'{i}', fontsize=8)

ax.set_title("Gradients of Hull Area w.r.t. Point Positions")
plt.show()

print("Red arrows show gradient directions:")
print("- Longer arrows indicate larger gradients")
print("- Direction shows how to move point to increase area")

## 4. JAX Transformations

PolytopAX works seamlessly with all JAX transformations.

### JIT Compilation for Performance

In [None]:
# JIT compile the hull area function
jit_hull_area = jax.jit(hull_area_function)

# Test that results are the same
area_regular = hull_area_function(points)
area_jit = jit_hull_area(points)

print(f"Regular function: {area_regular:.6f}")
print(f"JIT compiled:     {area_jit:.6f}")
print(f"Difference:       {abs(area_regular - area_jit):.2e}")
print("✅ JIT compilation preserves accuracy")

### Vectorization with vmap

In [None]:
# Create multiple point sets
n_sets = 3
scales = jnp.array([0.5, 1.0, 1.5])
batch_points = scales[:, None, None] * points[None, :, :]

print(f"Batch shape: {batch_points.shape}")
print(f"Processing {n_sets} point sets with scales: {scales}")

# Vectorize the hull area function
batch_area_fn = jax.vmap(hull_area_function)
batch_areas = batch_area_fn(batch_points)

print(f"\nBatch areas: {batch_areas}")

# Verify scaling relationship
# Area should scale as scale^2 in 2D
expected_areas = batch_areas[1] * (scales**2)  # Reference is scale=1.0
print(f"Expected areas (scale^2): {expected_areas}")
print(f"Actual areas:             {batch_areas}")
print(f"Relative errors: {jnp.abs(batch_areas - expected_areas) / expected_areas}")

### Visualization of Batch Results

In [None]:
# Visualize the different scaled point sets
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, (scale, points_scaled) in enumerate(zip(scales, batch_points, strict=False)):
    hull_scaled = ptx.ConvexHull.from_points(points_scaled, n_directions=15)

    ax = axes[i]

    # Plot points and hull
    ax.scatter(points_scaled[:, 0], points_scaled[:, 1], c='red', s=50, alpha=0.7)
    hull_vertices = hull_scaled.vertices
    ax.scatter(hull_vertices[:, 0], hull_vertices[:, 1], c='blue', s=100, marker='s')

    # Draw approximate hull
    if hull_vertices.shape[0] > 2:
        center = jnp.mean(hull_vertices, axis=0)
        angles = jnp.arctan2(hull_vertices[:, 1] - center[1], hull_vertices[:, 0] - center[0])
        order = jnp.argsort(angles)
        ordered_vertices = hull_vertices[order]
        polygon = Polygon(ordered_vertices, alpha=0.3, facecolor='blue')
        ax.add_patch(polygon)

    ax.set_title(f'Scale {scale}: Area = {batch_areas[i]:.3f}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.grid(True, alpha=0.3)
    ax.axis('equal')

plt.tight_layout()
plt.show()

## 5. Exploring Algorithm Parameters

PolytopAX uses direction vector sampling for convex hull approximation. Let's see how different parameters affect the results.

In [None]:
# Test different numbers of directions
direction_counts = [5, 10, 20, 50]

print("Effect of number of directions:")
for n_dirs in direction_counts:
    hull_test = ptx.ConvexHull.from_points(points, n_directions=n_dirs)
    area_test = hull_test.volume()
    n_vertices = hull_test.n_vertices
    print(f"  {n_dirs:2d} directions: {n_vertices:2d} vertices, area = {area_test:.4f}")

print("\nObservation: More directions generally give better approximations")
print("but with diminishing returns and increased computation cost.")

In [None]:
# Visualize the effect of different direction counts
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for i, n_dirs in enumerate(direction_counts):
    hull_test = ptx.ConvexHull.from_points(points, n_directions=n_dirs)

    ax = axes[i]

    # Plot original points
    ax.scatter(points[:, 0], points[:, 1], c='red', s=50, alpha=0.7, label='Input')

    # Plot hull vertices
    hull_vertices = hull_test.vertices
    ax.scatter(hull_vertices[:, 0], hull_vertices[:, 1], c='blue', s=100,
               marker='s', label='Hull vertices')

    # Draw hull
    if hull_vertices.shape[0] > 2:
        center = jnp.mean(hull_vertices, axis=0)
        angles = jnp.arctan2(hull_vertices[:, 1] - center[1], hull_vertices[:, 0] - center[0])
        order = jnp.argsort(angles)
        ordered_vertices = hull_vertices[order]
        polygon = Polygon(ordered_vertices, alpha=0.3, facecolor='blue')
        ax.add_patch(polygon)

    ax.set_title(f'{n_dirs} directions\n{hull_test.n_vertices} vertices, area={hull_test.volume():.3f}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axis('equal')

plt.tight_layout()
plt.show()

## 6. 3D Example

Let's explore convex hulls in 3D.

In [None]:
# Create 3D points (vertices of a cube plus some interior points)
points_3d = jnp.array([
    # Cube vertices
    [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 1.0, 1.0],
    # Interior points
    [0.5, 0.5, 0.5], [0.3, 0.3, 0.3], [0.7, 0.7, 0.7],
])

print(f"3D points shape: {points_3d.shape}")

# Compute 3D convex hull
hull_3d = ptx.ConvexHull.from_points(points_3d, n_directions=30)

print("\n3D Hull Properties:")
print(f"  Vertices: {hull_3d.n_vertices}")
print(f"  Volume: {hull_3d.volume():.4f}")
print(f"  Surface area: {hull_3d.surface_area():.4f}")
print(f"  Centroid: {hull_3d.centroid()}")
print(f"  Diameter: {hull_3d.diameter():.4f}")

In [None]:
# 3D visualization (2D projections)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

projections = [
    ([0, 1], 'XY'),  # xy projection
    ([0, 2], 'XZ'),  # xz projection
    ([1, 2], 'YZ'),  # yz projection
]

for i, (dims, label) in enumerate(projections):
    ax = axes[i]

    # Project points
    points_proj = points_3d[:, dims]
    hull_vertices_proj = hull_3d.vertices[:, dims]

    # Plot projected points
    ax.scatter(points_proj[:, 0], points_proj[:, 1], c='red', s=50, alpha=0.7, label='Input')
    ax.scatter(hull_vertices_proj[:, 0], hull_vertices_proj[:, 1], c='blue', s=100,
               marker='s', label='Hull vertices')

    ax.set_title(f'{label} Projection')
    ax.set_xlabel(label[0])
    ax.set_ylabel(label[1])
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axis('equal')

plt.tight_layout()
plt.show()

print("3D hulls are visualized as 2D projections.")
print("The actual 3D hull has the computed volume and surface area.")

## Summary

In this notebook, you've learned:

1. **Basic Usage**: How to compute convex hulls with PolytopAX
2. **Properties**: Computing area, perimeter, centroid, and testing point inclusion
3. **Differentiability**: How to compute gradients of hull properties
4. **JAX Integration**: Using JIT, grad, and vmap transformations
5. **Parameters**: Effect of direction count on approximation quality
6. **3D Extensions**: Working with higher-dimensional hulls

## Next Steps

- **Optimization**: Use gradients for shape optimization problems
- **Batch Processing**: Process multiple point sets efficiently
- **Advanced Features**: Explore different sampling methods and algorithms
- **Applications**: Apply to machine learning and scientific computing problems

Check out the other notebooks for more advanced topics!