# Serialize/export models to ONNX
Read the contents of results.txt and serialize all the top models

In [None]:
import pandas as pd
import numpy as np

from serialize import SerializeFailure, COL_SEP


ONNX_FAILS: list[SerializeFailure] = []

# load the results
df = pd.read_csv(
    "results.tsv", 
    index_col=False, 
    sep=COL_SEP,
    usecols=['Sport', 'Service', 'Style', 'Type', 'y', 'ModelType', 'R2', 'Date', 'Params'],
)

# drop everything after the seperator
seperator_idx = np.where(df['Sport'].str.startswith('*'))[0][0]
df = df.iloc[:seperator_idx]

# df = df.query('ModelType == "tpot"')

with pd.option_context('display.max_rows', 1000, 'display.max_colwidth', 1000):
    display(df)

In [None]:
import ast
import os
import tempfile

import numpy as np

from fantasy_py import ContestStyle, CLSRegistry, CONTEST_DOMAIN, lineup

from automl import create_automl_model, error_report
from serialize import serialize_model, SerializeFailure, SUPPORTED_EXPORT_MODELS
from generate_train_test import generate_train_test, load_csv


DEFAULT_PCA_COMPONENTS = 5


def train_export(
    sport, service, style: ContestStyle,
    contest_type: str, model_type: str,
    y_type,
    model_def_: dict,
    skip_fit=False,
    datapath=".",
    modelpath=".",
    overwrite=True,
):
    if not os.path.isdir(modelpath):
        print(f"Creating model path '{modelpath}'")
        os.makedirs(modelpath)

    contest_style = ContestStyle[style.upper()]
    full_model_name = f'{sport}_{service}_{contest_style}_{contest_type}_{model_type}_{y_type}'
    model_filepath = os.path.join(modelpath, full_model_name + ".onnx")

    if os.path.isfile(model_filepath) and not overwrite:
        print(f"Model '{model_filepath}' already exists, skipping")
        return
    print(f"Exporting model to '{model_filepath=}'")

    contest_type_cls = CLSRegistry.get_class(CONTEST_DOMAIN, contest_type)
    data_df = load_csv(sport, service, contest_style,
                       contest_type_cls, data_folder=datapath)
    assert len(data_df) > 0, "CSV load returned no data"

    model_def = dict(model_def_)
    random_state = model_def.pop("random_state", None)
    model_cols = model_def.pop(
        'model_cols'
    ) if 'model_cols' in model_def else None
    train_test_data = generate_train_test(
        data_df,
        model_cols=model_cols,
        random_state=random_state,
    )
    if train_test_data is None:
        display("Failed to generate a train/test data set from...", data_df)
    (X_train, X_test, y_top_train, y_top_test,
     y_last_win_train, y_last_win_test) = train_test_data

    create_model_params = {
        'random_state': random_state,
    }
    if model_type.endswith('-pca'):
        create_model_params['pca_components'] = (
            model_def.pop('n_components')
            if 'n_components' in model_def else
            DEFAULT_PCA_COMPONENTS
        )

    if model_type.startswith('skautoml'):
        create_model_params.update({
            'framework': 'skautoml',
        })
    elif model_type.startswith('tpot'):
        create_model_params = {
            'framework': 'tpot',
        }
    else:
        raise ValueError(f"Don't know how to process model type {model_type}")

    if y_type == 'top':
        y_train = y_top_train
        y_test = y_top_test
    elif y_type == 'last':
        y_train = y_last_win_train
        y_test = y_last_win_test
    else:
        raise ValueError(f"Unexpected y of {y_type}")

    # add all remaining
    create_model_params.update(model_def)
    model, fit_params = create_automl_model(
        full_model_name,
        **create_model_params,
    )
    if skip_fit:
        print("Skipping fit...")
        return

    print("Training model...")
    model.fit(X_train, y_train, **fit_params)
    error_report(model, X_test, y_test,
                 f"{full_model_name}: model_cols={model_def.get('model_cols')}")
    serialize_model(model, model_type, X_train, y_train,
                    full_model_name, model_filepath)


In [None]:
from tqdm.notebook import tqdm

OVERWRITE = False
# skip fit and serialize... dryrun
DRYRUN = False

PARAM_OVERRIDES = {
    # 'generations': 10, 
    # 'early_stop': 1, 
    # 'max_train_time': 60
}

pbar = tqdm(df.iterrows(), total=len(df))
for _, row in pbar:
    model_desc = f"sport={row.Sport} service={row.Service} style={row.Style} type={row.Type} y={row.y}"
    pbar.set_postfix_str(model_desc)
    if row.ModelType not in SUPPORTED_EXPORT_MODELS:
        display(
            f"Failed to train+export model of type {row.ModelType=}. Export of this type is not supported."
        )
        continue
    
    if pd.isna(row.ModelType):
        print(
            f"Skipping row with no model type... {model_desc}")
        continue

    try:
        model_def: dict = ast.literal_eval(row.Params)
    except Exception:
        print("Failed to parse params", row.Params)
        raise

    try:
        model_def.update(PARAM_OVERRIDES)
        train_export(row.Sport, row.Service, row.Style,
                     row.Type, row.ModelType, row.y,
                     model_def, 
                     skip_fit=DRYRUN, overwrite=OVERWRITE,
                     datapath="data", modelpath="models")
    except SerializeFailure as se:
        ONNX_FAILS.append(se)
        display(
            f"Failed to serialize: {model_desc} {row.ModelType=} {row.Params=}"
        )
    except Exception as ex:
        display(
            f"Failed to train+export: {model_desc} {row.ModelType=} {row.Params=}"
        )
        import traceback
        print(traceback.format_exc())

print("Done!")


In [None]:
if len(ONNX_FAILS):
    try:
        print("###### attempting to serialize last failed model!!! ####")
        serialize_model(
            ONNX_FAILS[-1]['model'],
            'tpot',
            ONNX_FAILS[-1]['X'],
            ONNX_FAILS[-1]['y'],
            ONNX_FAILS[-1]['name']
        )
        print("#### serialization successful! ###")
    except Exception as ex:
        print("#### serializeation failed!!!", ex)
        raise
        # export_data_df = pd.read_csv('/tmp/tpot-data.csv', sep=COL_SEP, dtype=np.float64)
        # display("export df", export_data_df)
        # display("exported pipeline", ONNX_FAILS[-1].get('exported_pipeline'))
        # display("ONNX_FAILS[-1]", ONNX_FAILS[-1])
else:
    print("no previous errors found")