# 06 - Geometric Skeleton Visualization

This notebook introduces the **spatial tree representation** - the "oh shit" upgrade.

Instead of scalar biomass compartments, we now have:
- A fixed binary tree topology (depth-4 = 15 segments)
- Each segment has continuous: length, thickness, alive (soft gate)
- Tips have leaf and flower area
- Positions computed in 2D for visualization

This enables:
- **Spatial light capture**: higher leaves get more light
- **Self-shading**: overlapping canopy blocks light
- **Wind exposure**: exposed branches take more damage
- **Stained-glass aesthetics**: geometric, symbolic tree forms

In [None]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection, PatchCollection
import numpy as np

from sim.skeleton import (
    SkeletonState,
    compute_segment_positions_2d,
    compute_light_capture,
    compute_wind_exposure,
    get_tip_indices,
)

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['font.size'] = 11

## 1. Basic Skeleton Visualization

In [None]:
def draw_skeleton(
    skeleton: SkeletonState,
    ax=None,
    show_leaves: bool = True,
    show_flowers: bool = True,
    title: str = "",
    color_by: str = "alive",  # "alive", "thickness", "light"
):
    """
    Draw tree skeleton in stained-glass style.
    
    Args:
        skeleton: SkeletonState to visualize
        ax: Matplotlib axes (creates new if None)
        show_leaves: Draw leaf patches at tips
        show_flowers: Draw flower patches at tips
        title: Plot title
        color_by: What to color branches by
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    
    x, y = compute_segment_positions_2d(skeleton)
    x, y = np.array(x), np.array(y)
    
    depth = skeleton.depth
    num_segments = skeleton.num_segments
    tip_indices = get_tip_indices(depth)
    
    # Color palette (warm browns for wood)
    trunk_color = '#8B4513'  # saddle brown
    leaf_color = '#228B22'   # forest green
    flower_colors = ['#FF6B6B', '#FF8E8E', '#FFB4B4', '#E74C3C']  # reds/pinks
    
    # Draw branches
    for idx in range(num_segments):
        alive = float(skeleton.alive[idx])
        if alive < 0.1:
            continue
        
        # Get start position (parent's end or origin)
        if idx == 0:
            start_x, start_y = 0.0, 0.0
        else:
            parent = (idx - 1) // 2
            start_x, start_y = x[parent], y[parent]
        
        end_x, end_y = x[idx], y[idx]
        
        # Line width based on thickness
        thickness = float(skeleton.thickness[idx])
        linewidth = 2 + 8 * thickness * alive
        
        # Color by alive-ness (darker = more alive)
        alpha = 0.3 + 0.7 * alive
        
        ax.plot([start_x, end_x], [start_y, end_y], 
                color=trunk_color, linewidth=linewidth, 
                alpha=alpha, solid_capstyle='round')
    
    # Draw leaves at tips
    if show_leaves:
        for tip_idx in tip_indices:
            tip_idx = int(tip_idx)
            leaf_area = float(skeleton.leaf_area[tip_idx])
            alive = float(skeleton.alive[tip_idx])
            
            if leaf_area * alive < 0.01:
                continue
            
            # Draw leaf as ellipse
            leaf_size = 0.1 + 0.3 * np.sqrt(leaf_area * alive)
            
            # Slight random rotation for natural look
            angle = (tip_idx * 30) % 60 - 30
            
            ellipse = mpatches.Ellipse(
                (x[tip_idx], y[tip_idx]),
                width=leaf_size * 0.6,
                height=leaf_size,
                angle=angle,
                facecolor=leaf_color,
                edgecolor='#1B5E20',
                linewidth=1.5,
                alpha=0.4 + 0.5 * alive,
            )
            ax.add_patch(ellipse)
    
    # Draw flowers at tips
    if show_flowers:
        for i, tip_idx in enumerate(tip_indices):
            tip_idx = int(tip_idx)
            flower_area = float(skeleton.flower_area[tip_idx])
            alive = float(skeleton.alive[tip_idx])
            
            if flower_area * alive < 0.01:
                continue
            
            flower_size = 0.05 + 0.15 * np.sqrt(flower_area * alive)
            flower_color = flower_colors[i % len(flower_colors)]
            
            # Draw flower as circle slightly offset from leaf
            offset_x = 0.05 * np.cos(tip_idx)
            offset_y = 0.05 * np.sin(tip_idx)
            
            circle = mpatches.Circle(
                (x[tip_idx] + offset_x, y[tip_idx] + offset_y),
                radius=flower_size,
                facecolor=flower_color,
                edgecolor='#C0392B',
                linewidth=1,
                alpha=0.6 + 0.4 * alive,
            )
            ax.add_patch(circle)
    
    # Ground line
    ax.axhline(0, color='#654321', linewidth=3, alpha=0.5)
    ax.fill_between([-1.5, 1.5], -0.3, 0, color='#8B7355', alpha=0.3)
    
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-0.3, 2.5)
    ax.set_aspect('equal')
    ax.axis('off')
    
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold')
    
    return ax


# Test with initial skeleton
skeleton = SkeletonState.initial(depth=4)
print(f"Skeleton: {skeleton.num_segments} segments, depth {skeleton.depth}")
print(f"Initial trunk length: {float(skeleton.length[0]):.2f}")
print(f"Initial trunk thickness: {float(skeleton.thickness[0]):.2f}")

In [None]:
# Create a "grown" skeleton for visualization
def create_example_skeleton(growth_stage: float = 1.0):
    """Create an example tree at various growth stages."""
    depth = 4
    num_segments = 2**depth - 1
    tip_indices = get_tip_indices(depth)
    
    # Growth decreases with depth (trunk thickest, tips thinnest)
    length = jnp.zeros(num_segments)
    thickness = jnp.zeros(num_segments)
    alive = jnp.zeros(num_segments)
    leaf_area = jnp.zeros(num_segments)
    flower_area = jnp.zeros(num_segments)
    
    for idx in range(num_segments):
        level = int(np.floor(np.log2(idx + 1)))
        depth_factor = 1.0 - 0.2 * level  # Reduces with depth
        
        # Length grows with stage
        base_length = 0.4 * depth_factor
        length = length.at[idx].set(base_length * growth_stage)
        
        # Thickness concentrated in trunk/main branches
        base_thickness = 0.3 * (depth_factor ** 1.5)
        thickness = thickness.at[idx].set(base_thickness * growth_stage)
        
        # Alive depends on growth stage and random variation
        alive_prob = growth_stage * depth_factor
        # Use deterministic "random" based on index
        pseudo_random = (np.sin(idx * 1.5) + 1) / 2
        alive = alive.at[idx].set(min(1.0, alive_prob + 0.2 * pseudo_random))
    
    # Leaves only at tips
    for tip_idx in tip_indices:
        tip_idx = int(tip_idx)
        leaf_amount = growth_stage * (0.5 + 0.5 * np.sin(tip_idx))
        leaf_area = leaf_area.at[tip_idx].set(leaf_amount)
        
        # Flowers appear later
        if growth_stage > 0.6:
            flower_amount = (growth_stage - 0.6) / 0.4 * 0.3 * (1 + np.cos(tip_idx * 2))
            flower_area = flower_area.at[tip_idx].set(max(0, flower_amount))
    
    return SkeletonState(
        length=length,
        thickness=thickness,
        alive=alive,
        leaf_area=leaf_area,
        flower_area=flower_area,
    )


# Show growth stages
fig, axes = plt.subplots(1, 4, figsize=(16, 6))

stages = [0.25, 0.5, 0.75, 1.0]
for ax, stage in zip(axes, stages):
    skeleton = create_example_skeleton(stage)
    draw_skeleton(skeleton, ax=ax, title=f'Growth Stage {stage:.0%}')

plt.suptitle('Tree Development Stages', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## 2. Stained Glass Style

In [None]:
def draw_stained_glass_tree(
    skeleton: SkeletonState,
    ax=None,
    background_color: str = '#FFF8DC',  # Cornsilk
    panel_colors: list = None,
):
    """
    Draw tree in stained-glass style with colored panels.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 12))
    
    if panel_colors is None:
        # Warm autumn palette like the inspiration image
        panel_colors = [
            '#E8B84A',  # Gold
            '#D4A84B',  # Dark gold
            '#4A90A4',  # Teal
            '#E07B54',  # Coral
            '#8BC34A',  # Light green
        ]
    
    # Background
    ax.set_facecolor(background_color)
    
    x, y = compute_segment_positions_2d(skeleton)
    x, y = np.array(x), np.array(y)
    
    depth = skeleton.depth
    num_segments = skeleton.num_segments
    tip_indices = get_tip_indices(depth)
    
    # Draw "lead" lines (black outlines) for branches
    for idx in range(num_segments):
        alive = float(skeleton.alive[idx])
        if alive < 0.1:
            continue
        
        if idx == 0:
            start_x, start_y = 0.0, 0.0
        else:
            parent = (idx - 1) // 2
            start_x, start_y = x[parent], y[parent]
        
        end_x, end_y = x[idx], y[idx]
        
        thickness = float(skeleton.thickness[idx])
        linewidth = 3 + 12 * thickness * alive
        
        # Brown wood with black outline
        ax.plot([start_x, end_x], [start_y, end_y], 
                color='#2C1810', linewidth=linewidth + 2, 
                solid_capstyle='round', zorder=1)
        ax.plot([start_x, end_x], [start_y, end_y], 
                color='#8B4513', linewidth=linewidth, 
                solid_capstyle='round', zorder=2)
    
    # Draw stylized leaves (pointed ovals like the inspiration)
    leaf_colors = ['#C0392B', '#E74C3C', '#27AE60', '#2ECC71', '#F39C12', '#E67E22']
    
    for i, tip_idx in enumerate(tip_indices):
        tip_idx = int(tip_idx)
        leaf_area = float(skeleton.leaf_area[tip_idx])
        alive = float(skeleton.alive[tip_idx])
        
        if leaf_area * alive < 0.01:
            continue
        
        # Leaf size and color
        leaf_size = 0.15 + 0.25 * np.sqrt(leaf_area * alive)
        color = leaf_colors[i % len(leaf_colors)]
        
        # Compute angle pointing outward from trunk
        angle = np.degrees(np.arctan2(y[tip_idx], x[tip_idx] + 0.001))
        
        # Draw leaf as pointed ellipse (stained glass style)
        leaf = mpatches.FancyBboxPatch(
            (x[tip_idx] - leaf_size*0.3, y[tip_idx] - leaf_size*0.5),
            leaf_size * 0.6, leaf_size,
            boxstyle="round,pad=0.02,rounding_size=0.15",
            facecolor=color,
            edgecolor='#1a1a1a',
            linewidth=2,
            alpha=0.9,
            zorder=3,
        )
        
        # Use a simple ellipse for cleaner look
        ellipse = mpatches.Ellipse(
            (x[tip_idx], y[tip_idx]),
            width=leaf_size * 0.5,
            height=leaf_size,
            angle=angle - 90,
            facecolor=color,
            edgecolor='#1a1a1a',
            linewidth=2,
            alpha=0.9,
            zorder=3,
        )
        ax.add_patch(ellipse)
        
        # Leaf vein (center line)
        vein_length = leaf_size * 0.4
        vein_angle = np.radians(angle - 90)
        ax.plot(
            [x[tip_idx] - vein_length * np.sin(vein_angle), 
             x[tip_idx] + vein_length * np.sin(vein_angle)],
            [y[tip_idx] - vein_length * np.cos(vein_angle), 
             y[tip_idx] + vein_length * np.cos(vein_angle)],
            color='#1a1a1a', linewidth=1, alpha=0.5, zorder=4
        )
    
    # Background panels (geometric divisions like stained glass)
    # Simple radial divisions from trunk base
    for i in range(8):
        angle = i * np.pi / 4 + np.pi / 8
        ax.plot([0, 3*np.cos(angle)], [0, 3*np.sin(angle)],
                color='#1a1a1a', linewidth=1, alpha=0.2, zorder=0)
    
    # Ground
    ground_colors = ['#8BC34A', '#4A90A4', '#E07B54']
    for i, (x_start, x_end) in enumerate([(-1.5, -0.3), (-0.3, 0.6), (0.6, 1.5)]):
        rect = mpatches.Rectangle(
            (x_start, -0.5), x_end - x_start, 0.5,
            facecolor=ground_colors[i % len(ground_colors)],
            edgecolor='#1a1a1a',
            linewidth=2,
            alpha=0.7,
        )
        ax.add_patch(rect)
    
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-0.5, 2.5)
    ax.set_aspect('equal')
    ax.axis('off')
    
    return ax


# Draw stained glass tree
fig, ax = plt.subplots(figsize=(10, 12))
skeleton = create_example_skeleton(1.0)
draw_stained_glass_tree(skeleton, ax=ax)
ax.set_title('Arborhedron - Stained Glass Tree', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

## 3. Light Capture and Self-Shading

In [None]:
# Compute and visualize light capture
skeleton = create_example_skeleton(1.0)
light = compute_light_capture(skeleton)

print("Light capture by tip:")
tip_indices = get_tip_indices(skeleton.depth)
for i, tip_idx in enumerate(tip_indices):
    tip_idx = int(tip_idx)
    print(f"  Tip {i}: light={float(light[tip_idx]):.3f}, leaf_area={float(skeleton.leaf_area[tip_idx]):.3f}")

print(f"\nTotal light capture: {float(jnp.sum(light)):.3f}")

In [None]:
# Visualize light capture with color intensity
def draw_with_light(skeleton, ax=None):
    """Draw tree with leaves colored by light capture."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    
    light = compute_light_capture(skeleton)
    x, y = compute_segment_positions_2d(skeleton)
    x, y = np.array(x), np.array(y)
    
    tip_indices = get_tip_indices(skeleton.depth)
    
    # Draw branches
    for idx in range(skeleton.num_segments):
        alive = float(skeleton.alive[idx])
        if alive < 0.1:
            continue
        
        if idx == 0:
            start_x, start_y = 0.0, 0.0
        else:
            parent = (idx - 1) // 2
            start_x, start_y = x[parent], y[parent]
        
        thickness = float(skeleton.thickness[idx])
        linewidth = 2 + 8 * thickness * alive
        
        ax.plot([start_x, x[idx]], [start_y, y[idx]], 
                color='#8B4513', linewidth=linewidth, 
                solid_capstyle='round')
    
    # Draw leaves colored by light
    max_light = float(jnp.max(light)) + 0.001
    
    for tip_idx in tip_indices:
        tip_idx = int(tip_idx)
        leaf_area = float(skeleton.leaf_area[tip_idx])
        alive = float(skeleton.alive[tip_idx])
        light_val = float(light[tip_idx])
        
        if leaf_area * alive < 0.01:
            continue
        
        leaf_size = 0.1 + 0.3 * np.sqrt(leaf_area * alive)
        
        # Color from dark (shaded) to bright (full sun)
        # Green intensity based on light
        light_frac = light_val / max_light
        green = int(100 + 155 * light_frac)
        color = f'#{30:02x}{green:02x}{30:02x}'
        
        ellipse = mpatches.Ellipse(
            (x[tip_idx], y[tip_idx]),
            width=leaf_size * 0.6,
            height=leaf_size,
            facecolor=color,
            edgecolor='#1B5E20',
            linewidth=1.5,
            alpha=0.8,
        )
        ax.add_patch(ellipse)
        
        # Label with light value
        ax.annotate(f'{light_val:.2f}', (x[tip_idx], y[tip_idx]),
                   fontsize=8, ha='center', va='center', color='white', fontweight='bold')
    
    ax.axhline(0, color='#654321', linewidth=3, alpha=0.5)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-0.3, 2.5)
    ax.set_aspect('equal')
    ax.set_title('Light Capture by Leaf Position\n(brighter = more light)', fontsize=12)
    ax.axis('off')
    
    return ax


fig, ax = plt.subplots(figsize=(10, 10))
skeleton = create_example_skeleton(1.0)
draw_with_light(skeleton, ax=ax)
plt.tight_layout()
plt.show()

## 4. Wind Stress Visualization

In [None]:
# Compare trees under different wind levels
fig, axes = plt.subplots(1, 3, figsize=(15, 6))

wind_levels = [0.1, 0.4, 0.8]
skeleton = create_example_skeleton(1.0)

for ax, wind in zip(axes, wind_levels):
    exposure = compute_wind_exposure(skeleton, wind)
    
    x, y = compute_segment_positions_2d(skeleton)
    x, y = np.array(x), np.array(y)
    
    # Draw branches colored by exposure
    max_exp = float(jnp.max(exposure)) + 0.001
    
    for idx in range(skeleton.num_segments):
        alive = float(skeleton.alive[idx])
        if alive < 0.1:
            continue
        
        if idx == 0:
            start_x, start_y = 0.0, 0.0
        else:
            parent = (idx - 1) // 2
            start_x, start_y = x[parent], y[parent]
        
        thickness = float(skeleton.thickness[idx])
        linewidth = 2 + 8 * thickness * alive
        
        # Color by exposure (brown -> red)
        exp_frac = float(exposure[idx]) / max_exp
        red = int(139 + 116 * exp_frac)
        green = int(69 * (1 - exp_frac))
        blue = int(19 * (1 - exp_frac))
        color = f'#{red:02x}{green:02x}{blue:02x}'
        
        ax.plot([start_x, x[idx]], [start_y, y[idx]], 
                color=color, linewidth=linewidth, 
                solid_capstyle='round')
    
    # Draw leaves
    tip_indices = get_tip_indices(skeleton.depth)
    for tip_idx in tip_indices:
        tip_idx = int(tip_idx)
        leaf_area = float(skeleton.leaf_area[tip_idx])
        alive = float(skeleton.alive[tip_idx])
        
        if leaf_area * alive < 0.01:
            continue
        
        leaf_size = 0.1 + 0.3 * np.sqrt(leaf_area * alive)
        
        ellipse = mpatches.Ellipse(
            (x[tip_idx], y[tip_idx]),
            width=leaf_size * 0.6,
            height=leaf_size,
            facecolor='#228B22',
            edgecolor='#1B5E20',
            linewidth=1.5,
            alpha=0.7,
        )
        ax.add_patch(ellipse)
    
    ax.axhline(0, color='#654321', linewidth=3, alpha=0.5)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-0.3, 2.5)
    ax.set_aspect('equal')
    ax.set_title(f'Wind Level: {wind:.1f}\n(red = high exposure)', fontsize=11)
    ax.axis('off')

plt.suptitle('Wind Exposure by Branch Position', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Summary

The geometric skeleton provides:

1. **Spatial representation**: Tree form is explicit, not just scalar biomass
2. **Light competition**: Higher leaves capture more light, shading lower ones
3. **Wind exposure**: Position affects vulnerability
4. **Visual aesthetics**: Stained-glass style matches the Platonic ideal vision

**Next steps:**
- Integrate skeleton dynamics with resource allocation policy
- Train policy to optimize spatial growth under stress
- Animate growth over a season