# Introduction

Create a binary classifier using the `features_20240217_v1` features dataset and deep learning. The dataset contains predicted and actual values with the same feature name. The prefix marks the source of the value. Use actual values for training and predicted values for testing.

In [28]:
import re
import tempfile
import warnings

import lightgbm as lgb
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import mlflow
import numpy as np
import pandas as pd
import seaborn as sns
from hyperopt import STATUS_OK, SparkTrials, Trials, fmin, hp, tpe
from hyperopt.pyll.base import scope
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.compose import ColumnTransformer
from sklearn.metrics import (
    accuracy_score,
    auc,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sqlalchemy import create_engine

from JapanHorseRaceAnalytics.utilities.base import get_base_dir, get_data_dir
from JapanHorseRaceAnalytics.utilities.metrics import (
    calculate_binary_classifier_statistics,
)
from JapanHorseRaceAnalytics.utilities.mlflow import get_colspecs
from JapanHorseRaceAnalytics.utilities.structured_logger import logger


pd.set_option("display.max_rows", 1000)
pd.set_option("display.max_columns", 100)

# Data Collection

In [4]:
warehouse_dir = f"{get_base_dir()}/spark-warehouse"
postgres_driver_path = f"{get_base_dir()}/jars/postgresql-42.7.1.jar"

spark = (
    SparkSession.builder.appName("20240211_competitors")
    .config("spark.driver.memory", "21g")
    .config("spark.driver.maxResultSize", "5g")
    .config("spark.sql.warehouse.dir", warehouse_dir)
    .config("spark.jars", postgres_driver_path)
    .config("spark.executor.extraClassPath", postgres_driver_path)
    .config("spark.driver.extraClassPath", postgres_driver_path)
    .enableHiveSupport()
    .getOrCreate()
)

24/02/17 17:05:08 WARN Utils: Your hostname, Hanks-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.40.105 instead (on interface en0)
24/02/17 17:05:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/02/17 17:05:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [46]:
def read_hive_table(
    table_name: str,
    schema: str,
    spark_session: SparkSession,
    use_cache: bool = True,
):
    save_path = get_data_dir() / "sql_tables" / f"{table_name}.snappy.parquet"
    if use_cache and save_path.exists():
        logger.info(f"Read from parquet {save_path} to pandas")
        return pd.read_parquet(save_path)
    logger.info(f"Read from hive {schema}.{table_name}")
    spark_df = spark_session.read.table(f"{schema}.{table_name}")
    logger.info(f"Write to parquet {save_path}")
    spark_df.write.mode("overwrite").parquet(str(save_path))
    logger.info(f"Read from parquet {save_path} to pandas")
    return pd.read_parquet(save_path)


data = read_hive_table(
    table_name="features_20240217_v1",
    schema="jhra_curated",
    spark_session=spark,
    # use_cache=False,
)

# For all columns beginning with cat_, cast to string and fillna with "missing"
data = data.astype(
    {col: "string" for col in data.columns if re.match(r"^cat_", col) and col}
)
data = data.fillna({col: "missing" for col in data.columns if re.match(r"^cat_", col) and col})

# Set the dtype of all columns beginning with cat_ to category
data = data.astype(
    {col: "category" for col in data.columns if re.match(r"^cat_", col) and col}
)
# Set the dtype of all columns beginning with num_ to float64
data = data.astype(
    {col: "float64" for col in data.columns if re.match(r"^num_", col) and col}
)

data.head()

{"event": "Read from parquet /Users/hankehly/Projects/JapanHorseRaceAnalytics/data/sql_tables/features_20240217_v1.snappy.parquet to pandas", "level": "info", "timestamp": "2024-02-17T08:39:33.184484Z", "logger": "__main__"}


Unnamed: 0,meta_レースキー,meta_馬番,meta_着順,meta_本賞金,meta_単勝的中,meta_単勝払戻金,meta_複勝的中,meta_複勝払戻金,meta_int_races_レースキー,meta_発走日時,meta_場コード,cat_四半期,cat_距離,cat_事前_馬場状態コード,cat_事前_レース条件_トラック情報_右左,cat_事前_レース条件_トラック情報_内外,cat_事前_レース条件_種別,cat_事前_レース条件_条件,cat_事前_レース条件_記号,cat_事前_レース条件_重量,cat_事前_レース条件_グレード,num_事前_頭数,cat_実績_馬場状態コード,cat_実績_レース条件_トラック情報_右左,cat_実績_レース条件_トラック情報_内外,cat_実績_レース条件_種別,cat_実績_レース条件_条件,cat_実績_レース条件_記号,cat_実績_レース条件_重量,cat_実績_レース条件_グレード,num_実績_頭数,cat_トラック種別,num_事前_馬場差,num_実績_馬場差,cat_馬場状態内,cat_馬場状態中,cat_馬場状態外,num_直線馬場差最内,num_直線馬場差内,num_直線馬場差中,num_直線馬場差外,num_直線馬場差大外,cat_芝種類,cat_草丈,cat_転圧,cat_凍結防止剤,num_中間降水量,meta_int_race_horses_レースキー,meta_int_race_horses_馬番,num_事前_馬体重,...,num_競争相手平均調教師トップ3完走率差,num_競争相手平均調教師場所レース数差,num_競争相手平均調教師場所1位完走差,num_競争相手平均調教師場所トップ3完走差,num_競争相手平均調教師場所1位完走率差,num_競争相手平均調教師場所トップ3完走率差,num_競争相手平均調教師本賞金累計差,num_競争相手平均調教師1位完走平均賞金差,num_競争相手平均調教師レース数平均賞金差,meta_int_combinations_レースキー,meta_int_combinations_馬番,num_馬騎手レース数,num_馬騎手1位完走,num_馬騎手1位完走率,num_馬騎手トップ3完走,num_馬騎手トップ3完走率,num_馬騎手初二走,num_馬騎手同騎手,num_馬騎手場所レース数,num_馬騎手場所1位完走,num_馬騎手場所1位完走率,num_馬騎手場所トップ3完走,num_馬騎手場所トップ3完走率,num_馬調教師レース数,num_馬調教師1位完走,num_馬調教師1位完走率,num_馬調教師トップ3完走,num_馬調教師トップ3完走率,num_馬調教師初二走,num_馬調教師同調教師,num_馬調教師場所レース数,num_馬調教師場所1位完走,num_馬調教師場所1位完走率,num_馬調教師場所トップ3完走,num_馬調教師場所トップ3完走率,meta_int_race_weather_レースキー,num_temperature,num_precipitation,num_snowfall,num_snow_depth,num_wind_speed,cat_wind_direction,num_solar_radiation,num_local_air_pressure,num_sea_level_air_pressure,num_relative_humidity,num_vapor_pressure,num_dew_point_temperature,cat_weather,num_visibility
0,1011103,4,6.0,0.0,False,0,False,0,1011103,2001-08-04 01:45:00,1,3,1200,20,1,1,12,A3,102,3,missing,16.0,21,1,1,12,A3,102,3,missing,16.0,芝,,-18.0,1,1,1,1.0,1.0,0.0,0.0,0.0,missing,missing,False,False,,1011103,4,476.0,...,-0.036119,-6.733333,-1.066667,-0.733333,-0.010847,0.12559,-18662.266667,-159.225818,-51.509013,1011103,4,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.0,0.0,0.0,1.0,0.111111,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1011103,22.8,0.0,,0.0,3.9,北西,2.93,1010.95,1013.95,60.75,16.875,14.85,missing,
1,1011103,9,2.0,200.0,False,0,True,120,1011103,2001-08-04 01:45:00,1,3,1200,20,1,1,12,A3,102,3,missing,16.0,21,1,1,12,A3,102,3,missing,16.0,芝,,-18.0,1,1,1,1.0,1.0,0.0,0.0,0.0,missing,missing,False,False,,1011103,9,482.0,...,-0.06815,-9.933333,-1.066667,-1.8,0.018783,0.12559,-28766.8,-206.365603,-72.468494,1011103,9,3.0,0.0,0.0,2.0,0.666667,0.0,1.0,0.0,0.0,0.0,0.0,0.0,8.0,0.0,0.0,3.0,0.375,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1011103,22.8,0.0,,0.0,3.9,北西,2.93,1010.95,1013.95,60.75,16.875,14.85,missing,
2,1011204,14,6.0,0.0,False,0,False,0,1011204,2001-08-05 02:15:00,1,3,1800,10,1,1,12,A3,102,3,missing,14.0,11,1,1,12,A3,102,3,missing,14.0,芝,,-14.0,1,1,1,1.0,1.0,0.0,0.0,0.0,missing,missing,False,False,,1011204,14,470.0,...,0.049687,15.769231,6.846154,9.923077,0.137634,0.14995,49377.5,316.358601,86.088353,1011204,14,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,0.0,0.0,1.0,0.1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1011204,22.3,0.0,,0.0,4.625,北北西,3.14,1010.325,1013.325,64.0,17.225,15.15,1,30.0
3,1011303,6,3.0,130.0,False,0,True,1090,1011303,2001-08-11 01:45:00,1,3,1700,10,1,1,12,A3,2,3,missing,13.0,11,1,1,12,A3,2,3,missing,13.0,ダート,,-19.0,1,1,1,1.0,1.0,0.0,0.0,0.0,missing,missing,False,False,,1011303,6,436.0,...,0.063804,7.75,1.333333,7.916667,0.039064,0.115922,-1223.125,-199.919166,12.856413,1011303,6,1.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1011303,23.475,0.0,,0.0,0.825,北,1.435,1009.925,1012.925,65.0,18.725,16.475,missing,
4,1011304,7,1.0,510.0,True,230,True,120,1011304,2001-08-11 02:15:00,1,3,2000,10,1,1,12,A3,102,3,missing,16.0,11,1,1,12,A3,102,3,missing,16.0,芝,,-17.0,1,1,1,1.0,1.0,0.0,0.0,0.0,missing,missing,False,False,,1011304,7,502.0,...,0.176155,45.066667,15.533333,30.333333,0.169848,0.277041,230843.133333,516.538555,305.098488,1011304,7,1.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1011304,24.125,0.0,,0.0,1.3,北北西,1.7175,1009.85,1012.825,61.5,18.425,16.225,2,30.0


# Model Training

## Split Train and Test Data

In [47]:
X = data
y = data["meta_複勝的中"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
print(f"X_train: {X_train.shape}")
print(f"X_test: {X_test.shape}")
print(f"y_train: {y_train.shape}")
print(f"y_test: {y_test.shape}")

X_train: (886532, 886)
X_test: (221634, 886)
y_train: (886532,)
y_test: (221634,)


## Define Objective Function

In [15]:
def create_objective_fn(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    X_test: pd.DataFrame,
    y_test: pd.Series,
    df_payout: pd.DataFrame,
    experiment_name: str,
):
    """
    df_payout should have the same index as *_test and have the following columns:
    * レースキー
    * 馬番
    * 距離
    * 発走日時
    * 年齢
    * グレード
    * 場コード
    * payout - amount won if betting 100 yen.
    """

    def train(params):
        def profit_loss(row, payout_column_name, bet_amount=100):
            if row["pred"] and row["actual"]:
                payout = row[payout_column_name] * (bet_amount / 100)
                return payout - bet_amount
            elif row["pred"] and not row["actual"]:
                return -bet_amount
            else:
                return 0

        mlflow.set_experiment(experiment_name=experiment_name)
        with mlflow.start_run():
            mlflow.log_params(params)

            numeric_features = X_train.select_dtypes("number").columns.tolist()
            categorical_features = X_train.select_dtypes("category").columns.tolist()
            preprocessor = ColumnTransformer(
                transformers=[
                    ("num", StandardScaler(), numeric_features),
                    (
                        "cat",
                        OneHotEncoder(handle_unknown="ignore"),
                        categorical_features,
                    ),
                ]
            )
            model = Pipeline(
                steps=[
                    ("preprocessor", preprocessor),
                    ("classifier", lgb.LGBMClassifier(**params)),
                ]
            )

            # RestException: INVALID_PARAMETER_VALUE: Dataset schema exceeds the maximum length of 65535
            # Xy_train = pd.concat((X_train, y_train), axis=1)
            # dataset = mlflow.data.from_pandas(Xy_train, targets=y_train.name)
            # mlflow.log_input(dataset, context="train")

            input_schema = Schema(get_colspecs(X_train))
            output_schema = Schema([ColSpec("double", y_train.name)])
            signature = ModelSignature(inputs=input_schema, outputs=output_schema)
            input_example = X_train.iloc[:25]
            model.fit(X_train, y_train)
            mlflow.sklearn.log_model(
                sk_model=model,
                signature=signature,
                input_example=input_example,
                artifact_path="model",
            )

            y_pred_proba = model.predict_proba(X_test)[:, 1]
            y_pred = model.predict(X_test)

            metrics = {
                "loss": log_loss(y_test, y_pred_proba),
                "accuracy": accuracy_score(y_test, y_pred),
                "precision": precision_score(y_test, y_pred),
                "recall": recall_score(y_test, y_pred),
                "f1": f1_score(y_test, y_pred),
                "roc_auc": roc_auc_score(y_test, y_pred),
            }
            mlflow.log_metrics(metrics)

            # Calculate payout rates by group
            results = pd.concat(
                [
                    df_payout,
                    pd.DataFrame(
                        np.c_[y_test, y_pred, y_pred_proba],
                        columns=["actual", "pred", "pred_proba_true"],
                    ),
                ],
                axis=1,
            )
            payout_all = calculate_binary_classifier_statistics(
                results, group_by=None, payout_column_name="payout"
            )
            payout_month = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.month,
                payout_column_name="payout",
            )
            payout_distance = calculate_binary_classifier_statistics(
                results,
                group_by=pd.cut(results["距離"], bins=[0, 1400, 1800, 10000]),
                payout_column_name="payout",
            )
            payout_season = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.month % 12 // 3,
                payout_column_name="payout",
            )
            payout_year = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.year,
                payout_column_name="payout",
            )
            payout_age = calculate_binary_classifier_statistics(
                results,
                group_by=pd.cut(results["年齢"], bins=[0, 3, 6, 100]),
                payout_column_name="payout",
            )
            payout_grade = calculate_binary_classifier_statistics(
                results,
                group_by=results["グレード"],
                payout_column_name="payout",
            )
            payout_racetrack = calculate_binary_classifier_statistics(
                results, group_by=results["場コード"], payout_column_name="payout"
            )
            payout = (
                pd.concat(
                    [
                        pd.DataFrame(payout_all).T.assign(group="all"),
                        pd.DataFrame(payout_month).T.assign(group="month"),
                        pd.DataFrame(payout_distance).T.assign(group="distance"),
                        pd.DataFrame(payout_season).T.assign(group="season"),
                        pd.DataFrame(payout_year).T.assign(group="year"),
                        pd.DataFrame(payout_age).T.assign(group="horse_age"),
                        pd.DataFrame(payout_grade).T.assign(group="grade"),
                        pd.DataFrame(payout_racetrack).T.assign(group="racetrack"),
                    ],
                    axis=0,
                )
                .rename_axis(index="part")
                .reset_index()
            )
            # Move "group" and "part" columns to the first position in this dataframe
            payout = payout[
                ["group", "part"]
                + [c for c in payout.columns if c not in ["group", "part"]]
            ]

            # Save payout rates as csv
            with tempfile.NamedTemporaryFile(prefix="payout_rate_", suffix=".csv") as f:
                payout.to_csv(f.name, index=False)
                mlflow.log_artifact(f.name)

            # Log payout rates as metrics
            payout_metrics = {}
            for group_name, group in payout.groupby("group"):
                for i, row in group.iterrows():
                    key = re.sub(r"\W", "_", f"payout_rate_{group_name}_{row['part']}")
                    payout_metrics[key] = row["payout_rate"]
            mlflow.log_metrics(payout_metrics)

            # Suppress UserWarning messages from matplotlib
            warnings.filterwarnings("ignore", category=UserWarning)

            # Plot payout rates by group
            sns.set_theme(style="whitegrid")
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            for (group, df), ax in zip(payout.groupby("group"), axes.flatten()):
                sns.barplot(x="part", y="payout_rate", data=df, ax=ax)
                ax.set_title(group)
                ax.set_ylim(0, 150)
                ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
                ax.set_ylabel("payout rate")
                ax.set_xlabel("")
                ax.yaxis.set_major_formatter(ticker.PercentFormatter())
            with tempfile.NamedTemporaryFile(prefix="payout_rate_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Plot bank balance over time
            results["profit_loss"] = results.apply(
                profit_loss, args=("payout", 100), axis=1
            )
            daily_profit_loss = results.groupby("発走日時")["profit_loss"].sum()
            bank_balance = daily_profit_loss.cumsum()
            plt.figure(figsize=(10, 10))
            ax = plt.subplot(1, 1, 1)
            ax.plot(bank_balance.index, bank_balance.values)
            ax.set_title("Bank Balance")
            ax.set_xlabel("Date")
            ax.set_ylabel("Bank Balance")
            ax.grid(True)
            ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:,.0f}"))
            with tempfile.NamedTemporaryFile(
                prefix="bank_balance_", suffix=".png"
            ) as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Confusion Matrix
            conf_matrix = confusion_matrix(y_test, y_pred)
            _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            sns.heatmap(conf_matrix, annot=True, fmt="g", cmap="Blues", ax=ax1)
            ax1.set_xlabel("Predicted")
            ax1.set_ylabel("Actual")
            ax1.set_title("Confusion Matrix")
            sns.heatmap(
                conf_matrix / conf_matrix.sum(axis=1)[:, None],
                annot=True,
                fmt=".2%",
                cmap="Blues",
                ax=ax2,
            )
            ax2.set_xlabel("Predicted")
            ax2.set_ylabel("Actual")
            ax2.set_title("Normalized Confusion Matrix")
            with tempfile.NamedTemporaryFile(
                prefix="confusion_matrix_", suffix=".png"
            ) as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # ROC Curve
            fpr, tpr, _ = roc_curve(y_test, y_pred)
            roc_auc = auc(fpr, tpr)
            _, ax = plt.subplots(figsize=(10, 10))
            ax.plot(
                fpr,
                tpr,
                color="darkorange",
                lw=2,
                label="ROC curve (area = %0.2f)" % roc_auc,
            )
            ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
            ax.set_xlim([0.0, 1.0])
            ax.set_ylim([0.0, 1.0])
            ax.set_xlabel("False Positive Rate")
            ax.set_ylabel("True Positive Rate")
            ax.set_title("Receiver Operating Characteristic")
            ax.legend(loc="lower right")
            with tempfile.NamedTemporaryFile(prefix="roc_curve_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Feature Importances
            feature_importances = zip(
                model.named_steps["preprocessor"].get_feature_names_out(),
                model.named_steps["classifier"].feature_importances_,
            )
            feature_importances_df = (
                pd.DataFrame(
                    data=feature_importances, columns=["feature", "importance"]
                )
                .sort_values("importance", ascending=False)
                .reset_index(drop=True)
            )
            with tempfile.NamedTemporaryFile(
                prefix="feature_importance_", suffix=".csv"
            ) as f:
                feature_importances_df.to_csv(f.name, index=False)
                mlflow.log_artifact(f.name)

            font_properties = fm.FontProperties(
                fname="/System/Library/Fonts/ヒラギノ角ゴシック W3.ttc"
            )
            sns.set_theme(style="whitegrid")
            plt.figure(figsize=(10, 12))
            ax = sns.barplot(
                x="importance", y="feature", data=feature_importances_df.iloc[:50]
            )
            ax.set_title("Feature Importances (Top 50)", fontproperties=font_properties)
            ax.set_xlabel("Importance", fontproperties=font_properties)
            ax.set_ylabel("Features", fontproperties=font_properties)
            for label in ax.get_yticklabels():
                label.set_fontproperties(font_properties)
            with tempfile.NamedTemporaryFile(
                prefix="feature_importance_", suffix=".png"
            ) as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            return {"status": STATUS_OK, "params": params, "model": model, **metrics}

    return train

## Hyperparameter Optimization

In [40]:
space = {
    "boosting_type": hp.choice("boosting_type", ["gbdt", "dart", "goss"]),
    "learning_rate": hp.loguniform("learning_rate", -5, 0),  # between e^-5 and 1
    "n_estimators": scope.int(hp.quniform("n_estimators", 100, 1000, 1)),
    "max_depth": scope.int(hp.quniform("max_depth", 3, 10, 1)),
    "num_leaves": scope.int(hp.quniform("num_leaves", 20, 150, 1)),
    "min_child_samples": scope.int(hp.quniform("min_child_samples", 20, 500, 1)),
    "feature_fraction": hp.uniform("feature_fraction", 0.5, 1.0),
    "lambda_l1": hp.uniform("lambda_l1", 0, 5),
    "lambda_l2": hp.uniform("lambda_l2", 0, 5),
    "min_split_gain": hp.uniform("min_split_gain", 0, 1),
    "min_child_weight": hp.uniform("min_child_weight", 0.001, 10),
    "subsample": hp.uniform("subsample", 0.5, 1),
    "colsample_bytree": hp.uniform("colsample_bytree", 0.5, 1),
    "reg_alpha": hp.uniform("reg_alpha", 0.0, 1.0),
    "reg_lambda": hp.uniform("reg_lambda", 0.0, 1.0),
    "objective": "binary",
    "class_weight": "balanced",
    "verbose": -1,
    "seed": 80,
}

In [48]:
names = data.columns.tolist()

# Get all the features names for the actual data
names_actual = []
for name in names:
    if "_事前_" in name:
        continue
    elif name.startswith("meta_"):
        continue
    elif name == "cat_トラック種別":
        continue
    else:
        names_actual.append(name)

names_actual_prep = [
    name.replace("_実績", "") if "_実績" in name else name for name in names_actual
]


names_before = []
for name in names:
    if "_実績_" in name:
        continue
    elif name.startswith("meta_"):
        continue
    elif name == "cat_トラック種別":
        continue
    else:
        names_before.append(name)

names_before_prep = [
    name.replace("_事前", "") if "_事前" in name else name for name in names_before
]

# Check if the names are the same
assert sorted(names_actual_prep) == sorted(names_before_prep)

### Turf Model

In [49]:
mask_train_turf = X_train["cat_トラック種別"] == "芝"
mask_test_turf = X_test["cat_トラック種別"] == "芝"

X_train_turf = X_train[mask_train_turf][names_actual]
X_train_turf.columns = names_actual_prep
y_train_turf = y_train[mask_train_turf]

X_test_turf = X_test[mask_test_turf][names_before]
X_test_turf.columns = names_before_prep
y_test_turf = y_test[mask_test_turf]

assert set(X_train_turf.columns) == set(X_test_turf.columns)

In [17]:
! pip install scikeras

IOStream.flush timed out
Collecting scikeras
  Downloading scikeras-0.12.0-py3-none-any.whl.metadata (4.0 kB)
Downloading scikeras-0.12.0-py3-none-any.whl (27 kB)
Installing collected packages: scikeras
Successfully installed scikeras-0.12.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [64]:
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Sequential
from sklearn.impute import SimpleImputer

numeric_features = X_train_turf.select_dtypes("number").columns.tolist()
categorical_features = X_train_turf.select_dtypes("category").columns.tolist()

numeric_pipeline = Pipeline(
    steps=[
        ("imputer", SimpleImputer()),  # First, impute missing values
        ("scaler", StandardScaler()),  # Then, scale the data
    ]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_pipeline, numeric_features),
        (
            "cat",
            OneHotEncoder(handle_unknown="ignore", drop="if_binary"),
            categorical_features,
        ),
    ]
)

input_shape = preprocessor.fit_transform(X_train_turf).shape[1]


# def build_fn(input_shape):
model = Sequential(
    [
        Dense(500, activation="relu", input_shape=(input_shape,)),
        Dropout(0.5),
        Dense(250, activation="relu"),
        Dropout(0.5),
        Dense(1, activation="sigmoid"),
    ]
)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])


model = KerasClassifier(
    model=model,
    # input_shape=input_shape,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=1,
)

pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", model)])

In [65]:
pipeline.fit(X_train_turf, y_train_turf)

Epoch 1/100


2024-02-17 17:51:30.491731: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Epoch 2/100
Epoch 3/100
Epoch 4/100
 2198/10627 [=====>........................] - ETA: 2:30 - loss: 705.3856 - accuracy: 0.9008

In [None]:
test_loss, test_acc = model.evaluate(X_test_turf, y_test_turf)
print(f"Test Accuracy: {test_acc}")

In [19]:
df_payout_renamed_columns = {
    "meta_レースキー": "レースキー",
    "meta_馬番": "馬番",
    "cat_距離": "距離",
    "meta_発走日時": "発走日時",
    "meta_複勝払戻金": "payout",
    "num_年齢": "年齢",
    "cat_実績_レース条件_グレード": "グレード",
    "meta_場コード": "場コード",
}

df_payout_turf = (
    data.iloc[X_test_turf.index]
    .reset_index(drop=True)
    .rename(columns=df_payout_renamed_columns)[df_payout_renamed_columns.values()]
)

experiment_name_turf = "20240217_before_after__turf"
if mlflow.get_experiment_by_name(experiment_name_turf) is None:
    mlflow.create_experiment(experiment_name_turf)

fn_turf = create_objective_fn(
    X_train_turf,
    y_train_turf,
    X_test_turf,
    y_test_turf,
    df_payout=df_payout_turf,
    experiment_name=experiment_name_turf,
)

In [26]:
trials_turf = SparkTrials(parallelism=3, spark_session=spark)
fmin(fn=fn_turf, space=space, algo=tpe.suggest, max_evals=60, trials=trials_turf)

  0%|          | 0/60 [00:00<?, ?trial/s, best loss=?]

build_posterior_wrapper took 0.001952 seconds
TPE using 0 trials
build_posterior_wrapper took 0.002027 seconds
TPE using 1/1 trials with best loss inf
build_posterior_wrapper took 0.002119 seconds
TPE using 2/2 trials with best loss inf
build_posterior_wrapper took 0.002132 seconds
TPE using 3/3 trials with best loss inf
Closing down clientserver connection                                            
[Stage 2:>                  (0 + 1) / 1][Stage 3:>                  (0 + 1) / 1]

  2%|▏         | 1/60 [02:45<2:42:19, 165.08s/trial, best loss: 0.9462831393144859]

build_posterior_wrapper took 0.006418 seconds
TPE using 4/4 trials with best loss 0.946283
Closing down clientserver connection                                            


  3%|▎         | 2/60 [02:59<1:13:52, 76.43s/trial, best loss: 0.9239293371013857] 

build_posterior_wrapper took 0.002626 seconds                       (0 + 1) / 1]
TPE using 5/5 trials with best loss 0.923929
Closing down clientserver connection                                            


  5%|▌         | 3/60 [03:19<48:14, 50.77s/trial, best loss: 0.7720435299667909]  

build_posterior_wrapper took 0.014627 secondse 5:>                  (0 + 1) / 1]
TPE using 6/6 trials with best loss 0.772044
Closing down clientserver connection                                            


  7%|▋         | 4/60 [05:35<1:18:35, 84.21s/trial, best loss: 0.7720435299667909]

build_posterior_wrapper took 0.003944 secondse 6:>                  (0 + 1) / 1]
TPE using 7/7 trials with best loss 0.772044
Closing down clientserver connection                                            


  8%|▊         | 5/60 [06:47<1:13:22, 80.05s/trial, best loss: 0.7330719958835319]

build_posterior_wrapper took 0.004351 secondse 7:>                  (0 + 1) / 1]
TPE using 8/8 trials with best loss 0.733072
Closing down clientserver connection                                            


 10%|█         | 6/60 [07:42<1:04:12, 71.34s/trial, best loss: 0.7330719958835319]

build_posterior_wrapper took 0.004967 secondse 8:>                  (0 + 1) / 1]
TPE using 9/9 trials with best loss 0.733072
Closing down clientserver connection                                            


 12%|█▏        | 7/60 [10:51<1:36:56, 109.75s/trial, best loss: 0.7330719958835319]

build_posterior_wrapper took 0.002827 seconds
TPE using 10/10 trials with best loss 0.733072
Closing down clientserver connection                                            


 13%|█▎        | 8/60 [11:48<1:20:42, 93.13s/trial, best loss: 0.6437892481832604] 

build_posterior_wrapper took 0.030781 secondse 10:>                 (0 + 1) / 1]
TPE using 11/11 trials with best loss 0.643789
Closing down clientserver connection                                            


 15%|█▌        | 9/60 [11:50<54:57, 64.66s/trial, best loss: 0.6437892481832604]  

build_posterior_wrapper took 0.006282 seconds
TPE using 12/12 trials with best loss 0.643789
Closing down clientserver connection                                            
[Stage 11:>                                                         (0 + 1) / 1]

 17%|█▋        | 10/60 [12:30<47:24, 56.88s/trial, best loss: 0.6437892481832604]

build_posterior_wrapper took 0.017812 seconds
TPE using 13/13 trials with best loss 0.643789
24/02/17 16:23:16 WARN HikariPool: HikariPool-1 - Thread starvation or clock leap detected (housekeeper delta=45s439ms).
24/02/17 16:23:16 WARN HikariPool: HikariPool-2 - Thread starvation or clock leap detected (housekeeper delta=45s436ms).
Closing down clientserver connection                                            


 18%|█▊        | 11/60 [15:39<1:19:29, 97.34s/trial, best loss: 0.6437892481832604]

build_posterior_wrapper took 0.017210 seconds
Closing down clientserver connection
TPE using 14/14 trials with best loss 0.643789
[Stage 12:>                                                         (0 + 1) / 1]

 20%|██        | 12/60 [15:40<54:32, 68.19s/trial, best loss: 0.6437892481832604]  

build_posterior_wrapper took 0.003382 seconds
TPE using 15/15 trials with best loss 0.643789
Closing down clientserver connection                                            


 22%|██▏       | 13/60 [15:56<40:57, 52.29s/trial, best loss: 0.6437892481832604]

build_posterior_wrapper took 0.006494 seconds
TPE using 16/16 trials with best loss 0.643789
Closing down clientserver connection                                            
[Stage 14:>                 (0 + 1) / 1][Stage 16:>                 (0 + 1) / 1]

 23%|██▎       | 14/60 [18:21<1:01:28, 80.18s/trial, best loss: 0.6437892481832604]

build_posterior_wrapper took 0.004367 seconds
TPE using 17/17 trials with best loss 0.643789
Closing down clientserver connection                                            


 25%|██▌       | 15/60 [18:24<42:45, 57.00s/trial, best loss: 0.5279504399227686]  

build_posterior_wrapper took 0.005603 seconds                       (0 + 1) / 1]
TPE using 18/18 trials with best loss 0.527950




Closing down clientserver connection                                            


 27%|██▋       | 16/60 [19:55<49:21, 67.30s/trial, best loss: 0.5279504399227686]

build_posterior_wrapper took 0.010563 seconds
TPE using 19/19 trials with best loss 0.527950 18:>                 (0 + 1) / 1]
Closing down clientserver connection                                            


 28%|██▊       | 17/60 [22:43<1:09:58, 97.64s/trial, best loss: 0.5279504399227686]

build_posterior_wrapper took 0.016223 secondse 19:>                 (0 + 1) / 1]
TPE using 20/20 trials with best loss 0.527950
Closing down clientserver connection                                            
[Stage 19:>                 (0 + 1) / 1][Stage 20:>                 (0 + 1) / 1]

 30%|███       | 18/60 [26:57<1:41:09, 144.51s/trial, best loss: 0.5279504399227686]

build_posterior_wrapper took 0.032197 seconds
TPE using 21/21 trials with best loss 0.527950
Closing down clientserver connection                                            


 32%|███▏      | 19/60 [27:39<1:17:38, 113.62s/trial, best loss: 0.5279504399227686]

build_posterior_wrapper took 0.155659 seconds                       (0 + 1) / 1]
TPE using 22/22 trials with best loss 0.527950 21:>                 (0 + 1) / 1]
Closing down clientserver connection                                            


 33%|███▎      | 20/60 [28:05<58:11, 87.29s/trial, best loss: 0.5279504399227686]   

build_posterior_wrapper took 0.003725 seconds                       (0 + 1) / 1]
TPE using 23/23 trials with best loss 0.527950
Closing down clientserver connection                                            


 35%|███▌      | 21/60 [29:29<56:16, 86.58s/trial, best loss: 0.5279504399227686]

build_posterior_wrapper took 0.014530 seconds
TPE using 24/24 trials with best loss 0.527950 23:>                 (0 + 1) / 1]
Closing down clientserver connection                                            


 37%|███▋      | 22/60 [31:02<56:03, 88.52s/trial, best loss: 0.51280440351991]  

build_posterior_wrapper took 0.003256 secondse 24:>                 (0 + 1) / 1]
TPE using 25/25 trials with best loss 0.512804
Closing down clientserver connection                                            


 38%|███▊      | 23/60 [33:34<1:06:11, 107.33s/trial, best loss: 0.5122738079398443]

build_posterior_wrapper took 0.014985 seconds
TPE using 26/26 trials with best loss 0.512274
Closing down clientserver connection                                            


 40%|████      | 24/60 [34:07<51:05, 85.16s/trial, best loss: 0.5122738079398443]   

build_posterior_wrapper took 0.044216 seconds
TPE using 27/27 trials with best loss 0.512274
Closing down clientserver connection                                            
[Stage 24:>                 (0 + 1) / 1][Stage 27:>                 (0 + 1) / 1]

 42%|████▏     | 25/60 [36:37<1:01:02, 104.64s/trial, best loss: 0.5122738079398443]

build_posterior_wrapper took 0.003735 seconds
TPE using 28/28 trials with best loss 0.512274
Closing down clientserver connection                                            


 43%|████▎     | 26/60 [37:01<45:31, 80.33s/trial, best loss: 0.5122738079398443]   

build_posterior_wrapper took 0.035589 seconds                       (0 + 1) / 1]
TPE using 29/29 trials with best loss 0.512274
Closing down clientserver connection                                            


 45%|████▌     | 27/60 [38:03<41:06, 74.74s/trial, best loss: 0.5122738079398443]

build_posterior_wrapper took 0.066373 seconds                       (0 + 1) / 1]
TPE using 30/30 trials with best loss 0.512274
Closing down clientserver connection                                            
[Stage 28:>                 (0 + 1) / 1][Stage 30:>                 (0 + 1) / 1]

 47%|████▋     | 28/60 [44:04<1:25:48, 160.89s/trial, best loss: 0.5122738079398443]

build_posterior_wrapper took 0.008146 seconds
TPE using 31/31 trials with best loss 0.512274
Closing down clientserver connection                                            
[Stage 30:>                 (0 + 1) / 1][Stage 31:>                 (0 + 1) / 1]

 48%|████▊     | 29/60 [44:50<1:05:16, 126.34s/trial, best loss: 0.5122738079398443]

build_posterior_wrapper took 0.048672 seconds
TPE using 32/32 trials with best loss 0.512274
