In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline
from scipy.spatial import KDTree

def read_csv(csv_path):
    """Read a CSV file containing polylines."""
    data = np.genfromtxt(csv_path, delimiter=',', skip_header=1)  
    paths = []
    unique_paths = np.unique(data[:, 0])
    
    for path_id in unique_paths:
        path_data = data[data[:, 0] == path_id][:, 1:]
        paths.append(path_data)
    
    return paths

def detect_occlusion(curves):
    """Detect the type of occlusion between curves."""
    all_points = np.vstack(curves)
    tree = KDTree(all_points)
    
    occlusions = []
    for i, curve in enumerate(curves):
        for j, other_curve in enumerate(curves):
            if i != j:
                distances, _ = tree.query(np.vstack(curve), k=1)
                if np.min(distances) < 0.01:  
                    occlusions.append((i, j))
    
    return occlusions

def fit_spline(points):
    """Fit a spline to the given points."""
    t = np.arange(len(points))
    t_new = np.linspace(0, len(points) - 1, 100)
    spline = CubicSpline(t, points, axis=0)
    return spline(t_new)

def complete_curve(curves, occlusions):
    """Complete the curves based on detected occlusions."""
    completed_curves = []
    
    for i, curve in enumerate(curves):
        if any(i in occlusion for occlusion in occlusions):
            
            if len(curve) > 1:
                completed_curve = fit_spline(np.array(curve))
                completed_curves.append(completed_curve)
            else:
                completed_curves.append(np.array(curve))
        else:
            completed_curves.append(np.array(curve))
    
    return completed_curves

def plot_curves(curves, labels=None):
    """Plot curves with optional labels."""
    plt.figure(figsize=(10, 10))
    
   
    if labels is None:
        labels = [f'Curve {i+1}' for i in range(len(curves))]
    elif len(labels) < len(curves):
        labels.extend([f'Curve {i+1}' for i in range(len(labels), len(curves))])
    
    
    for i, curve in enumerate(curves):
        if len(curve) > 0:  # Ensure curve has data
            plt.plot(curve[:, 0], curve[:, 1], label=labels[i] if i < len(labels) else f'Curve {i+1}')
    
    plt.legend()
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.title('Completed Curves')
    
    
    x_min, x_max = np.min([curve[:, 0] for curve in curves]), np.max([curve[:, 0] for curve in curves])
    y_min, y_max = np.min([curve[:, 1] for curve in curves]), np.max([curve[:, 1] for curve in curves])
    
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    
    plt.show()

def process_and_complete_curves(csv_path):
    """Process and complete the curves from the CSV file."""
    curves = read_csv(csv_path)
    occlusions = detect_occlusion(curves)
    completed_curves = complete_curve(curves, occlusions)
    
    return completed_curves


csv_file_path = 'problems/occlusion1.csv'
completed_curves = process_and_complete_curves(csv_file_path)
plot_curves(completed_curves, labels=[f'Completed Curve {i+1}' for i in range(len(completed_curves))])
