## Testing orbital on larger models

### Goals

`orbital` translate all model scoring logic to SQL. This is trivial for linear models, but increasingly complex for tree-based and ensemble models. Let's stress test `orbital` with larger and more complex models by scaling up a random forest. 

We'll train a random forest with 100 trees of depth 10. Since, by default, `RandomForestDefault` is very permissive in terms of small leaf nodes and doesn't do much pruning, this means there could be around 100K (100 * 2^10) decision nodes in the resulting tree logic.

To be clear, I don't care about model performance. I just care if this works and how fast it runs. 

### Set Up

In [None]:
import orbital
import duckdb
import sqlglot
import pandas as pd
import numpy as np
import re
import joblib
import os
from sklearn.ensemble import RandomForestClassifier
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification

In [None]:
#| label: model-prep

# make data dataset
X_train, y_train = make_classification(int(1e6), random_state = 102)
X_train = X_train.round(3)

# get column names for use in pipeline
n_cols = len(X_train[0])
nm_cols = [f"f{i}" for i in range(n_cols)]
feat_dict = {c:orbital.types.DoubleColumnType() for c in nm_cols}

# fit sklearn pipeline
model_path = "sample-outputs/big-rfc.joblib" 
if os.path.exists(model_path):
  pipeline = joblib.load(model_path)
else: 
  pipeline = Pipeline([
  ("preprocess", ColumnTransformer([("scaler", StandardScaler(), [])], remainder="passthrough")),
  ("gbm", RandomForestClassifier(max_depth = 10, n_estimators = 100)),
    ])
  pipeline.fit(X_train, y_train)
  joblib.dump(pipeline, model_path)

### Run `orbital`

In [None]:
%%time

orbital_pipeline = orbital.parse_pipeline(pipeline, features=feat_dict)
sql_raw = orbital.export_sql("DATA_TABLE", orbital_pipeline, dialect="duckdb")

In [None]:
#| label: cleanup

# parse AST from SQL script
ast = sqlglot.parse_one(sql_raw)

# clean up SQL
## drop the class prediction and negative-event predictions
ast.expressions[0] = None
ast.expressions[1] = None 

## pretty print -- not important for usage; but we'll take a peak at the output at the end here
sql_mod = ast.sql()
sql_fmt = sqlglot.transpile(sql_mod, write="duckdb", identify=True, pretty=True)[0]

In [None]:
# count CASEs to gauge size of tree
# divide by 3 since raw output repeats logic 3x

cases = [match.start() for match in re.finditer('CASE', sql_raw)]
len(cases)/3

### Testing Output

We can now assess scoring time and double check the validity of our predictions.

First, we can see the benefits of cleaning up our SQL to only compute the single predictions column. Both CPU and Wall time is about ~1/2 when we only calculate our positive prediciton. This makes sense because the `orbital` code produces separate (repeated) logic for class prediction, positive probability, and negative probability instead of reusing the computation. That may be fine enough for small problems, but for larger and more complex problems, optimization matters. 

In [None]:
DATA_TABLE = pd.DataFrame(X_train[:1000,], columns = nm_cols)

In [None]:
%%timeit -n 1 -r 10

#| label: scores-raw

df_preds = duckdb.sql(sql_raw).df()

In [None]:
%%timeit -n 1 -r 10

#| label: scores-fmt

df_preds = duckdb.sql(sql_fmt).df()

We can also confirm that our outputs still match in this more complex case. 

In [None]:
#| label: valid-values

preds_orb = df_preds['output_probability.1']
preds_ppl = pipeline.predict_proba(DATA_TABLE)[:,1]

print(f"ppl and orb match: {np.all(np.isclose(preds_ppl, preds_orb))}")
print(f"ppl and orb prop mismatch: {sum(~np.isclose(preds_ppl, preds_orb)) / len(preds_ppl)}")
print(f"ppl and orb corr: {np.corrcoef(preds_ppl, preds_orb)[0][1]:.2f}")
print(f"ppl and orb MAE: {np.mean(np.abs(preds_ppl - preds_orb)):.10f}")

### Write Output

We can save out this long query to see what it looks like. Spoiler alert: prepare to scroll.

In [None]:
with open("sample-outputs/long_query.sql", "w") as file:
    file.write(sql_fmt)