In [None]:
from niagara import Chain, Model, ModelIntrinsicLogProb, NullTransformation, LogisticRegressionCalibrator
from niagara import OpenAIClient, FireworksClient, OneSidedAsymptoticLog, TwoSidedAsymptoticLog
import pickle
import os
os.environ["FIREWORKS_API_KEY"] = "leave-this-line-but-there-is-no-need-to-add-an-API-key"

llama_chain = Chain(
    models = [
        Model(
            model_name=name, 
            thresholds={"reject": -10000, "accept": 0.0},
            conf_signal=ModelIntrinsicLogProb(),
            conf_signal_transform=NullTransformation(),
            conf_signal_calibrator=LogisticRegressionCalibrator()
        )
        for name in ["llama3.2-1b", "llama3.2-3b", "llama3.1-8b", "llama3.1-70b", "llama3.1-405b"]
    ]
)

qwen_oai_chain = Chain(
    models = [
        Model(
            model_name=name, 
            thresholds={"reject": -10000, "accept": 0.0},
            conf_signal=ModelIntrinsicLogProb(),
            conf_signal_transform=NullTransformation(),
            conf_signal_calibrator=LogisticRegressionCalibrator(),
            client=client
        )
        for name, client in [("gpt-4o-mini", None), ("qwen2.5-32b-coder-instruct", None), ("qwen2.5-72b-instruct", None), ("gpt-4o", None)]
    ]
)

NAME = "mmlu"
CHAIN_NAME = "llama_chain"
ALL_MODEL_INDICES = [
    [0, 1],
    [0, 2],
    [1, 2],
    [0, 1, 2],
    [0, 3],
    [1, 3],
    [2, 3],
    [0, 1, 3],
    [0, 2, 3],
    [1, 2, 3],
    [0, 1, 2, 3],
    [0, 4],
    [1, 4],
    [2, 4],
    [3, 4],
    [0, 1, 4],
    [0, 2, 4],
    [0, 3, 4],
    [1, 2, 4],
    [1, 3, 4],
    [2, 3, 4],
    [0, 1, 2, 4],
    [0, 1, 3, 4],
    [0, 2, 3, 4],
    [1, 2, 3, 4],
    [0, 1, 2, 3, 4]
]

TRANSFORM = OneSidedAsymptoticLog() if NAME in {'mmlu', 'medmcqa'} else TwoSidedAsymptoticLog()

if CHAIN_NAME == "llama_chain":
    CHAIN = llama_chain
elif CHAIN_NAME == "qwen_oai_chain":
    CHAIN = qwen_oai_chain

# Update the transformation for the chain
for model in CHAIN.models:
    model.conf_signal_transform = TRANSFORM

with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_train.pkl', 'rb') as f:
    results_train = pickle.load(f)
with open(f'../benchmarks/data/{NAME}/chain_results/{NAME}_full_{CHAIN_NAME}_results_test.pkl', 'rb') as f:
    results_test = pickle.load(f)

# Get the train and test data

### Compute calibrated confidence values

process_scores = lambda scores: sum(scores.values()) >= 20

if NAME=="xsum":
    raw_corr_train = { k: [process_scores(x) for x in v] for k,v in results_train['model_correctness'].items() }
else:
    raw_corr_train= results_train['model_correctness']

raw_conf_train = results_train['raw_confidences']

corr_train = [
    raw_corr_train[model_name] for model_name in CHAIN.model_names
]

transformed_conf_train = [ 
    list(TRANSFORM.transform_confidence_signal(raw_conf_train[model_name]))
        for model_name in CHAIN.model_names
]

calibration_data = [
    {"correctness": corr, "transformed_confidence": conf} 
        for (corr, conf, model_name) 
            in zip(corr_train, transformed_conf_train, CHAIN.model_names)
]

CHAIN.calibrate(calibration_data)

calibrated_conf_train = [
    list(
        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(
            transformed_conf_train[model_idx]
        )
    )
    for model_idx in range(len(CHAIN.model_names))
]

### Compute test data

if NAME=="xsum":
    raw_corr_test = { k: [process_scores(x) for x in v] for k,v in results_test['model_correctness'].items() }
else:
    raw_corr_test= results_test['model_correctness']

raw_conf_test = results_test['raw_confidences']

corr_test = [
    raw_corr_test[model_name] for model_name in CHAIN.model_names
]

transformed_conf_test = [ 
    list(TRANSFORM.transform_confidence_signal(raw_conf_test[model_name]))
        for model_name in CHAIN.model_names
]

calibrated_conf_test = [
    list(
        CHAIN.models[model_idx].conf_signal_calibrator.calibrate_confidence_signal(
            transformed_conf_test[model_idx]
        )
    )
    for model_idx in range(len(CHAIN.model_names))
]

In [24]:
from scipy.integrate import quad
from scipy.interpolate import interp1d
import numpy as np

### Define function for computing area under the curve (AUC)

def compute_auc(x, y, x_min, x_max, integration_limit=200, method='linear'):
    """
    Compute area under the curve of the function y = f(x), as defined by point samples.
    """
    x = np.array(x)
    y = np.array(y)
    order = np.argsort(x)
    x = x[order]
    y = y[order]
    f = interp1d(x, y, kind=method, bounds_error=False, fill_value='extrapolate')
    
    return quad(f, x_min, x_max, limit=integration_limit)


In [25]:
import pandas as pd
import numpy as np
from optimize_cascade import get_expected_uncumulated_costs

# raw_model_costs = { 
#     model_name: CHAIN.models[i].cpm_tokens 
#         for i, model_name in enumerate(CHAIN.model_names) 
# }

if CHAIN_NAME == "llama_chain":
    raw_model_costs = {
        "llama3.2-1b": {"in": 0.10, "out": 0.10},
        "llama3.2-3b": {"in": 0.10, "out": 0.10},
        "llama3.1-8b": {"in": 0.20, "out": 0.20},
        "llama3.1-70b": {"in": 0.90, "out": 0.90},
        "llama3.1-405b": {"in": 3.00, "out": 3.00},
    }
elif CHAIN_NAME == "qwen_oai_chain":
    raw_model_costs = {
        "gpt-4o-mini": {"in": 0.15, "out": 0.60},
        "qwen2.5-32b-coder-instruct": {"in": 0.90, "out": 0.90},
        "qwen2.5-72b-instruct": {"in": 0.90, "out": 0.90},
        "gpt-4o": {"in": 2.50, "out": 10.00},
    }

expected_uncumulated_costs_train = get_expected_uncumulated_costs(raw_model_costs, results_train)
expected_uncumulated_costs_test = get_expected_uncumulated_costs(raw_model_costs, results_test)

def compute_utilization_from_conditional_deferral_probs(conditional_deferral_probs):
    return [
        np.prod(conditional_deferral_probs[:i]) * (1 - conditional_deferral_probs[i])
            for i in range(0,len(conditional_deferral_probs))
    ]

def compute_utilization_and_error(T, model_indices, calibrated_conf_train, conditional_deferral_probs=None):
    """ Estimate the probabilities that each model return the query. """
    # use local indices for accessing calibrated confidences of the selected models
    cal_conf_tr = np.array(calibrated_conf_train).transpose()[:, model_indices]

    utilizations = []
    conditional_corrs = []
    unconditional_corrs = []

    # add utilization of the first model
    first_model_accepts = cal_conf_tr[:, 0] > T[0]
    utilizations.append(np.mean(first_model_accepts))
    conditional_corrs.append(np.mean(cal_conf_tr[first_model_accepts, 0]))
    unconditional_corrs.append(np.mean(cal_conf_tr[:,0]))

    # get utilizations for the second and subsequent models
    for i in range(1, len(model_indices)):
        prior_models_delegate = np.all(cal_conf_tr[:,:i] <= np.array(T)[np.newaxis, :i], axis=1)
        this_model_accepts = (cal_conf_tr[:,i] > T[i]) if i < len(model_indices)-1 else np.array([True])

        # utilization rate
        utilization = np.mean(prior_models_delegate & this_model_accepts)
        utilizations.append(utilization)

        # error conditioned on returning the query
        conditional_corr = np.mean(cal_conf_tr[prior_models_delegate & this_model_accepts,i])
        conditional_corrs.append(conditional_corr)

        # unconditional error
        unconditional_corr = np.mean(cal_conf_tr[:,i])
        unconditional_corrs.append(unconditional_corr)

    if conditional_deferral_probs is not None:
        utilizations = compute_utilization_from_conditional_deferral_probs(conditional_deferral_probs)
    
    # print(f"Utilization: {utilizations}\nConditional Errors:{conditional_corrs}")
    return utilizations, conditional_corrs, unconditional_corrs


def get_ecorr_ecost_estimates(threshold_list, model_indices, calibrated_conf_train):
    """ Estimate the expected probability of correctness and expected cost for all thresholds. """
    outputs = []

    for i in range(len(threshold_list)):
        T = threshold_list[i]

        util, cond_corr, uncond_corr = compute_utilization_and_error(T, model_indices, calibrated_conf_train)
        costs = np.array(expected_uncumulated_costs_train)[model_indices]

        ecorr_estimate = np.nansum(np.array(util) * np.array(cond_corr))
        ecost_estimate = np.nansum(np.array(util) * np.cumsum(costs))

        outputs.append((ecost_estimate, ecorr_estimate))
    
    return outputs

In [None]:
### Setup for getting the probabilistic model
import os
import pickle
from time import time
from itertools import product
from paretoset import paretoset
from optimize_cascade import train_probability_model

def get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h = 0.025, eps=0.01):
    """ Get optimal thresholds via grid search. """
    # Compute the grid for each threshold
    threshold_grids = [ np.quantile(data_train[idx], q=np.arange(0+eps,1-eps,quantile_h)) for idx in model_indices[:-1] ]

    # Get all candidates
    threshold_candidates = [ np.array(x) for x in product(*threshold_grids) ]

    # Get cost and correctness for candidates
    ecost_ecorr = get_ecorr_ecost_estimates(threshold_candidates, model_indices, data_train)

    # Compute the Pareto set of solutions
    df = pd.DataFrame(ecost_ecorr, columns=['expected_cost','expected_correctness'])
    pareto_mask = paretoset(df, sense=["min", "max"])
    pareto_df = df.loc[pareto_mask]

    # Return the Pareto solutions AND the optimal thresholds
    opt_tholds = [ threshold_candidates[i] for i in range(len(pareto_mask)) if pareto_mask[i] ]
    return pareto_df, opt_tholds

start = time()
filename = f"data/probabilistic_model_results_{NAME}.pkl"
SAVE_TO_FILE = False

if os.path.exists(filename):
    with open(filename, 'rb') as file:
        prob_results = pickle.load(file)
else:
    prob_results = train_probability_model(full_data=np.array(calibrated_conf_train).transpose())
    if SAVE_TO_FILE:
        with open(filename, 'wb') as file:
            pickle.dump(prob_results, file)

stop = time()
print(stop-start)

In [27]:
import numpy as np

def fill_parameter_gaps_adaptively(parameters_list, model_indices, data_train, max_prob_gap=0.1):
    """
    Fill gaps between parameters by adding midpoints when probability mass gap is too large.
    
    Args:
        parameters_list: List of parameter vectors (numpy arrays)
        model_indices: Indices to select relevant columns from data_train
        data_train: Training data array (n_obs x n_features)
        max_prob_gap: Maximum allowed probability mass between consecutive parameters
        
    Returns:
        List of parameters including added midpoints
    """
    # Return early if we have 0 or 1 parameters
    if len(parameters_list) <= 1:
        return parameters_list
        
    # Convert to numpy array for easier manipulation
    params_list = np.array(parameters_list)
    
    # Get the observed values
    observed_values = np.array([data_train[i] for i in model_indices[:-1]]).transpose()
    
    while True:
        added_point = False
        
        # Look through consecutive pairs
        for i in range(len(params_list)-1):
            # For each component of the parameter vector
            max_prob = 0
            for j in range(params_list.shape[1]):
                # Get values between this pair of parameters for this component
                lower = min(params_list[i][j], params_list[i+1][j])
                upper = max(params_list[i][j], params_list[i+1][j])
                
                # Calculate probability mass between the parameters
                prob_mass = np.mean((observed_values[:, j] > lower) & 
                                  (observed_values[:, j] < upper))
                max_prob = max(max_prob, prob_mass)
            
            if max_prob > max_prob_gap:
                # Calculate midpoint
                midpoint = (params_list[i] + params_list[i+1]) / 2
                
                # Insert midpoint
                params_list = np.insert(params_list, i+1, midpoint, axis=0)
                added_point = True
                break  # Start over since we modified the array
        
        # If we didn't add any points, we're done
        if not added_point:
            break
    
    return [np.array(x) for x in params_list.tolist()]

In [None]:
from optimize_cascade import profile_cascade, profile_cascade_adaptively, make_full_data, score_cascade
from tqdm import tqdm
import matplotlib.pyplot as plt

with open("data/cascade_comparison_records.pkl", "rb") as file:
    ALL_RECORDS = pickle.load(file)

RUN_GRID = True

ALL_RECORDS = []
SAVE_TO_FILE = False

for model_indices in tqdm(ALL_MODEL_INDICES):
    data_train = calibrated_conf_train

    # Optimize the cascade on the data
    start = time()
    cascade_record = profile_cascade_adaptively(
        model_indices, 
        expected_uncumulated_costs_train, 
        prob_results, 
        start_sensitivities=[0, 1e-10, 1e-8, 1e-6], # or did it only go to 1e-7 at first?
        cost_threshold_multiplier=1.25,
        stop_val=1000,
        max_iterations=24, # used to be 100
        sensitivity_increase_factor=2.0 # used to be 2, # used to be 1.2
    )

    opt_tholds_cts_optim = cascade_record['optimal_thresholds']
    opt_tholds_cts_optim = fill_parameter_gaps_adaptively(
        opt_tholds_cts_optim, model_indices, data_train, max_prob_gap=0.1
    )
    stop = time()

    continuous_optim_time = stop-start
    print(f"Continuous optimization took {continuous_optim_time}s")

    # Optimize the thresholds via grid search
    if RUN_GRID:
        quantile_h = 0.025
        start = time()
        pareto_df, opt_tholds_grid_search = get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h)
        stop = time()
        grid_search_time = stop-start
        print(f"Grid search took {grid_search_time}s")

    # Score both solutions on the test data
    test_data = {
        'calib_conf': make_full_data(calibrated_conf_test), 
        'corr': make_full_data(corr_test) 
    }

    scores_continuous_optim = [
        score_cascade(T, model_indices, expected_uncumulated_costs_test, test_data)
            for T in opt_tholds_cts_optim #cascade_record['optimal_thresholds']
    ]
    ecost_ecorr_continuous_optim = [ 
        (rec['expected_cost_test'], rec['expected_correctness_test']) for rec in scores_continuous_optim
    ]

    if RUN_GRID:
        scores_grid_search = [
            score_cascade(T, model_indices, expected_uncumulated_costs_test, test_data)
                for T in opt_tholds_grid_search
        ]
        ecost_ecorr_gridsearch = [ (rec['expected_cost_test'], rec['expected_correctness_test']) for rec in scores_grid_search ]

    # Calculate the area under the curve for both methods (make sure to integrate between same min and max cost)
    min_cost_continuous_optim = min(ecost_ecorr_continuous_optim, key= lambda s: s[0])[0]
    max_cost_continuous_optim = max(ecost_ecorr_continuous_optim, key= lambda s: s[0])[0]

    min_cost_gridsearch = min(ecost_ecorr_gridsearch, key= lambda s: s[0])[0]
    max_cost_gridsearch = max(ecost_ecorr_gridsearch, key= lambda s: s[0])[0]

    min_cost_overall = max(min_cost_continuous_optim, min_cost_gridsearch)
    max_cost_overall = min(max_cost_continuous_optim, max_cost_gridsearch)

    auc_cts = compute_auc(*zip(*ecost_ecorr_continuous_optim), x_min=min_cost_overall, x_max=max_cost_overall)[0]
    auc_grid = compute_auc(*zip(*ecost_ecorr_gridsearch), x_min=min_cost_overall, x_max=max_cost_overall)[0]

    # Gather all the results: performance, time, resolution

    record_cts = {
        "benchmark": NAME,
        "cascade": model_indices,
        "cascade_len": len(model_indices),
        "method": "continuous_optimization",
        "performance": auc_cts,
        "auc_bounds": [min_cost_overall, max_cost_overall],
        "auc_norm": 1 - (auc_cts/(max_cost_overall-min_cost_overall)),
        "time": continuous_optim_time,
        "n_grid": len(opt_tholds_cts_optim),
        "data": ecost_ecorr_continuous_optim,
        "scores": scores_continuous_optim,
        "n_obs": len(np.unique(next(zip(*ecost_ecorr_continuous_optim))))
    }

    record_grid = {
        "benchmark": NAME,
        "cascade": model_indices,
        "cascade_len": len(model_indices),
        "method": "gridsearch",
        "performance": auc_grid,
        "auc_bounds": [min_cost_overall, max_cost_overall],
        "auc_norm": 1 - (auc_grid/(max_cost_overall-min_cost_overall)),
        "time": grid_search_time,
        "n_grid": len(opt_tholds_grid_search),
        "data": ecost_ecorr_gridsearch,
        "scores": scores_grid_search,
        "n_obs": len(np.unique(next(zip(*ecost_ecorr_gridsearch))))
    }

    indices_of_these_records = [
        i for i, record in enumerate(ALL_RECORDS) if 
            (record['benchmark'] == NAME) and (record['cascade'] == model_indices)
    ]
    if len(indices_of_these_records) > 0: # check if records already exist; if yes, overwrite
        for idx in indices_of_these_records:
            assert (ALL_RECORDS[idx]['benchmark'] == NAME) and ((ALL_RECORDS[idx]['cascade'] == model_indices))
            if ALL_RECORDS[idx]['method'] == 'gridsearch':
                ALL_RECORDS[idx] = record_grid
            elif ALL_RECORDS[idx]['method'] == 'continuous_optimization':
                ALL_RECORDS[idx] = record_cts
    else: 
        ALL_RECORDS.append(record_cts)
        ALL_RECORDS.append(record_grid)

    plt.figure()
    plt.title("->".join([str(x) for x in model_indices]))
    plt.scatter(*zip(*ecost_ecorr_gridsearch), label='grid')
    plt.scatter(*zip(*ecost_ecorr_continuous_optim), label='cts')
    plt.legend()
    plt.show()


# Save all records to file
if SAVE_TO_FILE:
    with open("data/cascade_comparison_records.pkl", "wb") as file:
        pickle.dump(ALL_RECORDS, file)

In [None]:
from optimize_cascade import profile_cascade, profile_cascade_adaptively, make_full_data, score_cascade
from tqdm import tqdm
import matplotlib.pyplot as plt

GRID_RESOLUTIONS = [ 0.1, 0.05, 0.0333, 0.025, 0.02 ]
SENSITIVITY_INCREASE_FACTORS = [ 3.0, 2.0, 1.5, 1.3, 1.2 ]

ALL_RECORDS = []
SAVE_TO_FILE = False

for model_indices in tqdm(ALL_MODEL_INDICES):
    data_train = calibrated_conf_train

    for sens_increase, grid_h in zip(SENSITIVITY_INCREASE_FACTORS, GRID_RESOLUTIONS):
        max_iter = 24*np.log(2)/np.log(sens_increase)

        # Optimize the cascade on the data
        start = time()
        cascade_record = profile_cascade_adaptively(
            model_indices, 
            expected_uncumulated_costs_train, 
            prob_results, 
            start_sensitivities=[0, 1e-10, 1e-8, 1e-6],
            cost_threshold_multiplier=1.25,
            stop_val=1000,
            max_iterations=max_iter,
            sensitivity_increase_factor=sens_increase
        )
        opt_tholds_cts_optim = cascade_record['optimal_thresholds']
        original_n_grid = len(cascade_record['optimal_thresholds'])
        # opt_tholds_cts_optim = fill_parameter_gaps(opt_tholds_cts_optim, max_gap=0.05)
        opt_tholds_cts_optim = fill_parameter_gaps_adaptively(
            opt_tholds_cts_optim, model_indices, data_train, max_prob_gap=0.1
        )
        stop = time()
        continuous_optim_time = stop-start
        print(f"Continuous optimization took {continuous_optim_time}s")

        # Optimize the thresholds via grid search
        quantile_h = grid_h
        start = time()
        pareto_df, opt_tholds_grid_search = get_optimal_thresholds_using_grid_search(model_indices, data_train, quantile_h)
        stop = time()
        grid_search_time = stop-start
        print(f"Grid search took {grid_search_time}s")

        # Gather all the results: performance, time, resolution

        record_cts = {
            "benchmark": NAME,
            "cascade": model_indices,
            "cascade_len": len(model_indices),
            "method": "continuous_optimization",
            "time": continuous_optim_time,
            "n_grid": len(opt_tholds_cts_optim),
            "sens_increase": sens_increase,
            "original_n_grid": original_n_grid,
        }

        record_grid = {
            "benchmark": NAME,
            "cascade": model_indices,
            "cascade_len": len(model_indices),
            "method": "gridsearch",
            "time": grid_search_time,
            "n_grid": len(opt_tholds_grid_search),
            "full_grid_size": len(opt_tholds_grid_search),
            "original_n_grid": int(1/grid_h),
        }

        ALL_RECORDS.append(record_cts)
        ALL_RECORDS.append(record_grid)


# Save all records to file
if SAVE_TO_FILE:
    with open("data/cascade_runtime.pkl", "wb") as file:
        pickle.dump(ALL_RECORDS, file)

In [34]:
import numpy as np
import statsmodels.api as sm

def fit_linear_trend_with_stderr(x, y, alpha=0.05, x_min=None, x_max=None):
    """
    Fit linear regression and return predictions with standard errors at unique x values.
    Returns prediction intervals for individual observations, not just mean prediction.
    """
    # Convert to numpy arrays
    X = np.array(x)
    Y = np.array(y)
    
    # Get unique, sorted x values
    x_unique = np.sort(np.unique(X))
    
    # Add x_min and x_max if specified
    if x_min is not None:
        x_unique = np.insert(x_unique, 0, x_min)
    if x_max is not None:
        x_unique = np.append(x_unique, x_max)
    
    # Add constant for statsmodels
    X = sm.add_constant(X)
    X_unique = sm.add_constant(x_unique)
    
    # Fit model
    model = sm.OLS(Y, X)
    results = model.fit()
    
    # Get predictions and standard errors
    pred_ints = results.get_prediction(X_unique)
    y_pred = pred_ints.predicted_mean
    stderr = pred_ints.se_obs
    
    # Get prediction intervals for individual observations
    pred_data = pred_ints.summary_frame(alpha=alpha)
    pi_lower = pred_data['obs_ci_lower']
    pi_upper = pred_data['obs_ci_upper']
    
    return x_unique, y_pred, stderr, pi_lower, pi_upper

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams

# Set seaborn style first
sns.set_style("white")
sns.set_context("paper", font_scale=1.0)

# Then matplotlib settings
rcParams['text.usetex'] = True
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = ['Computer Modern Roman']
rcParams['font.size'] = 10

df_time = pd.DataFrame(ALL_RECORDS)
df_time_grid = df_time[df_time['method'] == 'gridsearch']
df_time_cts = df_time[df_time['method'] == 'continuous_optimization']

# Get mean runtimes for each cascade len
mean_time_cts = df_time_cts.groupby(by='cascade_len')['time'].mean()
mean_time_grid = df_time_grid.groupby(by='cascade_len')['time'].mean()

fig, ax = plt.subplots(figsize=(5,4))
ax.set_xscale('log')
ax.set_yscale('log')

color_cts = 'tab:blue'
for k in range(2,5+1):
    k_cascades = df_time_cts['cascade_len'] == k
    x_unique, y_pred, std_error, pi_lower, pi_upper = fit_linear_trend_with_stderr(
        np.log(df_time_cts[k_cascades]['original_n_grid']), 
        np.log(df_time_cts[k_cascades]['time']),
        x_min = np.log(10),
        x_max = np.log(100)
    )
    ax.plot(np.exp(x_unique), np.exp(y_pred), color=color_cts, linewidth=1)
    ax.fill_between(
        np.exp(x_unique), 
        np.exp(y_pred - std_error),  # Lower bound
        np.exp(y_pred + std_error),  # Upper bound
        color=color_cts, 
        alpha=0.2,  # Transparency
    )

color_grid = 'gray'
for k in range(2,5+1):
    k_cascades = df_time_grid['cascade_len'] == k
    x_unique, y_pred, std_error, pi_lower, pi_upper = fit_linear_trend_with_stderr(
        np.log(df_time_grid[k_cascades]['original_n_grid']), 
        np.log(df_time_grid[k_cascades]['time']),
        x_min = np.log(10),
        x_max = np.log(100)
    )
    ax.plot(np.exp(x_unique), np.exp(y_pred), color=color_grid, linewidth=1, zorder=-1)
    ax.fill_between(
        np.exp(x_unique), 
        np.exp(y_pred - std_error),  # Lower bound
        np.exp(y_pred + std_error),  # Upper bound
        color=color_grid, 
        alpha=0.2,  # Transparency
    )

ax.text(50, 2500, "$k=5$", color=color_grid, fontweight='bold').set_rotation(24)
ax.text(60, 0.0045, "$k=2$", color=color_grid, fontweight='bold').set_rotation(3.5)
ax.text(10, 23, "$k=5$", color=color_cts, fontweight='bold').set_rotation(3)
ax.text(10, 0.06, "$k=2$", color=color_cts, fontweight='bold').set_rotation(3)

ax.text(9.9, 1.0, "continuous", color=color_cts, fontweight='bold').set_rotation(5)
ax.text(48, 500, "grid search", color=color_grid, fontweight='bold').set_rotation(22)
# ax.spines["top"].set_visible(False)
# ax.spines["right"].set_visible(False)
ax.set_ylabel("Runtime (s)", fontsize=14)
ax.set_xlabel('Resolution of Cost-Error Curve', fontsize=14)

for tick in ax.get_xticklabels():
    print(f"Position: {tick.get_position()}, Text: {tick.get_text()}")
# Then, for just the visible ticks
visible_ticks = [tick.get_position()[0] for tick in ax.get_xticklabels() if tick.get_visible()]
visible_labels = [f'$\\mathdefault{{{1/x:.2f}}}$' for x in visible_ticks]

ax.set_xticks(visible_ticks)
ax.set_xticklabels(visible_labels)
ax.set_xlim([9,110])

plt.tight_layout()

ax.set_title("Continuous optimization scales linearly independent of cascade length $k$", fontsize=14)