In [20]:
from model import *
from helpers import *

In [74]:
# dictionary of all model inputs and their default values as [default, min, max]
model_inputs = {
    # ice-albedo feedback
    'Ti': [260, 240, 265],               # ice-covered threshold temperature [K]
    'To': [295, 285, 300],               # ice-free threshold temperature [K]
    'ai': [0.6, 0.5, 0.75],               # ice-covered albedo
    'ao': [0.28, 0.2, 0.35],              # ice-free albedo

    # temperatures
    #'T0': [288, None, None],               # initial temperature [K]
    #'Teq': [288, None, None],              # equilibrium temperature [K]

    # volcanic degassing and weathering
    'V_C': [8.5, 5, 15],              # volcanic degassing [examol/Myr]
    'V_red': [1.7, 0.5, 3],            # volcanic reduced gases [examol/Myr]
    'W_sea': [1.6, 0.0, 2.4],              # seafloor weathering [examol/Myr]
    'n': [0.2, 0.0, 1.0],                  # silicate weathering feedback strength

    # productivity
    'forg': [0.2, 0.1, 0.4],               # fraction of buried carbon that is organic
    #'CPsed': [250, 106, 400],            # sedimentary organic C:P ratio
    'nb': [1, 1.0, 2.0],                   # exponent for organic burial dependence on P

    # oxygen
    #'O20': [0.1, 0.01, 0.1],              # O2 as % of PAL

    # phosphorus
    'P_conc': [2.2, 0.01, 2],           # seawater phosphate concentration [uM]
    'W_pho0': [4e-2, 1e-2, 2e-1],          # phosphorus weathering [examol/Myr]

    # perturbation
    #'C_imb': [0, 0, 40],              # imposed generalized carbon cycle imbalance [examol/Myr]
    'tau': [float('inf'), 1e-1, 80.0],      # e-folding timescale of forcing decay [Myr]
    #'beta': [0, None, None],               # fraction of carbon cycle imbalance happening organically
    'W_LIP': [1e-2, 0, 45],              # imposed LIP C sequestration [examol/Myr] (replaced C_imb)
    'n_LIP': [False, 0, 1],          # LIP silicate weathering feedback strength
    'PC_LIP': [0.009, 0.007, 0.011],       # molar ratio of weatherable P to Ca+Mg in LIP
    #'suppress_Borg': [False, None, None],  # suppress burial enhancement to isolate effect of generalized C imbalance
}


In [75]:
# create Latin hypercube to make a representative sample of parameter space
import pandas as pd
from scipy.stats import qmc

param_names = list(model_inputs.keys()) 
mins = np.array([model_inputs[var][1] for var in model_inputs])
maxs = np.array([model_inputs[var][2] for var in model_inputs])

#log_params = ["W_LIP","tau"] # parameters to sample in log space
log_params = []
for lp in log_params:
    j = param_names.index(lp)
    assert mins[j] > 0, f"Log-param {lp} must have min > 0"

N = 10000  # samples
d = len(param_names)
sampler = qmc.LatinHypercube(d=d, seed=42)
U = sampler.random(N)
X = np.empty_like(U)

for j, name in enumerate(param_names):
    if name in log_params:
        log_min = np.log10(mins[j])
        log_max = np.log10(maxs[j])
        X[:, j] = 10 ** (log_min + U[:, j] * (log_max - log_min))
    else:
        X[:, j] = mins[j] + U[:, j] * (maxs[j] - mins[j])
        
df = pd.DataFrame(X, columns=param_names)


In [76]:
# Run model over sampled parameter combinations and collect metrics

def run_one(params_row,
            dt=1e4/1e6,
            verbose=False,
            # for dynamically guessing t_max:
            Cimb_min_guess=0.4,
            max_t_max=100,
            safety_pad=10,
            min_window=10):
    params = {k: params_row[k] for k in param_names}

    # Dynamically guess t_max using tau and W_LIP (or C_imb if present)
    tau = float(params.get('tau', np.inf))
    C_imb_effect = float(params.get('W_LIP', params.get('C_imb', 0)))

    if np.isfinite(tau) and tau > 0 and C_imb_effect > 0:
        guess_duration = -tau * np.log(Cimb_min_guess / C_imb_effect)
        t_max = min(max(guess_duration + safety_pad, min_window), max_t_max)
    else:
        # Fallback when tau<=0, tau is inf, or no imbalance
        t_max = min_window

    if verbose:
        print(f"Running with t_max = {t_max:.0f} Myr")

    try:
        with warnings.catch_warnings():
                warnings.simplefilter("ignore", UserWarning)
                warnings.simplefilter("ignore", RuntimeWarning)
                results = run_model(t_max=100, dt=dt, **params)

    # catch negative Worg
    except Exception as e:
        if verbose:
            print(f"run_one failed: {e}. Returning NaNs.")
        return {
            'snowball_num': np.nan,
            'end_time': np.nan,
            'sb_dur_first': np.nan,
            'ig_dur_last': np.nan,
            'LIP_volume_end': np.nan,
            't_max_used': t_max,
        }
    
    t = results['t']
    snowball = results['snowball']
    sb_starts, sb_ends, sb_durs, ig_durs = get_times(t, snowball)
    sb_dur = sb_durs[0] if len(sb_durs) > 0 else np.nan
    ig_dur = ig_durs[-1] if len(ig_durs) > 0 else np.nan
    end_time = sb_ends[-1] if len(sb_ends) > 0 else (t[-1] if len(t) > 0 else np.nan)
    # LIP diagnostics
    lip_t, lip_vol, lip_height = LIP_volume(results, verbose=False)
    lip_vol_end = lip_vol[-1] if lip_vol is not None and len(lip_vol) > 0 else np.nan
    res = {
        'snowball_num': len(sb_starts),
        'end_time': end_time,
        'sb_dur_first': sb_dur,
        'ig_dur_last': ig_dur,
        'LIP_volume_end': lip_vol_end,
        't_max_used': t_max,
    }
    if verbose:
        print(f"Result: {len(sb_starts)} snowballs, end_time={end_time:.2f} Myr")
    return res

In [77]:
# test on one row
res = run_one(df.iloc[0],verbose=True)

Running with t_max = 53 Myr
Result: 0 snowballs, end_time=100.00 Myr


In [78]:
# parallelize running over all samples
from joblib import Parallel, delayed
from tqdm.auto import tqdm

n_workers = -1 #-1 uses all available cores

results_list = Parallel(n_jobs=n_workers, backend='loky')(
    delayed(run_one)(df.iloc[i]) for i in tqdm(range(len(df)))
)

metrics_df = df.copy()
metrics_df['snowball_num'] = [r['snowball_num'] for r in results_list]
metrics_df['end_time'] = [r['end_time'] for r in results_list]
metrics_df['sb_dur_first'] = [r['sb_dur_first'] for r in results_list]
metrics_df['ig_dur_last'] = [r['ig_dur_last'] for r in results_list]
metrics_df['LIP_volume_end'] = [r['LIP_volume_end'] for r in results_list]
metrics_df['t_max_used'] = [r['t_max_used'] for r in results_list]

metrics_df.head()



  0%|          | 0/10000 [00:00<?, ?it/s]

Unnamed: 0,Ti,To,ai,ao,V_C,V_red,W_sea,n,forg,nb,...,tau,W_LIP,n_LIP,PC_LIP,snowball_num,end_time,sb_dur_first,ig_dur_last,LIP_volume_end,t_max_used
0,241.203065,290.480342,0.722154,0.239575,7.752906,1.266756,1.193577,0.776621,0.299106,1.123155,...,41.818636,1.107798,0.425656,0.010849,0.0,100.0,,,1848244000000000.0,52.599164
1,259.773614,296.294904,0.709704,0.276356,6.569242,2.967661,1.912327,0.701511,0.353657,1.004181,...,57.866467,22.888426,0.953726,0.008418,,,,,,100.0
2,253.714185,285.510944,0.721863,0.302807,12.11987,0.946131,1.913946,0.214233,0.166857,1.586817,...,9.70532,42.791379,0.829561,0.008294,3.0,23.67,7.37,1.17,1.288609e+16,55.349345
3,263.643294,299.03679,0.67312,0.23234,6.415213,0.908834,2.266631,0.358222,0.303566,1.016443,...,78.644349,6.14488,0.718943,0.00973,6.0,,17.78,3.14,2.316257e+16,100.0
4,259.395913,289.65367,0.570811,0.264585,9.172969,1.846891,1.962668,0.089259,0.104474,1.923277,...,1.471934,0.609021,0.374244,0.009274,,,,,,10.618785


In [79]:
csv_path = 'data/parameter_sweep.csv'
metrics_df.to_csv(csv_path, index=False)