In [1]:
%load_ext autoreload
%autoreload 2
%aimport connectivity

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
from connectivity.curves import CURVES

curve_name = "4P"
infinity_replacement = 100
curve = CURVES[curve_name]


curve_function = curve["function"]
param_names = curve["param_names"]
initial_values = curve["initial_values"]
bounds = curve["bounds"]
if bounds is not None:
    bounds = (
        [max(-infinity_replacement, b) if b == -np.inf else min(infinity_replacement, b) for b in curve["bounds"][0]],  # Lower bounds
        [max(-infinity_replacement, b) if b == -np.inf else min(infinity_replacement, b) for b in curve["bounds"][1]],  # Upper bounds
    )
else:
    bounds = (
        [-infinity_replacement] * len(param_names),  # Lower bounds
        [infinity_replacement] * len(param_names),  # Upper bounds
    )

sliders = {
    param: widgets.FloatSlider(
        min=bounds[0][i], max=bounds[1][i], step=0.01, value=initial_values[i], description=param
    )
    for i, param in enumerate(param_names)
}


# Generate x values
x_fit = np.linspace(0, 1, 1000)

def plot_curve(**params):
    """Plot curve and its derivative."""
    param_values = list(params.values())
    y_fit = curve_function(x_fit, *param_values)

    # Compute numerical derivative
    dy_dx = np.gradient(y_fit, x_fit)

    # Area under curve
    area = np.trapezoid(y_fit, x_fit)

    # Create subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 7), sharex=True, gridspec_kw={'height_ratios': [3, 1]})

    # Main curve
    ax1.plot(x_fit, y_fit, label=f"{curve['name']} Curve")
    ax1.set_ylabel("y")
    ax1.set_title(f"{curve['name']} Curve (AUC = {area:.2f})")
    ax1.legend()
    ax1.grid()

    # Find index of maximum absolute slope (peak)
    peak_idx = np.argmax(np.abs(dy_dx))
    x_peak = x_fit[peak_idx]

    # Derivative
    ax2.plot(x_fit, dy_dx, color='gray', label="dy/dx")
    ax2.axvline(x_peak, color='red', linestyle='--', alpha=0.5)  # Vertical line at peak
    ax2.set_xlabel("x")
    ax2.set_ylabel("dy/dx")
    ax2.legend()
    ax2.set_title(f"Peak at x = {x_peak:.2f}, max_slope = {dy_dx[peak_idx]:.2f}")
    ax2.grid()

    plt.tight_layout()
    plt.show()

# Interactive display
display(widgets.interactive(plot_curve, **sliders))

interactive(children=(FloatSlider(value=1.0, description='upper_plateau', min=-100.0, step=0.01), FloatSlider(…