In [None]:
# TODO: use functions in tda_compute.py to compute persistence intervals and betti curves

import os
import json
import numpy as np
import matplotlib.pyplot as plt
from tda.compute import *
# --- Example Usage ---
# 1. Create sample data
num_points = 20
radius = 1.0
angles = np.linspace(0, 2 * np.pi, num_points, endpoint=False)
sample_points = np.array([[radius * np.cos(a), radius * np.sin(a)] for a in angles])
sample_points += np.random.normal(scale=0.1, size=sample_points.shape)
print("Generated Sample Point Cloud (first 5 points):\n", sample_points[:5])
print("-" * 30)

# 2. Compute Persistence using the function
max_filt_scale = 1.5
max_hom_dim = 2 # Compute H0, H1    

persistence_intervals, st = compute_persistence(sample_points,
                                                max_edge_length=max_filt_scale,
                                                max_dimension=max_hom_dim)

print("\nComputed Persistence Intervals:")
# Handle potential empty persistence list
if persistence_intervals:
    for dim, (birth, death) in persistence_intervals:
        death_str = f"{death:.4f}" if not np.isinf(death) else "inf"
        print(f"  Dim {dim}: [{birth:.4f}, {death_str})")
else:
    print("  No persistence intervals computed.")
print("-" * 30)

# 3. Plot Persistence Diagram
if persistence_intervals:
    plot_persistence_diagram(persistence_intervals)
else:
    print("Skipping persistence diagram plot (no intervals).")


# 4. Compute and Plot Betti Curves
threshold_values = np.linspace(0, max_filt_scale, 100)
betti_data = compute_betti_curves(st, persistence_intervals, threshold_values)
plot_betti_curves(threshold_values, betti_data)

print("\nAnalysis complete.")