# Real Optical Flow Torus Data Analysis

In [None]:
# ============================================================
# Core scientific stack
# ============================================================
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# Optional local circular coordinates
# ============================================================
from dreimac import CircularCoords


# ============================================================
# Persistent homology
# ============================================================
from ripser import ripser
from persim import plot_diagrams


# ============================================================
# circle_bundles core API + analysis tools
# ============================================================
from circle_bundles.api import build_bundle
from circle_bundles.base_covers import MetricBallCover
from circle_bundles.metrics import RP1AngleMetric as rp1_metric
from circle_bundles.analysis.local_analysis import get_local_rips, plot_local_rips


# ============================================================
# Optical flow data processing + features
# ============================================================
from circle_bundles.optical_flow.flow_processing import (
    get_patch_sample,
    preprocess_flow_patches,
)
from circle_bundles.optical_flow.contrast import get_predominant_dirs


# ============================================================
# Visualization utilities
# ============================================================
from circle_bundles.optical_flow.patch_viz import make_patch_visualizer
from circle_bundles.viz.thumb_grids import show_data_vis
from circle_bundles.viz.lattice_vis import lattice_vis


# ============================================================
# Attach optional BundleResult visualization methods
# ============================================================
from circle_bundles.bundle import attach_bundle_viz_methods
attach_bundle_viz_methods()


# Get A Sample of Optical Flow Patches From The Sintel Dataset

In [None]:
import pickle 
import pandas as pd

patches_per_frame = 400
#folder_path = ".../MPI-Sintel-complete/training/flow"   #path to Sintel flow frames
folder_path = '/Users/bradturow/Desktop/TDA/MPI-Sintel-complete/training/flow'
patch_df, file_paths = get_patch_sample(
    folder_path,
    patches_per_frame = patches_per_frame,
    d = 3)

print('')
print(f'{len(patch_df)} optical flow patches sampled')

#Downsample if necessary
max_samples = 400000
if len(patch_df) > max_samples:
    patch_df = patch_df.sample(n=max_samples)


In [None]:
#Preprocess the sample
hc_frac = 0.2
max_samples = 50000
k = [300]

print('Preprocessing data...')
patch_df = preprocess_flow_patches(
    patch_df,
    hc_frac = hc_frac,
    max_samples = max_samples,
    k_list = k)

print('Preprocessing complete.')

In [None]:
#Keep only the densest patches
p = 0.5
n_samples = int(p*len(patch_df))
data = np.vstack(patch_df['patch'])[:n_samples] #Data is already sorted in decreasing order by density
print(f'Downsampled to {len(data)} patches')

# Preliminary Analysis

In [None]:
#View a sample of the dataset

#Create a patch visusalization function
patch_vis = make_patch_visualizer()

fig = show_data_vis(data, patch_vis, sampling_method = None, max_samples = 30)
plt.show()

In [None]:
#Run Ripser on the dataset

diagrams = ripser(data, maxdim = 2, n_perm = 500)['dgms']
plot_diagrams(diagrams, show=True)    

# Bundle Analysis

In [None]:
#Compute the predominant flow axis in RP1 for each patch and construct a cover of RP1

predom_dirs, ratios = get_predominant_dirs(data)    #compute directionalities for later use

#Construct a cover of the base space
n_landmarks = 16
landmarks = np.linspace(0, np.pi, n_landmarks, endpoint= False)
overlap = 1.99
radius = overlap* np.pi/(2*n_landmarks)

cover = MetricBallCover(predom_dirs, landmarks, radius, metric = rp1_metric())
cover_data = cover.build()

#Show a summary of the construction
summ = cover.summarize(plot = True)

In [None]:
#View a sample of the dataset arranged by predominant flow direction
n_samples = 8

label_func = [fr"$\theta = {np.round(pred/np.pi, 2)}$" + r"$\pi$" for pred in predom_dirs]
fig = show_data_vis(
    data, 
    patch_vis, 
    label_func = label_func, 
    angles = predom_dirs, 
    sampling_method = 'angle', 
    max_samples = n_samples)
plt.show()


In [None]:
#Construct local circular coordinates and model transitions as O(2) matrices

bundle = build_bundle(
    data,
    cover,
#    CircularCoords_cls=CircularCoords,     #optionally use sparse cc algorithm
    show=True,
)


In [None]:
#View the correlations between local circular coordinates on overlaps

fig = bundle.compare_trivs(ncols = 4)
plt.show()

In [None]:
#Show a visualization of the nerve labeled with SW1

fig = bundle.show_circle_nerve()
plt.show()


## Restrict To High-Directionality Data 

In [None]:
#Set up a bundle with just the high-directionality data
thresh = 0.8
high_inds = ratios > thresh

print(f'{np.sum(high_inds)} high-directionality patches')

high_cover = MetricBallCover(predom_dirs[high_inds], landmarks, radius, metric = rp1_metric())
high_cover_data = high_cover.build()

#Show a summary of the construction
high_summ = high_cover.summarize(plot = True)

In [None]:
#Construct local circular coordinates and model transitions as O(2) matrices
high_bundle = build_bundle(
    data[high_inds],
    high_cover,
#    CircularCoords_cls=CircularCoords,       #optionally use sparse cc algorithm
    show=True,
)


In [None]:
#Get global coordinates using just the high-directionality data

high_triv_result = high_bundle.get_global_trivialization()
print('global coordinates computed.')

In [None]:
#Show a recovered patch diagram
per_row = 5
per_col = 9
coords = np.array([predom_dirs[high_inds], high_triv_result.F]).T

fig = lattice_vis(
    high_bundle.data,
    coords,
    patch_vis,
    per_row=per_row,
    per_col = per_col,
    figsize=19,
    thumb_px=350,   
    dpi=350, 
    padding = 0
)

plt.show()


In [None]:
#Run Ripser on the high-directionality data

diagrams = ripser(high_bundle.data, maxdim = 2, n_perm = 500)['dgms']
plot_diagrams(diagrams, show=True)    

In [None]:
#Get a visualization of the low-directionality data
thresh = 0.7
low_inds = ratios < thresh
print(f'{np.sum(low_inds)} low-directionality patches')
low_data = data[low_inds]

#Show a recovered patch diagram
per_row = 5
per_col = 9
coords = np.array([predom_dirs[low_inds], triv_result.F[low_inds]]).T

fig = lattice_vis(
    low_data,
    coords,
    patch_vis,
    per_row=per_row,
    per_col = per_col,
    figsize=19,
    thumb_px=350,   
    dpi=350,
    padding = 0
)

plt.show()



In [None]:
#Run Ripser on the low-directionality dataset

diagrams = ripser(low_data, maxdim = 2, n_perm = 500)['dgms']
plot_diagrams(diagrams, show=True)    