In [1]:
from pathlib import Path
import pandas as pd
from pysr import PySRRegressor

def collect_all_equations(N: int | None = None):
    """Collect and aggregate all symbolic equations from saved PySRRegressor runs in 'outputs/'.

    Returns:
        pd.DataFrame: DataFrame containing all equations, their corresponding run_ids, features present,
                    and features actually used in each equation.
    """


    outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
    run_folders = [f for f in outputs_dir.iterdir() if f.is_dir()]
    run_folders = sorted(run_folders, key=lambda x: x.name)

    summary = []
    run_folders = run_folders[:-N] if N is not None else run_folders
    for run_folder in run_folders:
        try:
            model = PySRRegressor.from_file(run_directory=str(run_folder))
            eqs = model.equations_.copy()
            eqs['run_id'] = run_folder.name
            eqs['features_in'] = [set(model.feature_names_in_)] * len(eqs)
            eqs['features_used'] = eqs['sympy_format'].apply(lambda s: s.free_symbols)
            summary.append(eqs)
            del model  # Explicitly cleanup to reduce Julia thread issues
        except Exception as e:
            print(f"Could not load model from {run_folder}: {e}")
            continue

    if summary:
        all_equations = pd.concat(summary, ignore_index=True)
    else:
        all_equations = pd.DataFrame()
    return all_equations

Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [2]:
from pathlib import Path

outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
run_folders = [f for f in outputs_dir.iterdir() if f.is_dir()]
run_folders

[PosixPath('/Users/danielgieseler/Documents/Code/early_stop/outputs/20251201_235222_HhSTqh')]

In [3]:
import sympy as sp
from IPython.display import display, Markdown

def display_global_pareto_front(all_equations):
    """
    Display the global Pareto front, removing any equations that are dominated
    (i.e., have both higher or equal complexity and higher or equal loss than another).
    Only keep an equation if there is no equation at lower complexity with a lower/equal loss.
    """
    # Sort by complexity, then by loss
    pareto = []
    eqs = all_equations.sort_values(["complexity", "loss"])
    min_loss_so_far = float('inf')
    for row in eqs.itertuples():
        if row.loss < min_loss_so_far:
            pareto.append(row)
            min_loss_so_far = row.loss
        # else: this row is dominated, skip

    # Format as DataFrame
    pareto_df = pd.DataFrame([{
        "complexity": r.complexity,
        "loss": r.loss,
        "sympy_format": r.sympy_format,
        "run_id": r.run_id
    } for r in pareto])

    # Display as table
    equations_md = "## Global Pareto Front\n\n| C | L | Equation |\n|---|---|---|\n"
    for _, row in pareto_df.iterrows():
        latex_eq = sp.latex(row['sympy_format'])
        equations_md += f"| {row['complexity']} | {row['loss']:.6f} | ${latex_eq}$ |\n"

    display(Markdown(equations_md))
    return pareto_df


In [4]:
from collections import Counter

def compute_feature_rank_df(all_equations):
    all_features_in = Counter()
    all_features_used = Counter()

    for row in all_equations.itertuples():
        for f in row.features_in:
            all_features_in[f] += 1
        for f in row.features_used:
            all_features_used[str(f)] += 1

    all_feature_names = set(all_features_in.keys()) | set(all_features_used.keys())
    rank_rows = []
    for feat in sorted(all_feature_names):
        IN = all_features_in.get(feat, 0)
        USED = all_features_used.get(feat, 0)
        SCORE = USED / IN if IN > 0 else 0.0
        rank_rows.append({"feature_name": feat, "SCORE": SCORE, "IN": IN, "USED": USED})

    feature_rank_df = pd.DataFrame(rank_rows)
    feature_rank_df = feature_rank_df.sort_values("SCORE", ascending=False).reset_index(drop=True)
    return feature_rank_df

In [5]:
import json

def update_global_pareto_with_tradeoff_curve(global_pareto):
    outputs_dir = Path("/Users/danielgieseler/Documents/Code/early_stop/outputs")
    new_columns = []
    for _, row in global_pareto.iterrows():
        with open(outputs_dir / row['run_id'] / 'tradeoff_curve.json', 'r') as f:
            tradeoff_curve = json.load(f)
        new_columns.append(tradeoff_curve[str(row['complexity'])])
    global_pareto['tradeoff_curve'] = new_columns
    return global_pareto

In [6]:
def plot_tradeoff_curve(curve):
    """
    Plot the trade-off between early stopping and accuracy using the new 'curve' DataFrame.

    Args:
        curve (pd.DataFrame): DataFrame where each row is a model/complexity, 
                              and 'tradeoff_curve' column holds [[epsilon, avg_step], ...] pairs.
    """
    import plotly.graph_objs as go
    import numpy as np
    from matplotlib.colors import LinearSegmentedColormap, Normalize, to_hex

    # Assume each row is a curve (by complexity). We'll plot each as a series.
    # Extract all complexities for colormap
    complexities = curve['complexity'].tolist()
    min_model, max_model = min(complexities), max(complexities)
    green_red = LinearSegmentedColormap.from_list("greenred", ["#2ca02c", "#d62728"])
    norm = Normalize(vmin=min_model, vmax=max_model)
    # Color by complexity
    model_colors = {c: to_hex(green_red(norm(c))) for c in complexities}

    fig = go.Figure()

    max_y = -np.inf
    max_label = None

    for idx, row in curve.iterrows():
        complexity = row['complexity']
        tradeoff_points = row['tradeoff_curve'] # List of [epsilon, avg_step]
        epsilons = [xy[0] for xy in tradeoff_points]
        # Do NOT elevate avg_steps (already linear)
        avg_steps = [xy[1] for xy in tradeoff_points]
        # Track max for overlay
        y_max = max(avg_steps)
        if y_max > max_y:
            max_y = y_max
            max_label = f"complexity {complexity}"
        fig.add_trace(
            go.Scatter(
                x=epsilons,
                y=avg_steps,
                mode="lines+markers",
                name=f"complexity {complexity}",
                line=dict(color=model_colors[complexity])
            )
        )

    # Overlay "max_step" as a flat line if wanted
    if max_y > 0:
        # Use the sorted and unique x-axis values present (all epsilons)
        all_epsilons = sorted({xy[0] for sublist in curve['tradeoff_curve'] for xy in sublist})
        fig.add_trace(
            go.Scatter(
                x=all_epsilons,
                y=[max_y] * len(all_epsilons),
                mode="lines",
                name="max_step",
                line=dict(color="black", width=2, dash="dot")
            )
        )

    fig.update_layout(
        title="Trade-off: early-stop vs. accuracy",
        xaxis_title="Max Error",
        yaxis_title="Avg Step",
        xaxis_type="log",
        legend_title="Complexity",
        yaxis_type="log"
    )

    fig.show()

In [7]:
all_equations = collect_all_equations()

Attempting to load model from /Users/danielgieseler/Documents/Code/early_stop/outputs/20251201_235222_HhSTqh/checkpoint.pkl...


In [8]:
global_pareto = display_global_pareto_front(all_equations)

## Global Pareto Front

| C | L | Equation |
|---|---|---|
| 1 | 0.000060 | $last_{loss}$ |
| 3 | 0.000029 | $last_{loss} - 0.00537$ |
| 5 | 0.000019 | $- 4.9 \cdot 10^{-6} \delta_{steps} + last_{loss}$ |
| 7 | 0.000019 | $- 5.30796194291437 \cdot 10^{-6} \delta_{steps} + 1.00150225338007 last_{loss}$ |
| 9 | 0.000017 | $\delta_{steps} \left(0.05133 derivative_{3} - 4.2 \cdot 10^{-6}\right) + last_{loss}$ |
| 11 | 0.000015 | $\delta_{steps} \left(derivative_{3} \left(last_{loss} - 0.4148\right) - 4.2 \cdot 10^{-6}\right) + last_{loss}$ |
| 12 | 0.000012 | $- \frac{4.0 \cdot 10^{-7} \delta_{steps}}{e^{5000000.0 derivative_{3}} + 0.07056} + last_{loss}$ |
| 14 | 0.000009 | $- \frac{2.0 \cdot 10^{-7} \delta_{steps}}{{0.0006075}^{last_{loss}} + e^{10000000.0 derivative_{3}}} + last_{loss}$ |
| 16 | 0.000009 | $last_{loss} + \frac{- 1.0 \cdot 10^{-7} \delta_{steps} + derivative_{3}}{{0.0003195}^{last_{loss}} + e^{10000000.0 derivative_{3}}}$ |
| 18 | 0.000009 | $last_{loss} + 0.0003111 + \frac{- 1.0 \cdot 10^{-7} \delta_{steps} + derivative_{3}}{{0.0003195}^{last_{loss}} + e^{10000000.0 derivative_{3}}}$ |
| 20 | 0.000009 | $\frac{\delta_{steps} \left(derivative_{3} \left(last_{loss} - 0.455\right) - 1.97 \cdot 10^{-6}\right)}{5.86498544246967 e^{5000000.0 derivative_{3}} + 0.3853} + last_{loss}$ |
| 22 | 0.000008 | $\frac{\delta_{steps} \left(derivative_{3} \left(last_{loss} - 0.455\right) - 1.97 \cdot 10^{-6}\right)}{{0.1326}^{last_{loss}} + 5.86498544246967 e^{5000000.0 derivative_{3}}} + last_{loss}$ |


In [9]:
889.5 - 748, 2457 - 2311, 889.5 - 1685, 2451 - 1517

(141.5, 146, -795.5, 934)

In [10]:
global_pareto = update_global_pareto_with_tradeoff_curve(global_pareto)
plot_tradeoff_curve(global_pareto)

In [33]:
global_pareto

Unnamed: 0,complexity,loss,sympy_format,run_id,tradeoff_curve
0,1,6e-05,last_loss,20251201_235222_HhSTqh,"[[0.001, 3299.7894736842104], [0.002, 3049.315..."
1,3,2.9e-05,last_loss - 0.00537,20251201_235222_HhSTqh,"[[0.001, 4265.0], [0.002, 4265.0], [0.004, 426..."
2,5,1.9e-05,-4.9e-6*delta_steps + last_loss,20251201_235222_HhSTqh,"[[0.001, 3682.315789473684], [0.002, 3081.7368..."
3,7,1.9e-05,-5.30796194291437e-6*delta_steps + 1.001502253...,20251201_235222_HhSTqh,"[[0.001, 3603.9473684210525], [0.002, 3019.526..."
4,9,1.7e-05,delta_steps*(0.05133*derivative_3 - 4.2e-6) + ...,20251201_235222_HhSTqh,"[[0.001, 3757.0], [0.002, 3361.842105263158], ..."
5,11,1.5e-05,delta_steps*(derivative_3*(last_loss - 0.4148)...,20251201_235222_HhSTqh,"[[0.001, 3747.7894736842104], [0.002, 3236.894..."
6,12,1.2e-05,-4.0e-7*delta_steps/(exp(5000000.0*derivative_...,20251201_235222_HhSTqh,"[[0.001, 3641.2631578947367], [0.002, 3228.368..."
7,14,9e-06,-2.0e-7*delta_steps/(0.0006075**last_loss + ex...,20251201_235222_HhSTqh,"[[0.001, 3744.0526315789475], [0.002, 3307.684..."
8,16,9e-06,last_loss + (-1.0e-7*delta_steps + derivative_...,20251201_235222_HhSTqh,"[[0.001, 3606.9473684210525], [0.002, 3183.210..."
9,18,9e-06,last_loss + 0.0003111 + (-1.0e-7*delta_steps +...,20251201_235222_HhSTqh,"[[0.001, 3475.0526315789475], [0.002, 3054.105..."


In [11]:
compute_feature_rank_df(all_equations)

Unnamed: 0,feature_name,SCORE,IN,USED
0,last_loss,1.0,12,12
1,delta_steps,0.833333,12,10
2,derivative_3,0.666667,12,8


In [22]:
all_equations

Unnamed: 0,complexity,loss,equation,score,sympy_format,lambda_format,run_id,features_in,features_used
0,1,6e-05,last_loss,0.0,last_loss,PySRFunction(X=>last_loss),20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}",{last_loss}
1,3,2.9e-05,last_loss - 0.00537,0.36322,last_loss - 0.00537,PySRFunction(X=>last_loss - 0.00537),20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}",{last_loss}
2,5,1.9e-05,last_loss - (delta_steps * 4.9e-6),0.202733,-4.9e-6*delta_steps + last_loss,PySRFunction(X=>-4.9e-6*delta_steps + last_loss),20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, delta_steps}"
3,7,1.9e-05,((delta_steps * -5.3e-6) + last_loss) / 0.9985,0.009464,-5.30796194291437e-6*delta_steps + 1.001502253...,PySRFunction(X=>-5.30796194291437e-6*delta_ste...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, delta_steps}"
4,9,1.7e-05,(delta_steps * ((derivative_3 * 0.05133) - 4.2...,0.051384,delta_steps*(0.05133*derivative_3 - 4.2e-6) + ...,PySRFunction(X=>delta_steps*(0.05133*derivativ...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"
5,11,1.5e-05,last_loss + (delta_steps * (((last_loss + -0.4...,0.062582,delta_steps*(derivative_3*(last_loss - 0.4148)...,PySRFunction(X=>delta_steps*(derivative_3*(las...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"
6,12,1.2e-05,last_loss - ((delta_steps * 4.0e-7) / (exp(der...,0.219816,-4.0e-7*delta_steps/(exp(5000000.0*derivative_...,PySRFunction(X=>-4.0e-7*delta_steps/(exp(50000...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"
7,14,9e-06,last_loss - (delta_steps * (2.0e-7 / ((0.00060...,0.13998,-2.0e-7*delta_steps/(0.0006075**last_loss + ex...,PySRFunction(X=>-2.0e-7*delta_steps/(0.0006075...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"
8,16,9e-06,((derivative_3 - (delta_steps * 1.0e-7)) / ((0...,0.016761,last_loss + (-1.0e-7*delta_steps + derivative_...,PySRFunction(X=>last_loss + (-1.0e-7*delta_ste...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"
9,18,9e-06,((derivative_3 - (delta_steps * 1.0e-7)) / ((0...,0.005714,last_loss + 0.0003111 + (-1.0e-7*delta_steps +...,PySRFunction(X=>last_loss + 0.0003111 + (-1.0e...,20251201_235222_HhSTqh,"{delta_steps, derivative_3, last_loss}","{last_loss, derivative_3, delta_steps}"


In [23]:

a.free_symbols


{delta_steps, derivative_3, last_loss}

In [32]:
import numpy as np

eq = all_equations.iloc[7]['lambda_format']

variable_names = ['delta_steps', 'last_loss', 'derivative_3']

def delta_steps():
    return 50

def last_loss():
    return 0.5

def derivative_3():
    return 0.000000001


functions = {
    'delta_steps': delta_steps,
    'last_loss': last_loss,
    'derivative_3': derivative_3,
}


eq(np.array([[functions[v]() for v in variable_names]]))

array([0.49999034])

In [12]:
# # complexity = n
# # get that equations lambda
# from dataset import get_dataset
# from traditional_run import FM

# C = 9
# feature_syms = global_pareto.loc[global_pareto['complexity'] == C]['sympy_format'].values[0].free_symbols
# feature_names = set(str(s) for s in feature_syms)
# feature_names = feature_names - {'feature_step', 'target_step'}
# print(feature_names)

# df = get_dataset({f: FM[f] for f in feature_names}, path='runs_data.json')





In [13]:
complexities = [5, 6, 7, 8, 9, 10]
for c in complexities:
    subset = global_pareto[global_pareto['complexity'] == c]
    if not subset.empty:
        print(f"Complexity {c}:")
        for sf in subset['sympy_format']:
            print(sf)
        print()


Complexity 5:
-4.9e-6*delta_steps + last_loss

Complexity 7:
-5.30796194291437e-6*delta_steps + 1.00150225338007*last_loss

Complexity 9:
delta_steps*(0.05133*derivative_3 - 4.2e-6) + last_loss

