# Big Data, Big Waves - symbolic regression

Perform symbolic regression on the final model.

⚠️ **WARNING:** Do not run if you have less than 128GB of RAM. ⚠️

In [None]:
# Glob pattern for data
# E.g. /data/rogue/
DATA_PATH = ""

# Feature groups of the final model
MODEL_FEATURE_GROUPS = [
    ["sea_state_dynamic_crest_trough_correlation"],
    [
        "sea_state_dynamic_steepness",
        "sea_state_dynamic_peak_relative_depth_log10",
        "sea_state_dynamic_bandwidth_peakedness",
        "direction_dominant_spread",
    ],
]

## Imports

In [None]:
import os
import glob
import math
from functools import partial

import numpy as np
import sympy
from sklearn.preprocessing import RobustScaler
from pysr import PySRRegressor

import jax
import jax.numpy as jnp
import optax

In [None]:
from constants import (
    VALIDATION_STATIONS,
    BASE_CONSTRAINTS,
    TRAIN_SIZE,
    RANDOM_SEED,
    SWAG_EPOCHS,
    L1_REG,
    L2_REG,
    LEARNING_RATE,
    MLP_LAYERS,
    EPOCHS,
)

from data_functions import (
    read_files,
    convert_log_features,
    drop_invalid,
    apply_constraints,
    train_test_split,
    get_model_inputs,
    generate_subsets,
)

from training_functions import (
    score,
    train,
    train_swag,
    sample_swag,
    MultiHeadMLP,
    cross_entropy_regularized,
    inverse_logit,
)

from plot_functions import (
    check_calibration,
)

In [None]:
%load_ext watermark
%watermark
%watermark -p jax,jaxlib,flax,optax,matplotlib,seaborn,scipy,scikit-learn,numpy,pandas,PyALE,pysr

In [None]:
# keep track of scores for report
scores = {}

In [None]:
MODEL_FEATURES = set()
for grp in MODEL_FEATURE_GROUPS:
    MODEL_FEATURES |= set(grp)
MODEL_FEATURES = list(MODEL_FEATURES)

## Load data

In [None]:
infiles = []
validation_infiles = []

for f in sorted(glob.glob(os.path.join(DATA_PATH, "*-agg.parquet"))):
    if any(os.path.basename(f).startswith(s) for s in VALIDATION_STATIONS):
        validation_infiles.append(f)
    else:
        infiles.append(f)

len(infiles), len(validation_infiles)

In [None]:
df_all = read_files(infiles)
df_all = convert_log_features(df_all)
df_all = drop_invalid(df_all, MODEL_FEATURES)
df_all = apply_constraints(df_all, BASE_CONSTRAINTS)

df_all["day_of_year"] = df_all.aggregate_100_start_time.dt.dayofyear

In [None]:
df_train, df_val = train_test_split(df_all, train_ratio=TRAIN_SIZE)

x_train, y_train = get_model_inputs(df_train, MODEL_FEATURES)
x_val, y_val = get_model_inputs(df_val, MODEL_FEATURES)

preprocess = RobustScaler(quantile_range=(1e-2, 100 - 1e-2))
preprocess.fit(x_train)

x_train = preprocess.transform(x_train)
x_val = preprocess.transform(x_val)

x_train, y_train, x_val, y_val = (
    jnp.array(v, dtype="float32") for v in (x_train, y_train, x_val, y_val)
)

In [None]:
print(
    "\n".join(
        [
            f"train samples: {len(y_train)}",
            f"train base rate: {np.count_nonzero(y_train) / len(y_train)}",
            f"val samples: {len(y_val)}",
            f"val base rate: {np.count_nonzero(y_val) / len(y_val)}",
        ]
    )
)

In [None]:
sub_xy = generate_subsets(df_val, MODEL_FEATURES)

for subset, (sub_x, sub_y) in sub_xy.items():
    sub_x_pre = preprocess.transform(sub_x)
    sub_xy[subset] = (
        jnp.array(sub_x_pre, dtype="float32"),
        jnp.array(sub_y, dtype="float32"),
    )

sub_xy["full"] = (x_val, y_val)

for subset, (sub_x, sub_y) in sub_xy.items():
    print(f"{subset}: {len(sub_y) / 1e6:.2f}M")

## Train multi-head neural network

In [None]:
np.random.seed(RANDOM_SEED)

base_rate = float(y_train.sum()) / len(y_train)

feature_group_idx = [
    [MODEL_FEATURES.index(feat) for feat in grp] for grp in MODEL_FEATURE_GROUPS
]

model_layers = [math.ceil(l / np.sqrt(len(feature_group_idx))) for l in MLP_LAYERS]

model = MultiHeadMLP(
    features=feature_group_idx,
    hidden_layers=model_layers,
    base_rate=base_rate,
)

reg_loss = partial(cross_entropy_regularized, l1_reg=L1_REG, l2_reg=L2_REG)

state = train(
    model,
    x_train,
    y_train,
    x_val=x_val,
    y_val=y_val,
    loss_fn=reg_loss,
    learning_rate=LEARNING_RATE,
    num_epochs=EPOCHS,
)

In [None]:
state, swag_out = train_swag(
    state,
    x_train,
    y_train,
    loss_fn=reg_loss,
    num_steps=SWAG_EPOCHS,
    accumulate_every=1,
    max_cols_deviation=30,
)

In [None]:
swag_samples_train = sample_swag(state, x_train, *swag_out, 100)
swag_samples_val = sample_swag(state, x_val, *swag_out, 100)

In [None]:
scores.update(
    train_score=score(state, x_train, y_train, logits=swag_samples_train.mean(axis=1)),
    val_score=score(state, x_val, y_val, logits=swag_samples_val.mean(axis=1)),
)

print(f"train score: {scores['train_score']:.3e}, val score: {scores['val_score']:.3e}")

## Run symbolic regression

In [None]:
num_samples = 100_000

p_pred = swag_samples_train.mean(axis=1)

symreg_samples = np.random.choice(x_train.shape[0], size=num_samples, replace=False)
x_symreg = x_train[symreg_samples]
logit_symreg = p_pred[symreg_samples]

weight_symreg = 1 / swag_samples_train[symreg_samples].std(axis=1)
weight_symreg /= weight_symreg.sum()

symreg_features = [
    feat[: -len("_log10")] if feat.endswith("_log10") else feat
    for feat in MODEL_FEATURES
]


def preprocess_symreg(x):
    out = preprocess.inverse_transform(x)

    for i, feat in enumerate(MODEL_FEATURES):
        if feat.endswith("_log10"):
            out[:, i] = 10 ** out[:, i]
            continue

        if feat.endswith("_spread"):
            out[:, i] = np.pi / 180 * out[:, i]
            continue

    return out


input_symreg = preprocess_symreg(x_symreg)
output_symreg = np.array(np.log(inverse_logit(logit_symreg)))

In [None]:
variable_names = {
    "sea_state_dynamic_crest_trough_correlation": "r",
    "sea_state_dynamic_peak_relative_depth_log10": "log(kD)",
    "sea_state_dynamic_peak_relative_depth": "kD",
    "sea_state_dynamic_steepness": "eps",
    "sea_state_dynamic_bandwidth_peakedness": "nu",
    "direction_directionality_index_log10": "log(R)",
    "direction_directionality_index": "R",
    "direction_dominant_spread": "sig",
}

In [None]:
symreg = PySRRegressor(
    niterations=40,
    ncyclesperiteration=600,
    populations=48,
    population_size=1000,
    maxsize=32,
    maxdepth=16,
    binary_operators=["+", "*", "-", "/"],
    unary_operators=[
        "log",
        "inv(x) = 1/x",
        "square",
        "sqrt",
    ],
    nested_constraints={
        "log": {"log": 0},
        "sqrt": {"sqrt": 0},
    },
    constraints={
        "square": 4,
        "sqrt": 2,
        "inv": 4,
    },
    warmup_maxsize_by=0.5,
    model_selection="best",
    batching=False,
    multithreading=True,
    annealing=False,
    procs=48,
    loss="L2DistLoss()",
    extra_sympy_mappings={
        "inv": lambda x: 1 / x,
    },
)
symreg_varnames = [variable_names.get(f, f) for f in symreg_features]
symreg.fit(
    input_symreg, output_symreg, variable_names=symreg_varnames, weights=weight_symreg
)

In [None]:
for eq in symreg.equations_.itertuples():
    print(f"{eq.loss:.3f}  {sympy.sympify(f'exp({eq.sympy_format})')}\n")

In [None]:
best_eq = None
for eq in symreg.equations_.sort_values("score", ascending=False).itertuples():
    if all(var in str(eq.sympy_format) for var in symreg_varnames):
        best_eq = eq
        break

print(
    f"{best_eq.score:.3f}  {best_eq.loss:.4f}  {sympy.sympify(f'exp({best_eq.sympy_format})')}\n"
)

In [None]:
best_eq_jax = symreg.jax(best_eq.Index)
func, params = best_eq_jax["callable"], best_eq_jax["parameters"]

x_symreg_full = preprocess_symreg(x_train)
logits_train = swag_samples_train.mean(axis=1)

opt = optax.adam(learning_rate=1e-4)
opt_state = opt.init(params)


@jax.jit
def polish_symreg(params, opt_state, X, y):
    def symreg_loss(params, X, y):
        pred = func(X, params)
        return jnp.mean((pred - y) ** 2)

    loss, grads = jax.value_and_grad(symreg_loss)(params, X, y)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state


for i in range(1000):
    loss, params, opt_state = polish_symreg(
        params, opt_state, x_symreg_full, logits_train
    )

    if i % 100 == 0:
        print(loss)

best_eq_jax["parameters"], params

In [None]:
final_eq = lambda x: best_eq_jax["callable"](x, params)

In [None]:
check_calibration(final_eq(preprocess_symreg(x_val))[:, np.newaxis], y_val)

In [None]:
total_score = 0

for sub, (x, y) in sub_xy.items():
    pred = final_eq(preprocess_symreg(np.array(x)))
    this_score = score(None, None, y, logits=pred)
    total_score += this_score
    print(f"{sub}: {this_score:.2e}")

print("---")
print(f"total score: {total_score / len(sub_xy):.2e}")