# 18. PySR Complex Formula Search

## Goal
Find more complex formulas with **depth >= 2** to improve AUC.

## Strategy
1. **parsimony=0**: No complexity penalty
2. **maxsize=50**: Allow larger formulas
3. **Use SHAP top 8 features**: Reduce search space
4. **Longer search**: niterations=400, timeout=1800s

## Previous Results (depth=1)
- Hypertension: 0.684 (exp(SBP_T1))
- Hyperglycemia: 0.899 (FBG_T2)
- Dyslipidemia: 0.795 (exp(TC_T1))

## Date: 2026-01-13

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score, balanced_accuracy_score
from scipy import stats
import time

print("Packages loaded")

Packages loaded


In [2]:
# Load data
df = pd.read_csv('../../data/01_primary/SUA/processed/SUA_CVDs_wide_format.csv')
print(f"Data: {len(df):,} patients")

# SHAP top features for each disease (from shap_feature_importance.csv)
shap_top_features = {
    'Hypertension': ['SBP_T1', 'SBP_T2', 'Delta1_SBP', 'GFR_T1', 'DBP_T1', 'Delta1_GFR', 'UA_T2', 'Delta1_FBG'],
    'Hyperglycemia': ['FBG_T2', 'FBG_T1', 'Delta1_FBG', 'Delta1_TC', 'TC_T1', 'Delta1_GFR', 'Age', 'BMI_T1'],
    'Dyslipidemia': ['TC_T2', 'TC_T1', 'Age', 'Delta1_GFR', 'Delta1_TC', 'Delta1_SBP', 'UA_T1', 'Delta1_FBG']
}

# Common features (union of top 8 from each disease)
all_top_features = list(set(
    shap_top_features['Hypertension'] + 
    shap_top_features['Hyperglycemia'] + 
    shap_top_features['Dyslipidemia']
))
print(f"\nUnion of SHAP top features: {len(all_top_features)}")
print(all_top_features)

Data: 6,056 patients

Union of SHAP top features: 16
['TC_T1', 'UA_T2', 'Age', 'FBG_T1', 'SBP_T1', 'BMI_T1', 'SBP_T2', 'GFR_T1', 'Delta1_FBG', 'DBP_T1', 'Delta1_SBP', 'UA_T1', 'FBG_T2', 'Delta1_GFR', 'TC_T2', 'Delta1_TC']


In [3]:
# Define targets
target_cols = {
    'Hypertension': 'hypertension_T3',
    'Hyperglycemia': 'hyperglycemia_T3',
    'Dyslipidemia': 'dyslipidemia_T3'
}

targets = {}
for name, col in target_cols.items():
    targets[name] = (df[col] == 2).astype(int)

print("Class distribution:")
for name, y in targets.items():
    print(f"  {name}: {y.mean()*100:.1f}% positive")

Class distribution:
  Hypertension: 16.7% positive
  Hyperglycemia: 5.5% positive
  Dyslipidemia: 6.0% positive


In [4]:
# Import PySR
try:
    from pysr import PySRRegressor
    print("PySR available")
except ImportError:
    print("PySR not installed")
    raise

Detected Jupyter notebook. Loading juliacall extension. Set `PYSR_AUTOLOAD_EXTENSIONS=no` to disable.
PySR available


In [5]:
def run_pysr_experiment(X, y, disease_name, feature_names, 
                        parsimony=0, maxsize=50, niterations=400, timeout=1800):
    """
    Run PySR with aggressive settings for complex formulas.
    
    Key changes from previous experiments:
    - parsimony=0: No complexity penalty at all
    - maxsize=50: Allow much larger formulas
    - niterations=400: More generations
    - timeout=1800: 30 min per run
    """
    print(f"\n{'='*60}")
    print(f"{disease_name}")
    print(f"{'='*60}")
    print(f"Features: {feature_names}")
    print(f"Settings: parsimony={parsimony}, maxsize={maxsize}, niterations={niterations}")
    
    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # Standardize
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    X_train_df = pd.DataFrame(X_train_scaled, columns=feature_names)
    X_test_df = pd.DataFrame(X_test_scaled, columns=feature_names)
    y_train_reset = y_train.reset_index(drop=True)
    y_test_reset = y_test.reset_index(drop=True)
    
    # PySR model with aggressive settings
    model = PySRRegressor(
        niterations=niterations,
        populations=30,              # More populations for diversity
        population_size=50,          # Larger populations
        binary_operators=["+", "-", "*", "/"],
        unary_operators=["exp", "log", "sqrt", "abs", "square", "neg"],
        maxsize=maxsize,
        parsimony=parsimony,         # No complexity penalty!
        timeout_in_seconds=timeout,
        temp_equation_file=True,
        verbosity=1,
        random_state=42,
        # Encourage exploration
        weight_mutate_constant=0.5,
        weight_mutate_operator=0.5,
        weight_add_node=0.5,
        weight_insert_node=0.5,
        weight_delete_node=0.3,
        weight_simplify=0.1,
        # Cross-over for combining good solutions
        crossover_probability=0.1,
        # Tournament selection pressure
        tournament_selection_n=15,
    )
    
    start_time = time.time()
    model.fit(X_train_df, y_train_reset)
    elapsed = time.time() - start_time
    
    # Get all equations from Pareto front
    print(f"\nTraining completed in {elapsed/60:.1f} minutes")
    print(f"\nPareto front (top equations by complexity):")
    
    equations_df = model.equations_
    if equations_df is not None and len(equations_df) > 0:
        # Sort by complexity (ascending) and show top 10
        equations_df = equations_df.sort_values('complexity')
        for i, row in equations_df.head(10).iterrows():
            print(f"  Complexity {row['complexity']:2.0f}: loss={row['loss']:.4f}, {row['equation']}")
    
    # Best equation
    best_eq = str(model.sympy())
    print(f"\nBest equation: {best_eq}")
    
    # Evaluate on test set
    y_pred_raw = model.predict(X_test_df)
    y_pred_proba = np.clip(y_pred_raw, 0, 1)
    
    threshold = y_train_reset.mean()
    y_pred = (y_pred_proba >= threshold).astype(int)
    
    try:
        auc = roc_auc_score(y_test_reset, y_pred_proba)
    except:
        auc = 0.5
    
    tn, fp, fn, tp = confusion_matrix(y_test_reset, y_pred).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    print(f"\nTest Results:")
    print(f"  AUC: {auc:.3f}")
    print(f"  Sensitivity: {sensitivity:.3f}")
    print(f"  Specificity: {specificity:.3f}")
    
    return {
        'disease': disease_name,
        'auc': auc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'formula': best_eq,
        'time_min': elapsed / 60,
        'equations_df': equations_df,
        'model': model
    }

## Experiment 1: Disease-specific top features (parsimony=0)

In [6]:
# Run for each disease with disease-specific SHAP top 8 features
results = []

for disease_name, y in targets.items():
    features = shap_top_features[disease_name]
    X = df[features].copy()
    
    result = run_pysr_experiment(
        X, y, disease_name, features,
        parsimony=0,      # No penalty
        maxsize=50,       # Large formulas OK
        niterations=400,  # More iterations
        timeout=1800      # 30 min
    )
    results.append(result)


Hypertension
Features: ['SBP_T1', 'SBP_T2', 'Delta1_SBP', 'GFR_T1', 'DBP_T1', 'Delta1_GFR', 'UA_T2', 'Delta1_FBG']
Settings: parsimony=0, maxsize=50, niterations=400
Compiling Julia backend...


[ Info: Started!



Expressions evaluated per second: 0.000e+00
Head worker occupation: 0.0%
Progress: 0 / 12000 total iterations (0.000%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
---------------------------------------------------------------------------------------------------
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 5.980e+03
Head worker occupation: 69.9%. This is high, and will prevent efficient resource usage. Increase `ncycles_per_iteration` to reduce load on head worker.
Progress: 22 / 12000 total iterations (0.183%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
2           1.390e-01  7.971e+00  y = abs(-0.1667)
4           1.390e-01  5.960e-08  y = abs(square(abs(-0.4084)))
5           1.390e-01  -0.000e+00  y = abs(abs(abs(square

[ Info: Started!



Expressions evaluated per second: 0.000e+00
Head worker occupation: 0.0%
Progress: 0 / 12000 total iterations (0.000%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
---------------------------------------------------------------------------------------------------
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 7.330e+02
Head worker occupation: 99.2%. This is high, and will prevent efficient resource usage. Increase `ncycles_per_iteration` to reduce load on head worker.
Progress: 9 / 12000 total iterations (0.075%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
2           5.232e-02  7.971e+00  y = abs(0.062893)
3           5.232e-02  1.687e-05  y = -0.036522 - -0.099357
4           4.089e-02  2.465e-01  y = abs(0.1027 * FBG_T2)
8 

[ Info: Started!



Expressions evaluated per second: 0.000e+00
Head worker occupation: 0.0%
Progress: 0 / 12000 total iterations (0.000%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
---------------------------------------------------------------------------------------------------
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 1.660e+01
Head worker occupation: 0.0%
Progress: 1 / 12000 total iterations (0.008%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
2           5.837e-02  7.971e+00  y = abs(0.10731)
3           5.809e-02  4.910e-03  y = sqrt(square(0.10421))
4           5.618e-02  3.334e-02  y = neg(0.59752 * -0.11483)
5           5.610e-02  1.422e-03  y = abs(-0.11483 * log(0.59752))
7           5.476e-02  1.209e-02  y = abs(-0.11483 * abs

In [7]:
# Summary
print("\n" + "="*80)
print("Summary: Complex Formula Search")
print("="*80)

print("\n| Disease | AUC | Formula | Time |")
print("|---------|-----|---------|------|")
for r in results:
    formula_short = r['formula'][:60] + '...' if len(r['formula']) > 60 else r['formula']
    print(f"| {r['disease']} | {r['auc']:.3f} | {formula_short} | {r['time_min']:.1f}min |")

print("\n\nDetailed Formulas:")
for r in results:
    print(f"\n{r['disease']}:")
    print(f"  {r['formula']}")


Summary: Complex Formula Search

| Disease | AUC | Formula | Time |
|---------|-----|---------|------|
| Hypertension | 0.500 | 0.166794640000000 | 32.1min |
| Hyperglycemia | 0.910 | 0.119475365*FBG_T2 | 32.0min |
| Dyslipidemia | 0.815 | 0.043782398*exp(re(TC_T1)) | 32.3min |


Detailed Formulas:

Hypertension:
  0.166794640000000

Hyperglycemia:
  0.119475365*FBG_T2

Dyslipidemia:
  0.043782398*exp(re(TC_T1))


In [8]:
# Save results
summary_df = pd.DataFrame([{
    'Disease': r['disease'],
    'AUC': r['auc'],
    'Sensitivity': r['sensitivity'],
    'Specificity': r['specificity'],
    'Formula': r['formula'],
    'Time_min': r['time_min']
} for r in results])

summary_df.to_csv('../../results/pysr_complex_formulas.csv', index=False)
print("Saved: results/pysr_complex_formulas.csv")

Saved: results/pysr_complex_formulas.csv


## Compare with Previous Results

In [9]:
# Previous results (depth=1, parsimony=0.0001)
previous = {
    'Hypertension': {'auc': 0.684, 'formula': '0.11 * exp(SBP_T1)'},
    'Hyperglycemia': {'auc': 0.899, 'formula': '0.12 * FBG_T2'},
    'Dyslipidemia': {'auc': 0.795, 'formula': '0.04 * exp(TC_T1)'}
}

print("\n" + "="*80)
print("Comparison: Previous (depth=1) vs New (parsimony=0)")
print("="*80)

print("\n| Disease | Previous AUC | New AUC | Diff |")
print("|---------|-------------|---------|------|")
for r in results:
    prev = previous[r['disease']]
    diff = r['auc'] - prev['auc']
    print(f"| {r['disease']} | {prev['auc']:.3f} | {r['auc']:.3f} | {diff:+.3f} |")


Comparison: Previous (depth=1) vs New (parsimony=0)

| Disease | Previous AUC | New AUC | Diff |
|---------|-------------|---------|------|
| Hypertension | 0.684 | 0.500 | -0.184 |
| Hyperglycemia | 0.899 | 0.910 | +0.011 |
| Dyslipidemia | 0.795 | 0.815 | +0.020 |
