# Imports

In [None]:
# ============================================================
# Core scientific stack
# ============================================================
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# circle_bundles core API
# ============================================================
from circle_bundles.metrics import RP1AngleMetric as rp1_metric
from circle_bundles.base_covers import MetricBallCover
from circle_bundles.api import build_bundle
from circle_bundles.analysis.local_analysis import get_local_rips, plot_local_rips


# ============================================================
# Persistent homology
# ============================================================
from ripser import ripser
from persim import plot_diagrams


# ============================================================
# Local circular coordinates (optional)
# ============================================================
from dreimac import CircularCoords


# ============================================================
# Synthetic datasets
# ============================================================
from circle_bundles.synthetic.nat_img_patches import (
    sample_nat_img_kb,
    get_gradient_dirs,
)

from circle_bundles.synthetic.opt_flow_patches import sample_opt_flow_torus
from circle_bundles.optical_flow.contrast import get_predominant_dirs


# ============================================================
# Visualization utilities
# ============================================================
from circle_bundles.viz.thumb_grids import show_data_vis
from circle_bundles.viz.lattice_vis import lattice_vis
from circle_bundles.optical_flow.patch_viz import make_patch_visualizer

from circle_bundles.bundle import attach_bundle_viz_methods
attach_bundle_viz_methods()


# Torus Model For High-Contrast Optical Flow Patches

In [None]:
#Generate a sampling of the optical flow torus model over RP1

n_flow_patches = 5000
n_flow = 3   #choose a patch size

rng = np.random.default_rng(0)
flow_data = sample_opt_flow_torus(n_flow_patches, dim = n_flow, rng = rng)[0]
print(f'{n_flow_patches} {n_flow}-by-{n_flow} optical flow patches generated.')

#Create a visualizer function
patch_vis = make_patch_visualizer()

In [None]:
#Compute base projections to RP1 (predominant flow direction)
predom_dirs = get_predominant_dirs(flow_data)[0]
print(f'Predominant directions computed.')

In [None]:
#View a random sample of the dataset arranged by gradient direction
n_samples = 30

label_func = [fr"$\theta = {np.round(predom_dir/np.pi, 2)}$" + r"$\pi$" for predom_dir in predom_dirs]

fig = show_data_vis(
    flow_data, 
    patch_vis, 
    label_func = label_func, 
    angles = predom_dirs, 
    sampling_method = 'angle', 
    max_samples = n_samples)
plt.show()


In [None]:
#Run Ripser on a sample of the full dataset

#Compare persistent homology over two different coefficient fields
dgms_2 = ripser(flow_data, coeff=2, maxdim=2, n_perm=500)["dgms"]
dgms_3 = ripser(flow_data, coeff=3, maxdim=2, n_perm=500)["dgms"]

fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

plot_diagrams(dgms_2, ax=axes[0], title="coeff = 2")
plot_diagrams(dgms_3, ax=axes[1], title="coeff = 3")

plt.tight_layout()
plt.show()


In [None]:
#Construct a cover of the base space RP1

n_flow_landmarks = 12
flow_landmarks = np.linspace(0, np.pi,n_flow_landmarks, endpoint= False)
flow_overlap = 1.4

flow_radius = flow_overlap* np.pi/(2*n_flow_landmarks)

flow_cover = MetricBallCover(predom_dirs, flow_landmarks, flow_radius, metric = rp1_metric())
flow_cover_data = flow_cover.build()

flow_summ = flow_cover.summarize(plot=True)  #Show a summary of the constructed cover
plt.show()

In [None]:
#Run persistence on fibers to check for local circular features

fiber_ids, dense_idx_list, rips_list = get_local_rips(
    flow_data,
    flow_cover.U,
    to_view = [0,3,8],
    maxdim=1,
    n_perm=500,
    random_state=None,
)

fig, axes = plot_local_rips(
    fiber_ids,
    rips_list,
    n_cols=3,
    titles='default',
    font_size=20,
)

In [None]:
#Construct local circular coordinates
#and model transitions as O(2) matrices

flow_bundle = build_bundle(
    flow_data,
    flow_cover,
#    CircularCoords_cls=CircularCoords,     #OPTION: use Dreimac for circular coordinates
    show=True,
)


In [None]:
#Get a global toroidal coordinate system for the dataset
flow_triv_result = flow_bundle.get_global_trivialization()
print('Global coordinates computed.')

In [None]:
#Show a sample of coordinatized patches

per_row = 5
per_col = 9
coords = np.column_stack([predom_dirs.reshape(-1,1), flow_triv_result.F.reshape(-1,1)])

fig = lattice_vis(
    flow_data,
    coords,
    patch_vis,
    per_row=per_row,
    per_col = per_col,
    figsize=19,
    thumb_px=350,   
    dpi=350         
)

plt.show()


# Klein Bottle Model For High-Contrast Natural Image Patches

In [None]:
#Generate a sampling of the Klein bottle model over RP1

n_img_patches = 5000  #Number of patches to sample from each dataset
n_img = 3   #Choose a patch size

rng = np.random.default_rng(0)
img_data = sample_nat_img_kb(n_img_patches, n = n_img, rng = rng)[0]  #Sample patches from the kb model 

print(f'{n_img_patches} {n_img}-by-{n_img} natural image patches generated.')

In [None]:
#Compute predominant gradient directions in RP1

grad_dirs = get_gradient_dirs(img_data)[0]   
print(f'Predominant gradient directions computed.')

In [None]:
#View a small random sample of the dataset arranged by gradient direction
n_samples = 30

label_func = [fr"$\theta = {np.round(grad/np.pi, 2)}$" + r"$\pi$" for grad in grad_dirs]

fig = show_data_vis(
    img_data, 
    patch_vis, 
    label_func = label_func, 
    angles = grad_dirs, 
    sampling_method = 'angle', 
    max_samples = n_samples)
plt.show()


In [None]:
#Run Ripser on a sample of the full dataset

#Compare persistent homology over two different coefficient fields
dgms_2 = ripser(img_data, coeff=2, maxdim=2, n_perm=500)["dgms"]
dgms_3 = ripser(img_data, coeff=3, maxdim=2, n_perm=500)["dgms"]

fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

plot_diagrams(dgms_2, ax=axes[0], title="coeff = 2")
plot_diagrams(dgms_3, ax=axes[1], title="coeff = 3")

plt.tight_layout()
plt.show()


In [None]:
#Construct a cover of the base space RP1

n_img_landmarks = 12
img_landmarks = np.linspace(0, np.pi,n_img_landmarks, endpoint= False)
img_overlap = 1.4

img_radius = img_overlap* np.pi/(2*n_img_landmarks)

img_cover = MetricBallCover(grad_dirs, img_landmarks, img_radius, metric = rp1_metric())
img_cover_data = img_cover.build()

img_summ = img_cover.summarize(plot=True) 
plt.show()

In [None]:
#Run persistence on fibers to confirm local circular features

fiber_ids, dense_idx_list, rips_list = get_local_rips(
    img_data,
    img_cover.U,
    p_values=None,
    to_view = [0,3,8],
    maxdim=1,
    n_perm=500,
    random_state=None,
)

fig, axes = plot_local_rips(
    fiber_ids,
    rips_list,
    n_cols=3,
    titles='default',
    font_size=20,
)

In [None]:
#Construct local circular coordinates
#and model transitions as O(2) matrices

img_bundle = build_bundle(
    img_data,
    img_cover,
#    CircularCoords_cls=CircularCoords,     #OPTION: use Dreimac for circular coordinates
    show=True,
)


In [None]:
#Compute a global coordinate system after dropping an edge from the nerve of the cover
img_triv_result = img_bundle.get_global_trivialization()
print('Global coordinates computed.')

In [None]:
#Show a sample of coordinatized patches
per_row = 5
per_col = 9
coords = np.column_stack([grad_dirs.reshape(-1,1), img_triv_result.F.reshape(-1,1)])

fig = lattice_vis(
    img_data,
    coords,
    patch_vis,
    per_row=per_row,
    per_col = per_col,
    figsize=19,
    thumb_px=350,   
    dpi=350         
)

plt.show()


# Additional Visualization -- Orientation Cocycle Comparisons

In [None]:
#Compare the orientation class representatives for the two datasets
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7), dpi=150)


flow_bundle.show_circle_nerve(
    title="Optical Flow Patches",
    ax=ax1,
    show=False,
)

img_bundle.show_circle_nerve(
    title="Natural Image Patches",
    ax=ax2,
    show=False,
)


plt.tight_layout()
plt.show()
