In [None]:
# ============================================================
# Core scientific stack
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx


# ============================================================
# Optional local circular coordinates
# ============================================================
from dreimac import CircularCoords


# ============================================================
# Persistent homology
# ============================================================
from ripser import ripser
from persim import plot_diagrams


# ============================================================
# circle_bundles core API + analysis tools
# ============================================================
from circle_bundles.api import build_bundle
from circle_bundles.base_covers import MetricBallCover
from circle_bundles.metrics import RP1AngleMetric as rp1_metric


from circle_bundles.analysis.local_analysis import (
    get_local_rips, 
    plot_local_rips, 
    get_local_pca, 
    plot_local_pca
)

from circle_bundles.analysis.fiberwise_clustering import (
    fiberwise_clustering, 
    plot_fiberwise_pca_grid,
    plot_fiberwise_summary_bars,
    get_weights,
    get_cluster_persistence,
    get_filtered_cluster_graph,
)


# ============================================================
# Optical flow data processing + features
# ============================================================

from circle_bundles.optical_flow.contrast import get_predominant_dirs
from circle_bundles.synthetic.step_edges import get_patch_types_list, make_step_edges


# ============================================================
# Visualization utilities
# ============================================================
from circle_bundles.optical_flow.patch_viz import make_patch_visualizer
from circle_bundles.viz.thumb_grids import show_data_vis
from circle_bundles.viz.lattice_vis import lattice_vis
from circle_bundles.viz.nerve_vis import nerve_vis
from circle_bundles.viz.circle_vis import circle_vis, circle_vis_grid
from circle_bundles.viz.fiberwise_clustering_vis import (
    make_patch_cluster_diagram,
    get_G_vertex_coords
)
from circle_bundles.viz.gudhi_graph_utils import create_st_dicts


# ============================================================
# Attach optional BundleResult visualization methods
# ============================================================
from circle_bundles.bundle import attach_bundle_viz_methods
attach_bundle_viz_methods()


# Load The Dataset

In [None]:
import pickle
import pandas as pd

#Load the dataset of preprocessed high-contrast optical flow patches
file_path = '.../HC20_Flow_patches.pkl'   #path to pre-processed patches
with open(file_path, 'rb') as f:
    patch_df =pd.read_pickle(f)

#Get the data in X(k,p)
k = 50          #    density_options = [10, 50, 100, 200, 300, 1500]    
p = 0.60
column = 'density_' + str(k)
patch_df = patch_df.sort_values(by = column, ascending = False)
N = int(p*len(patch_df))
data = np.vstack(patch_df['patch'])[:N]

#Compute the predominant direction and directionality for each patch
predom_dirs, ratios = get_predominant_dirs(data)

#Create a visualizer for optical flow patches
patch_vis = make_patch_visualizer()

print(f'Sample contains {len(data)} 3-by-3 high-contrast optical flow patches.')

In [None]:
#Show a sample of the data
n_samples = 8

label_func = [fr"$\theta = {np.round(pred/np.pi, 2)}$" + r"$\pi$" for pred in predom_dirs]
fig = show_data_vis(
    data, 
    patch_vis, 
    label_func = label_func, 
    angles = predom_dirs, 
    sampling_method = 'angle', 
    max_samples = n_samples)
plt.show()


# Bundle Analysis

## Base Projections And Open Cover

In [None]:
#Construct a cover of the base space
n_landmarks = 16
landmarks = np.linspace(0, np.pi,n_landmarks, endpoint= False)

overlap = 1.5
radius = overlap* np.pi/(2*n_landmarks)

cover = MetricBallCover(predom_dirs, landmarks, radius, metric = rp1_metric())
cover_data = cover.build()
summ = cover.summarize(plot = True)

In [None]:
#Show local PCA of the fibers

fiber_ids, dense_idx_list, proj_list = get_local_pca(
    data,
    cover.U,
    p_values=None,
    to_view=[3,6,14],
    n_components=2,
    random_state=None)

fig, axes = plot_local_pca(
    fiber_ids,
    proj_list,
    n_cols=3,
    titles='default',
    font_size=20,
)

## Fiberwise Clustering

In [None]:
#Compute local 0-D persistence to get an eps value for DBSCAN

to_view = [3,6,14]
fiber_ids, dense_idx_list, rips_list = get_local_rips(
    data,
    cover.U,
    p_values=None,
    to_view=to_view,
    maxdim=1,
    n_perm=500,
    random_state=None,
)

fig, axes = plot_local_rips(
    fiber_ids,
    rips_list,
    n_cols=3,
    titles='default',
    font_size=20,
)

In [None]:
#Run fiberwise clustering
eps_values = 0.3*np.ones(n_landmarks)
min_sample_values = 5*np.ones(n_landmarks)
to_view = [3,6,14]


components, G, graph_dict, cl, summary = fiberwise_clustering(
    data, 
    cover.U, 
    eps_values, 
    min_sample_values
)

plot_fiberwise_pca_grid(summary, to_view=to_view)
plot_fiberwise_summary_bars(summary, hide_biggest=True)

n_clusters = len(np.unique(components))-1
print(f'Total number of global clusters: {n_clusters}')
print(f'Total number of unclustered points: {np.sum(components == -1)}')
print('')

point_count = [np.sum(components == j) for j in np.unique(components)[1:]]
print('Global components by size:')
print(np.argsort(point_count)[::-1])
print('')
print('Cardinality of each global component:')
print([int(x) for x in point_count])
print('')
percentages = 100*point_count/np.sum(point_count)
print('Percentage of the data:')
print(np.round(percentages,2))

# Global Clusters

In [None]:
#Get an array whose rows contain the indices of points in the associated global components
C = np.zeros((len(np.unique(components)), len(data)))
for c in range(len(C)):
    if c == (len(C)-1):
        j = -1
    else:
        j = c
    C[c] = components == j


In [None]:
#Check to see which global clusters have a significant 1D persistent class

fiber_ids, dense_idx_list, rips_list = get_local_rips(
    data,
    C,
    p_values=None,
    to_view=None,
    maxdim=1,
    n_perm=500,
    random_state=None,
)


circular_components = []
others = []
for r in range(len(rips_list)):
    dgm = rips_list[r]['dgms'][1]   #1D persistence diagram for component r
    if len(dgm) > 0:
        b, d = dgm[np.argmax(dgm[:,1] - dgm[:,0])]    #sufficient to run cc
        if 2*b < d:
            circular_components.append(r)
        else:
            others.append(r)
    else:
        others.append(r)

print('components with circular features:')
print(circular_components)
print(f'number of circular components: {len(circular_components)}, number of other components: {len(others)}')


In [None]:
# Run circular coordinates on each global cluster with a strong circular feature
n_landmarks = 100
prime = 17

datasets = []
angles_list = []
titles = []

for j in circular_components:
    if j > 0:   # Don't run circular coordinates on the optical flow torus cluster
        indices = (components == j)
        datasets.append(data[indices])
        cc = CircularCoords(data[indices], n_landmarks, prime=prime)
        angles_list.append(cc.get_coordinates())
        titles.append(f'Component {j}')
        
#to_show = [7,20,23, 4, 5, 6]


fig, axes = circle_vis_grid(
    datasets,
    angles_list,
    patch_vis,
    titles=titles,
    per_circle=8,
    circle_radius=1.0,
    extent_factor=1.2,
    circle_zoom=0.13,
    circle_linewidth=1.0,
    circle_color="black",
    n_cols=3,
    title_fontsize=16,
    figsize_per_panel=5,
    fig_dpi=150,
)
plt.show()

In [None]:
#View a visualization of a circular component of G

#Get the subgraph of G containing a single component
m = 1
G_comps = list(nx.connected_components(G))
Gm = G.subgraph(G_comps[m]).copy()

#Get coordinates for nodes in visualization
vertex_coords = get_G_vertex_coords(Gm)
vertex_coords[:,1] = 0.2*vertex_coords[:,1]
vertex_coords[:,0] = 1*vertex_coords[:,0]

#Get patch representatives of the local clusters
indices = []
clusters = np.zeros((len(Gm.nodes()),2))
for cluster_num, (j, k) in enumerate(Gm.nodes()):
    # Get all n where cl[j,n] == k
    n_matches = np.where(cl[j] == k)[0]
    clusters[cluster_num] = np.array([j,k])
    if len(n_matches) == 0:
        # No datapoint in this cluster, optionally skip or assign -1
        indices.append(-1)
    else:
        # Take the smallest index
        indices.append(n_matches.min())

# Convert to 1D numpy array
indices = np.array(indices)
patch_reps = data[indices]


fig, ax = make_patch_cluster_diagram(
    patch_reps, 
    clusters, 
    Gm, 
    patch_vis, 
    image_zoom=0.35, 
    row_spacing=4, 
    col_spacing=3.5, 
    line_color = 'lightgray', 
    line_width = 5,
)
plt.show()



In [None]:
#View samples in each 'non-circular' global cluster 
max_samples = 5
n_clusters = len(np.unique(components))-1
for j in others:
    if j == n_clusters:
        inds = components == -1
    else:
        inds = components == j
        
    print(f'Sample patches from global cluster {j} ({int(np.sum(C, axis = 1)[j])} total):')
    fig = show_data_vis(
        data[inds], 
        patch_vis, 
        angles = predom_dirs[inds], 
        sampling_method = 'angle', 
        max_samples = max_samples)
    plt.show()


In [None]:
#Get a single representative from each outlier cluster
patch_reps = np.zeros((len(others)-1, 18))
for t, j in enumerate(others):
    if j != n_clusters:
        if j == 25:
            patch_reps[t] = data[components == j][27]
        else:
            patch_reps[t] = data[components == j][0]

fig = show_data_vis(patch_reps, patch_vis, n_cols = 7)
plt.show()



## Cluster Persistence

In [None]:
#Compute cluster persistence

G = get_weights(G, method='rel_card2')
comp_pers = get_cluster_persistence(G)

In [None]:
#Replace G with G^{(0.07)}
thresh = 0.07

filtered_components, filtered_G, filtered_graph_dict, filtered_cl, comp_inds = get_filtered_cluster_graph(
    data,
    G,
    cl,
    thresh = thresh,
    rule = "to_smaller_cluster",
    show_results = True,
    hide_biggest = True,
)


print(f'threshold: {thresh}, global clusters: {len(comp_inds)}')

print(f'number of deleted edges: {len(G.edges) - len(filtered_G.edges)}')


## The Filtered Cluster Graph 

In [None]:
#View a visualization of G0 before and after filtering G
#(the largest connected component)

#Get the subgraph of G containing just the largest connected component
G_comps = list(nx.connected_components(G))
G0 = G.subgraph(G_comps[0]).copy()

#Get the filtered version of G0
thresh = 0.07
components_filtered, filtered_G0, graph_dict_filtered, G0_cl, G0_comp_inds = get_filtered_cluster_graph(data, G0, cl, thresh = thresh)

#Get coordinates for nodes in visualization
vertex_coords = get_G_vertex_coords(G0)

#Convert to a simplex tree and get color assignments
G0_st, filtered_G0_st, vertex_dict, edge_dict, node_to_index = create_st_dicts(G0, filtered_G0)


cmap = {
    4: '#DCE6F2',  # very light blue
    3: '#B0C4E8',  # soft periwinkle
    2: '#8090D1',  # medium lavender-blue
    0: 'black',  # deep violet-blue
    1: '#3F3FBF',  # rich indigo
    5: '#2A2A9F',  # dark blue accent
    6: '#8866CC',  # soft purple accent
    7: '#5C2ABF',  # vibrant purple
   -1: 'red'   # gray for special/missing
}


fig, axes =nerve_vis(
    G0_st,
    vertex_coords,
    cochains={0:vertex_dict, 1:edge_dict},
    base_colors={0:'black', 1:'black'},
    cochain_cmaps={0:cmap, 1:cmap},
    opacity=0,
    node_size=15,
    line_width=1,
    node_labels=None,
    vis_func=None,
    data=None,
    image_zoom=0.1,
    title=r'Visualization Of $G_{0}$'
)

plt.show()


In [None]:
#Check to see which global clusters have a significant 1D persistent class

fiber_ids, dense_idx_list, rips_list = get_local_rips(
    data,
    comp_inds,
    p_values=None,
    to_view=None,
    maxdim=1,
    n_perm=500,
    random_state=None,
)


circular_components = []
others = []
for r in range(len(rips_list)):
    if rips_list[r] is not None:

        dgm = rips_list[r]['dgms'][1]   #1D persistence diagram for component r
        if len(dgm) > 0:
            b, d = dgm[np.argmax(dgm[:,1] - dgm[:,0])]    
            if 2*b < d:    #sufficient to run cc
                circular_components.append(r)
            else:
                others.append(r)
        else:
            others.append(r)
    else:
        others.append(r)
print('components with circular features:')
print(circular_components)
print(f'number of circular components: {len(circular_components)}, number of other components: {len(others)}')


In [None]:
# Run circular coordinates on each global cluster with a strong circular feature

#Ignore the largest component corresponding to the optical flow torus
biggest = np.argmax([np.sum(row) for row in comp_inds])

n_landmarks = 100
prime = 17

datasets = []
angles_list = []
titles = []

for j in circular_components:
    if j != biggest:   
        indices = comp_inds[j].astype(bool)
        datasets.append(data[indices])
        cc = CircularCoords(data[indices], n_landmarks, prime=prime)
        angles_list.append(cc.get_coordinates())
        titles.append(f"Component {j}")

#Show a sample of coordinatized patches from each circular component
#to_show = [7,20,23, 4, 5, 6]

fig, axes = circle_vis_grid(
    datasets,
    angles_list,
    patch_vis,
    titles=titles,
    per_circle=8,
    circle_radius=1.0,
    extent_factor=1.2,
    circle_zoom=0.13,
    circle_linewidth=1.0,
    circle_color="black",
    n_cols=4,
    title_fontsize=16,
    figsize_per_panel=5,
    fig_dpi=150,
)
plt.show()

In [None]:
#Get a sample patch from each circle
samples = np.vstack([data[0] for data in datasets])
fig = show_data_vis(samples, patch_vis)
plt.show()


## Composite Circles

In [None]:
#View samples in each 'non-circular' global cluster 
for j in others:
    inds = comp_inds[j].astype(bool)        
    print(f'Sample patches from global cluster {j} ({int(np.sum(inds))} total)')
    fig = show_data_vis(data[inds], patch_vis, angles = predom_dirs[inds], max_samples = 5)
    plt.show()


In [None]:
#Fuse cluster pairs
bunches = [[2, 38], [39, 40]]
bunch_inds = np.zeros((len(bunches), len(data)))
for j, bunch in enumerate(bunches):
    for component in bunch:
        bunch_inds[j] += comp_inds[component]

#Show PCA of the composite circles
fiber_ids, dense_idx_list, proj_list = get_local_pca(
    data,
    bunch_inds,
    p_values=None,
    to_view=None,
    n_components=2,
    random_state=None)


fig, axes = plot_local_pca(
    fiber_ids,
    proj_list,
    n_cols=3,
    titles=titles,
    font_size=20,
)

#Show persistence of the composite circles
fiber_ids, dense_idx_list, rips_list = get_local_rips(
    data,
    bunch_inds,
    p_values=None,
    to_view=None,
    maxdim=1,
    n_perm=500,
    random_state=None,
)

fig, axes = plot_local_rips(
    fiber_ids,
    rips_list,
    n_cols=3,
    titles=titles,
    font_size=20,
)

In [None]:
#Get a small number of synthetic binary step edge patches from the hypothesized circles

samples_per_filament = 250
patch_types_list = get_patch_types_list()

#Show some samples
spots = 18
synth_patches_A, synth_angles_A = make_step_edges(samples_per_filament, patch_types_list[spots])
fig, axes = show_data_vis(synth_patches_A, patch_vis, max_samples = 8)
plt.show()

spots = 16
synth_patches_B, synth_angles_B = make_step_edges(samples_per_filament, patch_types_list[spots])
fig, axes = show_data_vis(synth_patches_B, patch_vis, max_samples = 8)
plt.show()

#Create combined datasets for computing circular coordinates
boosted_circ_A = np.concatenate((data[bunch_inds[0].astype(bool)], synth_patches_A))
boosted_circ_B = np.concatenate((data[bunch_inds[1].astype(bool)], synth_patches_B))
boosted_circs = [boosted_circ_A, boosted_circ_B]


In [None]:
# Run circular coordinates on each composite component
n_landmarks = 100
prime = 17

composite_datasets = []
composite_angs_list = []
titles = []

for j, circ in enumerate(boosted_circs):
    composite_datasets.append(data[bunch_inds[j].astype(bool)])
    cc = CircularCoords(circ, n_landmarks, prime=prime)
    composite_angs_list.append(cc.get_coordinates()[:-samples_per_filament])
    titles.append(f"Components {', '.join(map(str, bunches[j]))}")    


fig, axes = circle_vis_grid(
    composite_datasets,
    composite_angs_list,
    patch_vis,
    titles=titles,
    per_circle=8,
    circle_radius=1.0,
    extent_factor=1.2,
    circle_zoom=0.13,
    circle_linewidth=1.0,
    circle_color="black",
    title_fontsize=16,
    figsize_per_panel=5,
    fig_dpi=150,
)
plt.show()

# Additional Visualizations And Summaries

In [None]:
#Get a sample patch from each noise cluster
noise_clusters = [comp for comp in others if comp not in [2, 38, 39, 40]]

datasets = [data[comp_inds[clust].astype(bool)] for clust in noise_clusters]
samples = np.vstack([data[0] for data in datasets])

fig = show_data_vis(samples, patch_vis, n_cols = 7)
plt.show()


In [None]:
#Create an array whose rows track the final global clusters

#Combine the 'noise' clusters into a single cluster
new_bunches = bunches.copy()
new_bunches.append(noise_clusters)

#Combine the bunches into single clusters
C_final = np.copy(comp_inds)
to_delete = []
for bunch in new_bunches:
    m = min(bunch)
    for component in bunch:
        if component != m:
            C_final[m] += C_final[component]            
            to_delete.append(component)
C_final = np.delete(C_final, to_delete, axis=0)

#Rearrange everything to make it logical
C_final[[0, 2, 4, 27, 28, 29]] = C_final[[4, 27, 0, 2, 29, 28]]

#Note that the last 'component' is just the outlier data combined
print(C_final.shape)

#Create a 1D array with the final cluster labels
cluster_labels = np.full(C_final.shape[1], -1, dtype=int)

cols_with_one = np.any(C_final == 1, axis=0)
cluster_labels[cols_with_one] = np.argmax(C_final[:, cols_with_one], axis=0)


In [None]:
# Show the coordinatized patches in all 28 recovered step-edge circles 
n_landmarks = 100
prime = 17

circle_datasets = []
all_angles = []
titles = []

for j in range(30):
    indices = C_final[j]
    if j not in [0,27, 28, 29]:   
        circle_datasets.append(data[indices])
        cc = CircularCoords(data[indices], n_landmarks, prime=prime)
        all_angles.append(cc.get_coordinates())
        titles.append(f'Component {j}')
    elif j in [27,28]:
        circle_datasets.append(composite_datasets[j-27])
        all_angles.append(composite_angs_list[j-27])
        titles.append(f"Components {', '.join(map(str, bunches[j-27]))}")    

#to_show = [7,20,23, 4, 5, 6]

fig, axes = circle_vis_grid(
    circle_datasets,
    all_angles,
    patch_vis,
    titles=titles,
    per_circle=8,
    circle_radius=1.0,
    extent_factor=1.2,
    circle_zoom=0.13,
    circle_linewidth=1.0,
    circle_color="black",
    n_cols=4,
    title_fontsize=16,
    figsize_per_panel=5,
    fig_dpi=150,
)
plt.show()

In [None]:
#Show a representative from each circular component
samples = np.vstack([data[0] for data in circle_datasets])

fig = show_data_vis(samples, patch_vis, n_cols = 14)
plt.show()


## Lifting The Step Edge Circles

In [None]:
#Isolate the data in the step edge circles
circ_inds = (C_final[0]+C_final[-1] == 0)*(np.any(C_final != 0, axis=0))
circ_data = data[circ_inds]
C_circ = C_final[1:-1, circ_inds]

print(f'{np.sum(circ_inds)} patches from the step edge circles')


In [None]:
#Get a small number of synthetic binary step edge patches to fill in the composite circles
samples_per_filament = 250
patch_types_list = get_patch_types_list()
spots = [18, 16]

for j in range(2):
    #Generate patches for this circle
    synth_patches = make_step_edges(samples_per_filament, patch_types_list[spots[j]])[0]

    #Show some samples
    fig, axes = show_data_vis(synth_patches, patch_vis, max_samples = 8)
    plt.show()
    
    #Add the synthetic data to the circle dataset
    circ_data = np.concatenate((circ_data, synth_patches))

    new_cols = np.zeros((len(C_circ), samples_per_filament), dtype=bool)
    new_cols[26+j, :] = True
    C_circ = np.hstack([C_circ, new_cols])

print(f'Augmented dataset contains {len(circ_data)} patches')

In [None]:
#Save the data
folder_path = '/Users/bradturow/Desktop/Circle Bundle Code'
file_name = 'K_50_60_Circles.pkl'
save_path = folder_path + file_name
with open(save_path, 'wb') as f:
    pickle.dump([circ_data, C_circ], f)

        