In [40]:
from pathlib import Path
import pandas as pd
from pysr import PySRRegressor

def collect_all_equations(N):
    """Collect and aggregate all symbolic equations from saved PySRRegressor runs in 'outputs/'.

    Returns:
        pd.DataFrame: DataFrame containing all equations, their corresponding run_ids, features present,
                    and features actually used in each equation.
    """


    outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
    run_folders = [f for f in outputs_dir.iterdir() if f.is_dir()]
    run_folders = sorted(run_folders, key=lambda x: x.name)

    summary = []
    for run_folder in run_folders[:-N]:
        try:
            model = PySRRegressor.from_file(run_directory=str(run_folder))
            eqs = model.equations_.copy()
            eqs['run_id'] = run_folder.name
            eqs['features_in'] = [set(model.feature_names_in_)] * len(eqs)
            eqs['features_used'] = eqs['sympy_format'].apply(lambda s: s.free_symbols)
            summary.append(eqs)
            del model  # Explicitly cleanup to reduce Julia thread issues
        except Exception as e:
            print(f"Could not load model from {run_folder}: {e}")
            continue

    if summary:
        all_equations = pd.concat(summary, ignore_index=True)
    else:
        all_equations = pd.DataFrame()
    return all_equations

In [41]:
from pathlib import Path

outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
run_folders = [f for f in outputs_dir.iterdir() if f.is_dir()]
run_folders

[PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_170509_FT9tuW'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_164613_8ZyZQI'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_164413_msSvGs'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_215453_ElV6fP'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_170658_KtmoID'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_170958_oXfkRu'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251128_083918_7Tx1bP'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_145659_aHbujh'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_215913_MMyTOV'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_215155_Pu6lgr'),
 PosixPath('/Users/danielgieseler/Documents/Code/early_stop/

In [42]:
import sympy as sp
from IPython.display import display, Markdown

def display_global_pareto_front(all_equations):
    """
    Display the global Pareto front, removing any equations that are dominated
    (i.e., have both higher or equal complexity and higher or equal loss than another).
    Only keep an equation if there is no equation at lower complexity with a lower/equal loss.
    """
    # Sort by complexity, then by loss
    pareto = []
    eqs = all_equations.sort_values(["complexity", "loss"])
    min_loss_so_far = float('inf')
    for row in eqs.itertuples():
        if row.loss < min_loss_so_far:
            pareto.append(row)
            min_loss_so_far = row.loss
        # else: this row is dominated, skip

    # Format as DataFrame
    pareto_df = pd.DataFrame([{
        "complexity": r.complexity,
        "loss": r.loss,
        "sympy_format": r.sympy_format,
        "run_id": r.run_id
    } for r in pareto])

    # Display as table
    equations_md = "## Global Pareto Front\n\n| C | L | Equation |\n|---|---|---|\n"
    for _, row in pareto_df.iterrows():
        latex_eq = sp.latex(row['sympy_format'])
        equations_md += f"| {row['complexity']} | {row['loss']:.6f} | ${latex_eq}$ |\n"

    display(Markdown(equations_md))
    return pareto_df


In [43]:
from collections import Counter

def compute_feature_rank_df(all_equations):
    all_features_in = Counter()
    all_features_used = Counter()

    for row in all_equations.itertuples():
        for f in row.features_in:
            all_features_in[f] += 1
        for f in row.features_used:
            all_features_used[str(f)] += 1

    all_feature_names = set(all_features_in.keys()) | set(all_features_used.keys())
    rank_rows = []
    for feat in sorted(all_feature_names):
        IN = all_features_in.get(feat, 0)
        USED = all_features_used.get(feat, 0)
        SCORE = USED / IN if IN > 0 else 0.0
        rank_rows.append({"feature_name": feat, "SCORE": SCORE, "IN": IN, "USED": USED})

    feature_rank_df = pd.DataFrame(rank_rows)
    feature_rank_df = feature_rank_df.sort_values("SCORE", ascending=False).reset_index(drop=True)
    return feature_rank_df

In [44]:
def plot_tradeoff_curve(curve):
    """
    Plot the trade-off between early stopping and accuracy for a given curve DataFrame.

    Args:
        curve (pd.DataFrame): DataFrame where rows are epsilon values (max errors), 
                              columns are 'model_n' (e.g. model_1, model_2, ...),
                              and values are (presumed log10(avg_step)).
    """
    import plotly.graph_objs as go
    import numpy as np
    import re
    from matplotlib.colors import LinearSegmentedColormap, Normalize, to_hex

    fig = go.Figure()
    curve_sorted = curve.sort_index()

    def extract_model_num(name):
        m = re.match(r"model_(\d+)", name)
        return int(m.group(1)) if m else None

    model_nums = [extract_model_num(col) for col in curve_sorted.columns]
    min_model, max_model = min(model_nums), max(model_nums)

    # Direct green-to-red map, no yellow
    green_red = LinearSegmentedColormap.from_list("greenred", ["#2ca02c", "#d62728"])
    norm = Normalize(vmin=min_model, vmax=max_model)
    model_colors = {f"model_{n}": to_hex(green_red(norm(n))) for n in model_nums}

    x_vals_log = np.array(curve_sorted.index, dtype=float)
    for model in curve_sorted.columns:
        fig.add_trace(
            go.Scatter(
                x=x_vals_log,
                y=10 ** curve_sorted[model],  # Convert log values to linear before plotting
                mode="lines+markers",
                name=model,
                line=dict(color=model_colors[model])
            )
        )

    max_step = curve_sorted.values.max()
    fig.add_trace(
        go.Scatter(
            x=x_vals_log,
            y=[10 ** max_step] * len(curve_sorted.index),  # Convert the log value to linear
            mode="lines",
            name="max_step",
            line=dict(color="black", width=2, dash="dot")
        )
    )

    fig.update_layout(
        title="Trade-off: early-stop vs. accuracy",
        xaxis_title="Max Error",
        yaxis_title="Avg Step",
        xaxis_type="log",
        legend_title="Model",
        yaxis_type="log"
    )

    fig.show()

In [45]:
import json

def update_global_pareto_with_tradeoff_curve(global_pareto):
    outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
    new_columns = []
    for _, row in global_pareto.iterrows():
        with open(outputs_dir / row['run_id'] / 'tradeoff_curve.json', 'r') as f:
            tradeoff_curve = json.load(f)
        new_columns.append(tradeoff_curve[str(row['complexity'])])
    global_pareto['tradeoff_curve'] = new_columns
    return global_pareto

In [46]:
def plot_tradeoff_curve(curve):
    """
    Plot the trade-off between early stopping and accuracy using the new 'curve' DataFrame.

    Args:
        curve (pd.DataFrame): DataFrame where each row is a model/complexity, 
                              and 'tradeoff_curve' column holds [[epsilon, avg_step], ...] pairs.
    """
    import plotly.graph_objs as go
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap, Normalize, to_hex

    # Assume each row is a curve (by complexity). We'll plot each as a series.
    # Extract all complexities for colormap
    complexities = curve['complexity'].tolist()
    min_model, max_model = min(complexities), max(complexities)
    green_red = LinearSegmentedColormap.from_list("greenred", ["#2ca02c", "#d62728"])
    norm = Normalize(vmin=min_model, vmax=max_model)
    # Color by complexity
    model_colors = {c: to_hex(green_red(norm(c))) for c in complexities}

    fig = go.Figure()

    max_y = -np.inf
    max_label = None

    for idx, row in curve.iterrows():
        complexity = row['complexity']
        tradeoff_points = row['tradeoff_curve'] # List of [epsilon, avg_step]
        epsilons = [xy[0] for xy in tradeoff_points]
        # Convert log-step value to normal value on y axis
        avg_steps = [10 ** xy[1] for xy in tradeoff_points]
        # Track max for overlay
        y_max = max(avg_steps)
        if y_max > max_y:
            max_y = y_max
            max_label = f"complexity {complexity}"
        fig.add_trace(
            go.Scatter(
                x=epsilons,
                y=avg_steps,
                mode="lines+markers",
                name=f"complexity {complexity}",
                line=dict(color=model_colors[complexity])
            )
        )

    # Overlay "max_step" as a flat line if wanted
    if max_y > 0:
        # Use the sorted and unique x-axis values present (all epsilons)
        all_epsilons = sorted({xy[0] for sublist in curve['tradeoff_curve'] for xy in sublist})
        fig.add_trace(
            go.Scatter(
                x=all_epsilons,
                y=[max_y] * len(all_epsilons),
                mode="lines",
                name="max_step",
                line=dict(color="black", width=2, dash="dot")
            )
        )

    fig.update_layout(
        title="Trade-off: early-stop vs. accuracy",
        xaxis_title="Max Error",
        yaxis_title="Avg Step",
        xaxis_type="log",
        legend_title="Complexity",
        yaxis_type="log"
    )

    fig.show()

In [None]:
N = 10 # 10 8 5 4 0
offset = 2
all_equations = collect_all_equations(N+offset)

Attempting to load model from /Users/danielgieseler/Documents/Code/early_stop/outputs/20251125_145348_iCKEDY/checkpoint.pkl...


In [48]:
global_pareto = display_global_pareto_front(all_equations)

## Global Pareto Front

| C | L | Equation |
|---|---|---|
| 1 | 0.000923 | $0.4714$ |
| 3 | 0.000814 | $0.942 last_{loss}$ |
| 4 | 0.000505 | $\log{\left(last_{loss} + 1.1045 \right)}$ |
| 6 | 0.000472 | $first_{derivative ws2} + \log{\left(last_{loss} + 1.1045 \right)}$ |
| 8 | 0.000458 | $\frac{first_{derivative ws2}}{last_{loss}} + \log{\left(last_{loss} + 1.1045 \right)}$ |
| 10 | 0.000433 | $0.01819 feature_{step} + last_{loss} e^{- 0.3813 last_{loss}}$ |
| 12 | 0.000409 | $0.0190995 feature_{step} + last_{loss} e^{- 0.3813 last_{loss}}$ |
| 13 | 0.000360 | $0.01819 feature_{step} \log{\left(feature_{step} \right)} + last_{loss} e^{- 0.3813 last_{loss}}$ |
| 15 | 0.000331 | $0.01819 feature_{step} \log{\left(feature_{step} \right)} + first_{derivative ws2} + last_{loss} e^{- 0.3813 last_{loss}}$ |


In [49]:
global_pareto = update_global_pareto_with_tradeoff_curve(global_pareto)
plot_tradeoff_curve(global_pareto)

In [50]:
compute_feature_rank_df(all_equations)

Unnamed: 0,feature_name,SCORE,IN,USED
0,last_loss,0.888889,9,8
1,feature_step,0.444444,9,4
2,first_derivative_ws2,0.333333,9,3
3,target_step,0.0,9,0


In [51]:
# # complexity = n
# # get that equations lambda
# from dataset import get_dataset
# from traditional_run import FM

# C = 9
# feature_syms = global_pareto.loc[global_pareto['complexity'] == C]['sympy_format'].values[0].free_symbols
# feature_names = set(str(s) for s in feature_syms)
# feature_names = feature_names - {'feature_step', 'target_step'}
# print(feature_names)

# df = get_dataset({f: FM[f] for f in feature_names}, path='runs_data.json')





In [52]:
complexities = [5, 6, 7, 8, 9, 10]
for c in complexities:
    subset = global_pareto[global_pareto['complexity'] == c]
    if not subset.empty:
        print(f"Complexity {c}:")
        for sf in subset['sympy_format']:
            print(sf)
        print()


Complexity 6:
first_derivative_ws2 + log(last_loss + 1.1045)

Complexity 8:
first_derivative_ws2/last_loss + log(last_loss + 1.1045)

Complexity 10:
0.01819*feature_step + last_loss*exp(-0.3813*last_loss)

