# Imports

In [None]:
# ============================================================
# Core scientific stack
# ============================================================
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# circle_bundles core API
# ============================================================
from circle_bundles.api import build_bundle
from circle_bundles.bundle import attach_bundle_viz_methods
attach_bundle_viz_methods()


# ============================================================
# Local circular coordinates (optional)
# ============================================================
from dreimac import CircularCoords


# ============================================================
# Cover constructions + base metric
# ============================================================
from circle_bundles.covers.triangle_cover_builders_fibonacci import (
    make_rp2_fibonacci_star_cover,
)
from circle_bundles.base_covers import MetricBallCover
from circle_bundles.metrics import RP1AngleMetric as rp1_metric


# ============================================================
# Synthetic data generation
# ============================================================
from circle_bundles.synthetic.so3_sampling import sample_so3
from circle_bundles.synthetic.meshes import make_tri_prism

from circle_bundles.synthetic.densities import (
    mesh_to_density,
    rotate_density,
    get_mesh_sample,
    get_density_axes,
)


# ============================================================
# Synthetic visualization helpers
# ============================================================
from circle_bundles.synthetic.mesh_vis import (
    make_density_visualizer,
    make_tri_prism_visualizer,
)


# ============================================================
# General visualization utilities (lightweight, non-interactive)
# ============================================================
from circle_bundles.viz.thumb_grids import show_data_vis
from circle_bundles.viz.lattice_vis import lattice_vis


# Generate A Synthetic Dataset

In [None]:
# --- Sanity check: generate one prism density + visualize a few SO(3) rotations ---

height = 1.0
radius = 1.0
grid_size = 32
sigma = 0.05
mesh, face_groups = make_tri_prism(height=height, radius=radius)
density = mesh_to_density(mesh, grid_size=grid_size, sigma=sigma)

vis_density = make_density_visualizer(grid_size=grid_size)
vis_mesh = make_tri_prism_visualizer(mesh, face_groups)

rng = np.random.default_rng(0)
so3_data = sample_so3(n_samples=8, rng=rng)[0]  

density_sample = rotate_density(density, so3_data, grid_size=grid_size)
fig = show_data_vis(
    density_sample, 
    vis_density, 
    max_samples=8, 
    n_cols=8, 
    sampling_method="first")
plt.show()

mesh_sample = get_mesh_sample(mesh, so3_data)
fig = show_data_vis(
    mesh_sample, 
    vis_mesh, 
    max_samples=8, 
    n_cols=8, 
    sampling_method="first", 
    pad_frac=0.3)
plt.show()


In [None]:
# --- Generate the dataset ---

n_samples = 5000
rng = np.random.default_rng(0)
R = sample_so3(n_samples, rng=rng)[0]

mesh_data = get_mesh_sample(mesh, R)
data = rotate_density(density, R, grid_size=grid_size)

print(
    f"Generated {n_samples} SO(3)-rotated prism densities "
    f"represented as {data.shape[1]}-dimensional voxel vectors."
)


# Bundle Analysis

## Open Cover

In [None]:
# --- Compute base projections ---

base_points = get_density_axes(data)
print("Base projection coordinates computed.")


In [None]:
#Construct an open cover of RP2

n_landmarks = 60
cover = make_rp2_fibonacci_star_cover(base_points, n_pairs = n_landmarks)

summ = cover.summarize(plot = True)


In [None]:
#Optional: Plotly visualization of the nerve of the cover constructed above

fig = cover.show_nerve()
plt.show()

## Characteristic Classes

In [None]:
# --- Compute local trivializations + O(2) transitions + characteristic classes ---

bundle = build_bundle(
    data,
    cover,
    CircularCoords_cls=CircularCoords,  # use sparse CC algorithm for local circular coordinates
    show=True,                          # print summary + basic diagnostics
)


In [None]:
#Compute class persistence on the weights filtration of the nerve
pers = bundle.get_persistence(show = True)


## Coordinate Bundle 

In [None]:
#Construct a classifying map to the Stiefel manifold and compute the pullback bundle
pullback_results = bundle.get_pullback_data(
    subcomplex = 'full',
    base_weight=1.0,
    fiber_weight=1.0,
    packing = 'coloring2',
    show_summary = True,
)




In [None]:
#Construct a pullback coordinate bundle object and verify it has the correct classification

pb_bundle = build_bundle(
    pullback_results.total_data,
    cover,
#    CircularCoords_cls=CircularCoords,     # optionally use sparse cc's, but PCA2 is sufficient
    show=True,
    total_metric = pullback_results.metric 
)


# Restriction To The Equator $\mathbb{RP}^{1}\subset \mathbb{RP}^{2}$

In [None]:
# --- Restrict to the RP^1 "equator" in RP^2 ---

eps = 0.15  # thickness of equatorial band (in the chosen RP^2 embedding/coords)

# Points near the equator: last coordinate close to 0
eq_mask = np.abs(base_points[:, -1]) < eps

eq_data = bundle.data[eq_mask]
eq_mesh_data = mesh_data[eq_mask]

# Parametrize the equator by an angle in RP^1 (theta ~ theta + pi)
eq_base_angles = np.arctan2(base_points[eq_mask, 1], base_points[eq_mask, 0]) % np.pi

print(f"Equator band: {eq_data.shape[0]} / {bundle.data.shape[0]} samples (eps={eps}).")

In [None]:
# --- Build an open cover of the base circle (RP^1) ---

n_landmarks = 12
landmarks = np.linspace(0, np.pi, n_landmarks, endpoint=False)

overlap = 1.99
radius = overlap * np.pi / (2 * n_landmarks)

eq_cover = MetricBallCover(
    eq_base_angles,
    landmarks,
    radius,
    metric=rp1_metric(),
)
eq_cover_data = eq_cover.build()

#Show a summary of the construction
eq_summ = eq_cover.summarize(plot = True)

In [None]:
# --- Build the restricted bundle ---

eq_bundle = build_bundle(
    eq_data,
    eq_cover,
#    CircularCoords_cls=CircularCoords,     # optionally use sparse cc's, but PCA2 is sufficient    
    show=True,
)


In [None]:
#Compute global coordinates on equator data using a filtration of the nerve

eq_pers = eq_bundle.get_persistence(show = True)
eq_triv_result = eq_bundle.get_global_trivialization()
print(f'Global trivialization computed.')

In [None]:
#Show a visualization of the coordinatized densities
coords = np.column_stack([eq_base_angles, eq_triv_result.F])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 10), dpi=200)

lattice_vis(
    eq_mesh_data,
    coords,
    vis_mesh,
    per_row=7,
    per_col=7,
    figsize=10,
    thumb_px=100,
    dpi=200,
    ax=ax1,
)

lattice_vis(
    eq_data,
    coords,
    vis_density,
    per_row=7,
    per_col=7,
    figsize=10,
    thumb_px=120,
    dpi=200,
    ax=ax2,
)

ax1.set_title("Equator restriction: coordinates visualized with meshes")
ax2.set_title("Equator restriction: coordinates visualized with densities")
plt.show()


# Additional Visualizations

In [None]:
#Show a visualization of a 'fat fiber' of the projection map
from circle_bundles.viz.fiber_vis import fiber_vis
from circle_bundles.viz.base_vis import base_vis

center_ind = 299
r = 0.2
dist_mat = cover.metric.pairwise(X=cover.base_points)
nearby_indices = np.where(dist_mat[center_ind] < r)[0]

fiber_data = data[nearby_indices]
vis_data = mesh_data[nearby_indices]


fig = plt.figure(figsize=(18, 6), dpi=120)
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
ax3 = fig.add_subplot(1, 3, 3, projection="3d")

# PCA labeled with meshes
fiber_vis(
    fiber_data,
    vis_mesh,
    vis_data=vis_data,
    max_images=200,
    zoom=0.08,
    ax=ax1,
    show=False,
)
ax1.set_title("Fiber PCA (Meshes)")

# PCA labeled with density projections
fiber_vis(
    fiber_data,
    vis_func=vis_density,
    max_images=200,
    zoom=0.05,
    ax=ax2,
    show=False,
)
ax2.set_title("Fiber PCA (Densities)")

# Base visualization
base_vis(
    cover.base_points,
    center_ind,
    r,
    dist_mat,
    use_pca=False,
    ax=ax3,
    show=False,
)
ax3.set_title("Base neighborhood")

plt.tight_layout()
plt.show()


In [None]:
#Show a visualization of the 1-skeleton of the nerve of the cover
from circle_bundles.viz.nerve_vis import nerve_vis

dist_mat = cover.metric.pairwise(X = cover.landmarks, Y = cover.base_points)
inds = np.argmin(dist_mat, axis = 1)
node_data = mesh_data[inds]

node_labels = [f"{i+1}" for i in range(cover.landmarks.shape[0])]

fig, axes = nerve_vis(
    cover,
    cochains={1:bundle.classes.sw1_O1},
    base_colors={0:'black', 1:'black', 2:'pink'},
    cochain_cmaps={1:{1: 'blue', -1:'darkred'}},
    opacity=0,
    node_size=22,
    line_width=1,
    node_labels=None,
    fontsize=8,
    font_color='white',
    vis_func=vis_mesh,
    data=node_data,
    image_zoom=0.065,
    title='1-Skeleton Of The Nerve Of The Cover'
)
plt.show()


In [None]:
# Show a visualization of the maximal subcomplex of N(U) on which the class reps become coboundaries

dist_mat = cover.metric.pairwise(X=cover.landmarks, Y=cover.base_points)
inds = np.argmin(dist_mat, axis=1)
node_data = mesh_data[inds]

# Highlight the open sets which contain equator points
eq_sets = set(map(int, np.flatnonzero(cover.U[:, eq_mask].any(axis=1))))

max_triv = bundle.get_max_trivial_subcomplex()

# --- subcomplex simplices ---
sub_edges = [(int(i), int(j)) for (i, j) in max_triv.kept_edges]
sub_tris  = [(int(i), int(j), int(k)) for (i, j, k) in max_triv.kept_triangles]

# vertices actually appearing in the subcomplex
sub_verts = sorted({i for (a, b) in sub_edges for i in (a, b)} | {i for t in sub_tris for i in t})

# 0-cochain highlighting equator-containing sets (use int keys)
highlight_cochain = {j: int(j in eq_sets) for j in sub_verts}

node_labels = [f"{i+1}" for i in range(cover.landmarks.shape[0])]

fig, ax = nerve_vis(
    cover, 
    vertices=sub_verts,           
    edges=sub_edges,              
    max_dim=1,                    
    cochains={1: bundle.classes.sw1_O1, 0: highlight_cochain},
    base_colors={0: "black", 1: "black", 2: "pink"},
    cochain_cmaps={
        1: {1: "blue", -1: "darkred"},
        0: {0: "darkgray", 1: "black"},
    },
    opacity=0,
    node_size=22,
    line_width=1,
    node_labels=None,
    fontsize=8,
    font_color="white",
    vis_func=vis_mesh,
    data=node_data,
    image_zoom=0.065,
    title='Subcomplex With Equator Sets Highlighted',
)
plt.show()


In [None]:
#Attach extra visualization methods to bundle objects
attach_bundle_viz_methods()

In [None]:
#Show a visualization of the 2-skeleton of the subcomplex 
#on which class representatives restrict to coboundaries

fig = bundle.show_max_trivial(highlight_kept = False, highlight_removed = False)
plt.show()

In [None]:
#Show a visualization of the restricted nerve of the equator cover

#Compute a potential for the restricted orientation class
eq_subcomplex = eq_bundle.get_max_trivial_subcomplex()
edges = eq_subcomplex.kept_edges
Omega = eq_bundle.classes.cocycle_used.restrict(edges)
phi_vec = Omega.orient_if_possible(edges)[2]
phi = {lmk: phi_vec[lmk] for lmk in range(len(eq_cover.landmarks))}
omega = eq_bundle.classes.omega_O1_used

fig = eq_bundle.show_circle_nerve(omega = omega, phi = phi)
plt.show()


