# Overview

In [1]:
import os
import re
import tempfile
import warnings

import japanize_matplotlib
import lightgbm as lgb
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import mlflow
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll.base import scope
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema
from pyspark.sql import SparkSession
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import RFE
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    accuracy_score,
    auc,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler

from JapanHorseRaceAnalytics.utilities.base import get_base_dir, read_hive_table
from JapanHorseRaceAnalytics.utilities.metrics import (
    calculate_binary_classifier_statistics,
)
from JapanHorseRaceAnalytics.utilities.mlflow import get_colspecs
from JapanHorseRaceAnalytics.utilities.structured_logger import logger

# Set pandas display options
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 200)

# Set seed for reproducibility
os.environ["PYTHONHASHSEED"] = str(42)
np.random.seed(42)
tf.random.set_seed(42)
random_state = 42

# Prepare the data

In [2]:
warehouse_dir = f"{get_base_dir()}/spark-warehouse"
postgres_driver_path = f"{get_base_dir()}/jars/postgresql-42.7.1.jar"

spark = (
    SparkSession.builder.appName("20240211_competitors")
    .config("spark.driver.memory", "21g")
    .config("spark.sql.warehouse.dir", warehouse_dir)
    .config("spark.jars", postgres_driver_path)
    .config("spark.executor.extraClassPath", postgres_driver_path)
    .config("spark.driver.extraClassPath", postgres_driver_path)
    .enableHiveSupport()
    .getOrCreate()
)

24/02/24 16:43:21 WARN Utils: Your hostname, Hanks-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.40.105 instead (on interface en0)
24/02/24 16:43:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/02/24 16:43:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/24 16:43:22 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
data = read_hive_table(
    table_name="features_20240217_v1",
    schema="jhra_curated",
    spark_session=spark,
    # use_cache=False,
)

# drop from data where cat_トラック種別 == "障害"
data = data[(data["cat_トラック種別"] != "障害")]

# This would mess up the number of horses in the race
# drop from data where meta_int_race_horses_異常区分 != '0'
# data = data[(data["meta_int_race_horses_異常区分"] == "0")]

data = data.reset_index(drop=True)
data.head()

{"event": "Read from hive jhra_curated.features_20240217_v1", "level": "info", "timestamp": "2024-02-24T07:43:23.622759Z", "logger": "JapanHorseRaceAnalytics.utilities.base"}
24/02/24 16:43:24 WARN HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist
24/02/24 16:43:24 WARN HiveConf: HiveConf of name hive.stats.retries.wait does not exist
24/02/24 16:43:25 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
{"event": "Write to parquet /Users/hankehly/Projects/JapanHorseRaceAnalytics/data/sql_tables/features_20240217_v1.snappy.parquet", "level": "info", "timestamp": "2024-02-24T07:43:26.130335Z", "logger": "JapanHorseRaceAnalytics.utilities.base"}
24/02/24 16:43:26 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
{"event": "Read from parquet /Users/hankehly/Projects/JapanHorseRaceAnalytics/data/sql_tables/features_2024021

Unnamed: 0,meta_レースキー,meta_馬番,meta_着順,meta_本賞金,meta_単勝的中,meta_単勝払戻金,meta_複勝的中,meta_複勝払戻金,meta_int_races_レースキー,meta_発走日時,meta_場コード,cat_四半期,cat_距離,cat_事前_馬場状態コード,cat_事前_レース条件_トラック情報_右左,cat_事前_レース条件_トラック情報_内外,cat_事前_レース条件_種別,cat_事前_レース条件_条件,cat_事前_レース条件_記号,cat_事前_レース条件_重量,cat_事前_レース条件_グレード,num_事前_馬場差,num_頭数,cat_トラック種別,cat_馬場状態内,cat_馬場状態中,cat_馬場状態外,num_直線馬場差最内,num_直線馬場差内,num_直線馬場差中,num_直線馬場差外,num_直線馬場差大外,cat_芝種類,cat_草丈,cat_転圧,cat_凍結防止剤,num_中間降水量,meta_int_race_horses_レースキー,meta_int_race_horses_馬番,meta_int_race_horses_血統登録番号,meta_int_race_horses_発走日時,meta_int_race_horses_異常区分,num_事前ＩＤＭ,cat_事前脚質,num_事前単勝オッズ,num_事前複勝オッズ,cat_事前馬体,cat_事前気配コード,cat_事前上昇度,cat_事前クラスコード,num_事前テン指数,num_事前ペース指数,num_事前上がり指数,num_負担重量,num_馬体重,num_馬体重増減,cat_性別,cat_トラック種別瞬発戦好走馬,cat_トラック種別消耗戦好走馬,num_一走前不利,num_二走前不利,num_三走前不利,num_一走前着順,num_二走前着順,num_三走前着順,num_四走前着順,num_五走前着順,num_六走前着順,num_1走前上昇度,num_2走前上昇度,num_3走前上昇度,num_4走前上昇度,num_5走前上昇度,num_騎手指数,num_情報指数,num_オッズ指数,num_パドック指数,num_総合指数,cat_馬具変更情報,cat_脚元情報,cat_見習い区分,cat_オッズ印,cat_パドック印,cat_直前総合印,cat_距離適性,num_ローテーション,num_基準オッズ,num_基準人気順位,num_基準複勝オッズ,num_基準複勝人気順位,num_特定情報◎,num_特定情報○,num_特定情報▲,num_特定情報△,num_特定情報×,num_総合情報◎,num_総合情報○,num_総合情報▲,num_総合情報△,num_総合情報×,...,num_競争相手平均調教師1位完走,num_競争相手調教師1位完走標準偏差,num_競争相手最高調教師トップ3完走,num_競争相手最低調教師トップ3完走,num_競争相手平均調教師トップ3完走,num_競争相手調教師トップ3完走標準偏差,num_競争相手最高調教師1位完走率,num_競争相手最低調教師1位完走率,num_競争相手平均調教師1位完走率,num_競争相手調教師1位完走率標準偏差,num_競争相手最高調教師トップ3完走率,num_競争相手最低調教師トップ3完走率,num_競争相手平均調教師トップ3完走率,num_競争相手調教師トップ3完走率標準偏差,num_競争相手最高調教師場所レース数,num_競争相手最低調教師場所レース数,num_競争相手平均調教師場所レース数,num_競争相手調教師場所レース数標準偏差,num_競争相手最高調教師場所1位完走,num_競争相手最低調教師場所1位完走,num_競争相手平均調教師場所1位完走,num_競争相手調教師場所1位完走標準偏差,num_競争相手最高調教師場所トップ3完走,num_競争相手最低調教師場所トップ3完走,num_競争相手平均調教師場所トップ3完走,num_競争相手調教師場所トップ3完走標準偏差,num_競争相手最高調教師場所1位完走率,num_競争相手最低調教師場所1位完走率,num_競争相手平均調教師場所1位完走率,num_競争相手調教師場所1位完走率標準偏差,num_競争相手最高調教師場所トップ3完走率,num_競争相手最低調教師場所トップ3完走率,num_競争相手平均調教師場所トップ3完走率,num_競争相手調教師場所トップ3完走率標準偏差,num_競争相手最高調教師本賞金累計,num_競争相手最低調教師本賞金累計,num_競争相手平均調教師本賞金累計,num_競争相手調教師本賞金累計標準偏差,num_競争相手最高調教師1位完走平均賞金,num_競争相手最低調教師1位完走平均賞金,num_競争相手平均調教師1位完走平均賞金,num_競争相手調教師1位完走平均賞金標準偏差,num_競争相手最高調教師レース数平均賞金,num_競争相手最低調教師レース数平均賞金,num_競争相手平均調教師レース数平均賞金,num_競争相手調教師レース数平均賞金標準偏差,num_競争相手平均調教師レース数差,num_競争相手平均調教師1位完走差,num_競争相手平均調教師トップ3完走差,num_競争相手平均調教師1位完走率差,num_競争相手平均調教師トップ3完走率差,num_競争相手平均調教師場所レース数差,num_競争相手平均調教師場所1位完走差,num_競争相手平均調教師場所トップ3完走差,num_競争相手平均調教師場所1位完走率差,num_競争相手平均調教師場所トップ3完走率差,num_競争相手平均調教師本賞金累計差,num_競争相手平均調教師1位完走平均賞金差,num_競争相手平均調教師レース数平均賞金差,meta_int_combinations_レースキー,meta_int_combinations_馬番,num_馬騎手レース数,num_馬騎手1位完走,num_馬騎手1位完走率,num_馬騎手トップ3完走,num_馬騎手トップ3完走率,num_馬騎手初二走,num_馬騎手同騎手,num_馬騎手場所レース数,num_馬騎手場所1位完走,num_馬騎手場所1位完走率,num_馬騎手場所トップ3完走,num_馬騎手場所トップ3完走率,num_馬調教師レース数,num_馬調教師1位完走,num_馬調教師1位完走率,num_馬調教師トップ3完走,num_馬調教師トップ3完走率,num_馬調教師初二走,num_馬調教師同調教師,num_馬調教師場所レース数,num_馬調教師場所1位完走,num_馬調教師場所1位完走率,num_馬調教師場所トップ3完走,num_馬調教師場所トップ3完走率,meta_int_race_weather_レースキー,num_temperature,num_precipitation,num_snowfall,num_snow_depth,num_wind_speed,cat_wind_direction,num_solar_radiation,num_local_air_pressure,num_sea_level_air_pressure,num_relative_humidity,num_vapor_pressure,num_dew_point_temperature,cat_weather,num_visibility
0,1011103,4,6.0,0.0,False,0,False,0,1011103,2001-08-04 01:45:00,1,3,1200,20,1,1,12,A3,102,3,,,16,芝,1,1,1,1,1,0,0,0,,,False,False,,1011103,4,98102049,2001-08-04 01:45:00,0,36.0,好位差し,11.5,2.9,,,3,18.0,-12.4,-21.1,-10.9,550,476.0,14,牡,True,False,0.0,0.0,0.0,7.0,2.0,7.0,,,,3.0,3.0,3.0,,,0.4,0.4,0.0,1.8,38.6,0,0,0,,4.0,4.0,5,4.0,16.8,6,3.4,6,0,0,0,10,0,3,6,8,87,0,...,34.133333,18.575492,204,32,98.066667,44.395896,0.162376,0.027668,0.075334,0.036725,0.40396,0.120735,0.213261,0.076365,47,3,18.733333,12.390677,8,0,2.066667,2.112397,15,0,4.733333,4.464178,0.333333,0.0,0.09418,0.088041,0.5,0.0,0.207743,0.136352,231606.0,14187.0,65859.266667,53242.331542,1687.439024,567.857143,913.419366,273.792817,458.625743,49.167979,141.408061,101.993179,73.0,-3.133333,-5.066667,-0.016287,-0.036119,-6.733333,-1.066667,-0.733333,-0.010847,0.12559,-18662.266667,-159.225818,-51.509013,1011103,4,2,0,0.0,0,0.0,False,False,0,0,0.0,0,0.0,9,0,0.0,1,0.111111,False,True,0,0,0.0,0,0.0,1011103,22.8,0.0,,0.0,3.9,北西,2.93,1010.95,1013.95,60.75,16.875,14.85,,
1,1011103,9,2.0,200.0,False,0,True,120,1011103,2001-08-04 01:45:00,1,3,1200,20,1,1,12,A3,102,3,,,16,芝,1,1,1,1,1,0,0,0,,,False,False,,1011103,9,98102902,2001-08-04 01:45:00,0,38.0,先行,4.4,1.6,,,3,16.0,-10.6,-23.6,-5.1,550,482.0,0,牡,True,False,0.0,0.0,,3.0,2.0,,,,,3.0,3.0,,,,1.6,2.5,2.5,2.0,46.6,0,2,0,3.0,3.0,3.0,5,3.0,4.2,2,1.5,2,5,2,4,0,0,30,45,33,27,0,...,34.6,18.402174,204,32,99.0,44.131621,0.162376,0.027668,0.076291,0.036095,0.40396,0.120735,0.215263,0.074985,47,3,18.933333,12.25543,8,0,2.066667,2.112397,15,0,4.8,4.445222,0.333333,0.0,0.092329,0.087958,0.5,0.0,0.207743,0.136352,231606.0,14187.0,66490.8,52960.294858,1687.439024,567.857143,916.365603,271.818667,458.625743,49.167979,142.718028,101.193884,85.8,-10.6,-20.0,-0.031598,-0.06815,-9.933333,-1.066667,-1.8,0.018783,0.12559,-28766.8,-206.365603,-72.468494,1011103,9,3,0,0.0,2,0.666667,False,True,0,0,0.0,0,0.0,8,0,0.0,3,0.375,False,True,0,0,0.0,0,0.0,1011103,22.8,0.0,,0.0,3.9,北西,2.93,1010.95,1013.95,60.75,16.875,14.85,,
2,1011204,14,6.0,0.0,False,0,False,0,1011204,2001-08-05 02:15:00,1,3,1800,10,1,1,12,A3,102,3,,,14,芝,1,1,1,1,1,0,0,0,,,False,False,,1011204,14,98110058,2001-08-05 02:15:00,0,33.4,差し,7.6,2.3,,,3,38.0,-16.2,-9.6,-23.2,550,470.0,6,牡,False,False,0.0,,,8.0,,,,,,3.0,,,,,1.3,0.0,1.5,3.0,39.2,0,0,0,5.0,2.0,6.0,5,6.0,8.6,6,2.7,6,0,1,1,7,0,11,7,9,71,0,...,41.384615,23.470238,231,21,120.846154,58.455566,0.150621,0.033766,0.083075,0.030124,0.358696,0.119481,0.248023,0.072792,65,1,25.230769,17.129165,5,0,2.153846,1.511299,20,0,8.076923,5.980244,0.2,0.0,0.081879,0.051502,0.6,0.0,0.289074,0.151223,164902.0,12493.0,74442.5,47183.126061,1498.970588,626.153846,964.732308,274.777458,266.831715,48.657143,150.209357,65.950213,47.076923,13.615385,35.153846,0.021887,0.049687,15.769231,6.846154,9.923077,0.137634,0.14995,49377.5,316.358601,86.088353,1011204,14,4,0,0.0,0,0.0,False,False,0,0,0.0,0,0.0,10,0,0.0,1,0.1,False,True,0,0,0.0,0,0.0,1011204,22.3,0.0,,0.0,4.625,北北西,3.14,1010.325,1013.325,64.0,17.225,15.15,1.0,30.0
3,1011303,6,3.0,130.0,False,0,True,1090,1011303,2001-08-11 01:45:00,1,3,1700,10,1,1,12,A3,2,3,,,13,ダート,1,1,1,1,1,0,0,0,,,False,False,,1011303,6,98103267,2001-08-11 01:45:00,0,17.0,差し,50.8,8.7,,,3,18.0,-8.4,-24.3,-28.8,550,436.0,-4,牡,False,False,1.0,0.0,0.0,12.0,13.0,7.0,,,,3.0,3.0,3.0,,,0.3,-1.0,0.0,0.0,16.3,0,0,0,,,,5,0.0,89.5,13,14.5,13,0,0,0,0,0,0,0,0,1,0,...,38.25,16.573699,180,50,122.25,44.358624,0.096825,0.029126,0.064903,0.022876,0.294828,0.121359,0.208923,0.055122,65,2,35.25,19.472737,10,0,2.666667,2.838231,20,1,8.083333,7.193265,0.153846,0.0,0.053959,0.048886,0.5,0.039216,0.256171,0.157059,172151.0,24692.0,77010.125,41784.766816,1655.294118,716.296296,972.226859,235.029654,296.812069,59.932039,130.679572,63.348222,-41.333333,13.75,21.75,0.033582,0.063804,7.75,1.333333,7.916667,0.039064,0.115922,-1223.125,-199.919166,12.856413,1011303,6,1,0,0.0,0,0.0,True,True,1,0,0.0,0,0.0,5,0,0.0,0,0.0,False,True,1,0,0.0,0,0.0,1011303,23.475,0.0,,0.0,0.825,北,1.435,1009.925,1012.925,65.0,18.725,16.475,,
4,1011304,7,1.0,510.0,True,230,True,120,1011304,2001-08-11 02:15:00,1,3,2000,10,1,1,12,A3,102,3,,,16,芝,1,1,1,1,1,0,0,0,,,False,False,,1011304,7,98101610,2001-08-11 02:15:00,0,42.8,追込,2.2,1.0,,,2,,-21.6,-33.9,7.1,550,502.0,0,牡,False,False,,,,,,,,,,,,,,,2.9,4.1,3.5,1.8,55.1,0,3,0,1.0,4.0,1.0,2,43.0,2.8,1,1.3,1,8,3,0,0,0,47,55,18,11,0,...,47.866667,20.603775,232,66,137.733333,48.078362,0.151235,0.02729,0.085412,0.02996,0.358025,0.128655,0.246196,0.064278,66,7,29.933333,19.739864,7,0,2.466667,2.499778,20,0,7.666667,6.559133,0.269231,0.0,0.070152,0.073384,0.5,0.0,0.229625,0.133773,208944.0,27159.0,93025.866667,46309.291859,1294.605263,567.857143,950.1511,181.940753,343.657895,52.94152,164.958116,69.914071,139.066667,97.133333,153.266667,0.125038,0.176155,45.066667,15.533333,30.333333,0.169848,0.277041,230843.133333,516.538555,305.098488,1011304,7,1,0,0.0,1,1.0,True,True,0,0,0.0,0,0.0,1,0,0.0,1,1.0,True,True,0,0,0.0,0,0.0,1011304,24.125,0.0,,0.0,1.3,北北西,1.7175,1009.85,1012.825,61.5,18.425,16.225,2.0,30.0


# Train/test split

In [4]:
X = data
y = data["meta_複勝的中"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=random_state
)

print(f"X_train: {X_train.shape}")
print(f"X_test: {X_test.shape}")
print(f"y_train: {y_train.shape}")
print(f"y_test: {y_test.shape}")

X_train: (857061, 1008)
X_test: (214266, 1008)
y_train: (857061,)
y_test: (214266,)


In [5]:
assert len(X_train.filter(regex="実績", axis=1).columns.tolist()) == 0

In [None]:
num_columns = X_train.filter(regex="num_", axis=1).columns.tolist()
cat_columns = X_train.filter(regex="cat_", axis=1).columns.tolist()
ord_columns = X_train.filter(regex="ord_", axis=1).columns.tolist()
meta_columns = X_train.filter(regex="meta_", axis=1).columns.tolist()

print(f"num_columns: {len(num_columns)}, cat_columns: {len(cat_columns)}, ord_columns: {len(ord_columns)}, meta_columns: {len(meta_columns)}")
print(f"Total columns: {len(num_columns) + len(cat_columns) + len(ord_columns) + len(meta_columns)}, X_train.shape[1]: {X_train.shape[1]}")

# Define objective function

In [6]:
def create_objective_fn(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    X_test: pd.DataFrame,
    y_test: pd.Series,
    df_payout: pd.DataFrame,
    experiment_name: str,
):
    """
    df_payout should have the same index as *_test and have the following columns:
    * レースキー
    * 馬番
    * 距離
    * 発走日時
    * 年齢
    * payout - amount won if betting 100 yen.
    """

    def train(params):
        def profit_loss(row, payout_column_name, bet_amount=100):
            if row["pred"] and row["actual"]:
                payout = row[payout_column_name] * (bet_amount / 100)
                return payout - bet_amount
            elif row["pred"] and not row["actual"]:
                return -bet_amount
            else:
                return 0

        mlflow.set_experiment(experiment_name=experiment_name)
        with mlflow.start_run():
            mlflow.log_params(params)

            categories = {"距離": [1000, 1150, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2100, 2200, 2300, 2400, 2500, 2600, 3000, 3200, 3400, 3600]}
            categorical_features = X_train.select_dtypes("category").columns.tolist()
            numeric_features = X_train.select_dtypes("number").columns.tolist()
            preprocessor = ColumnTransformer(
                transformers=[
                    ("ord", OrdinalEncoder(categories=[categories["距離"]]), ["距離"]),
                    ("num", StandardScaler(), numeric_features),
                    ("cat", OneHotEncoder(handle_unknown="ignore"), categorical_features),
                ]
            )
            model = Pipeline(
                steps=[
                    ("preprocessor", preprocessor),
                    ("classifier", lgb.LGBMClassifier(**params)),
                ]
            )

            model.fit(X_train, y_train)
            mlflow.sklearn.log_model(sk_model=model, artifact_path="model")

            y_pred_proba = model.predict_proba(X_test)[:, 1]
            y_pred = model.predict(X_test)

            metrics = {
                "loss": log_loss(y_test, y_pred_proba),
                "accuracy": accuracy_score(y_test, y_pred),
                "precision": precision_score(y_test, y_pred),
                "recall": recall_score(y_test, y_pred),
                "f1": f1_score(y_test, y_pred),
                "roc_auc": roc_auc_score(y_test, y_pred),
            }
            mlflow.log_metrics(metrics)

            # Calculate payout rates by group
            results = pd.concat(
                [
                    df_payout,
                    pd.DataFrame(
                        np.c_[y_test, y_pred, y_pred_proba],
                        columns=["actual", "pred", "pred_proba_true"],
                    ),
                ],
                axis=1,
            )
            payout_all = calculate_binary_classifier_statistics(
                results, group_by=None, payout_column_name="payout"
            )
            payout_month = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.month,
                payout_column_name="payout",
            )
            payout_distance = calculate_binary_classifier_statistics(
                results,
                group_by=pd.cut(x=results["距離"], bins=[0, 1400, 1800, 10000]),
                payout_column_name="payout",
            )
            payout_season = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.month % 12 // 3,
                payout_column_name="payout",
            )
            payout_year = calculate_binary_classifier_statistics(
                results,
                group_by=results["発走日時"].dt.year,
                payout_column_name="payout",
            )
            payout_age = calculate_binary_classifier_statistics(
                results,
                group_by=pd.cut(results["年齢"], bins=[0, 3, 6, 100]),
                payout_column_name="payout",
            )
            payout_racetrack = calculate_binary_classifier_statistics(
                results, group_by=results["場コード"], payout_column_name="payout"
            )
            payout = (
                pd.concat(
                    [
                        pd.DataFrame(payout_all).T.assign(group="all"),
                        pd.DataFrame(payout_month).T.assign(group="month"),
                        pd.DataFrame(payout_distance).T.assign(group="distance"),
                        pd.DataFrame(payout_season).T.assign(group="season"),
                        pd.DataFrame(payout_year).T.assign(group="year"),
                        pd.DataFrame(payout_age).T.assign(group="horse_age"),
                        pd.DataFrame(payout_racetrack).T.assign(group="racetrack"),
                    ],
                    axis=0,
                )
                .rename_axis(index="part")
                .reset_index()
            )
            # Move "group" and "part" columns to the first position in this dataframe
            payout = payout[
                ["group", "part"]
                + [c for c in payout.columns if c not in ["group", "part"]]
            ]

            # Save payout rates as csv
            with tempfile.NamedTemporaryFile(prefix="payout_rate_", suffix=".csv") as f:
                payout.to_csv(f.name, index=False)
                mlflow.log_artifact(f.name)

            # Log payout rates as metrics
            payout_metrics = {}
            for group_name, group in payout.groupby("group"):
                for i, row in group.iterrows():
                    key = re.sub(r"\W", "_", f"payout_rate_{group_name}_{row['part']}")
                    payout_metrics[key] = row["payout_rate"]
            mlflow.log_metrics(payout_metrics)

            # Suppress UserWarning messages from matplotlib
            warnings.filterwarnings("ignore", category=UserWarning)

            # Plot payout rates by group
            sns.set_theme(style="whitegrid")
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            for (group, df), ax in zip(payout.groupby("group"), axes.flatten()):
                sns.barplot(x="part", y="payout_rate", data=df, ax=ax)
                ax.set_title(group)
                ax.set_ylim(0, 150)
                ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
                ax.set_ylabel("payout rate")
                ax.set_xlabel("")
                ax.yaxis.set_major_formatter(ticker.PercentFormatter())
            with tempfile.NamedTemporaryFile(prefix="payout_rate_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Plot bank balance over time
            results["profit_loss"] = results.apply(
                profit_loss, args=("payout", 100), axis=1
            )
            daily_profit_loss = results.groupby("発走日時")["profit_loss"].sum()
            bank_balance = daily_profit_loss.cumsum()
            plt.figure(figsize=(10, 10))
            ax = plt.subplot(1, 1, 1)
            ax.plot(bank_balance.index, bank_balance.values)
            ax.set_title("Bank Balance")
            ax.set_xlabel("Date")
            ax.set_ylabel("Bank Balance")
            ax.grid(True)
            ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:,.0f}"))
            with tempfile.NamedTemporaryFile(
                prefix="bank_balance_", suffix=".png"
            ) as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Confusion Matrix
            conf_matrix = confusion_matrix(y_test, y_pred)
            _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            sns.heatmap(conf_matrix, annot=True, fmt="g", cmap="Blues", ax=ax1)
            ax1.set_xlabel("Predicted")
            ax1.set_ylabel("Actual")
            ax1.set_title("Confusion Matrix")
            sns.heatmap(
                conf_matrix / conf_matrix.sum(axis=1)[:, None],
                annot=True,
                fmt=".2%",
                cmap="Blues",
                ax=ax2,
            )
            ax2.set_xlabel("Predicted")
            ax2.set_ylabel("Actual")
            ax2.set_title("Normalized Confusion Matrix")
            with tempfile.NamedTemporaryFile(prefix="confusion_matrix_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # ROC Curve
            fpr, tpr, _ = roc_curve(y_test, y_pred)
            roc_auc = auc(fpr, tpr)
            _, ax = plt.subplots(figsize=(10, 10))
            ax.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.2f)" % roc_auc)
            ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
            ax.set_xlim([0.0, 1.0])
            ax.set_ylim([0.0, 1.0])
            ax.set_xlabel("False Positive Rate")
            ax.set_ylabel("True Positive Rate")
            ax.set_title("Receiver Operating Characteristic")
            ax.legend(loc="lower right")
            with tempfile.NamedTemporaryFile(prefix="roc_curve_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            # Feature Importances
            feature_importances = zip(
                model.named_steps["preprocessor"].get_feature_names_out(),
                model.named_steps["classifier"].feature_importances_,
            )
            feature_importances_df = (
                pd.DataFrame(
                    data=feature_importances, columns=["feature", "importance"]
                )
                .sort_values("importance", ascending=False)
                .reset_index(drop=True)
            )
            with tempfile.NamedTemporaryFile(
                prefix="feature_importance_", suffix=".csv"
            ) as f:
                feature_importances_df.to_csv(f.name, index=False)
                mlflow.log_artifact(f.name)

            sns.set_theme(style="whitegrid")
            plt.figure(figsize=(10, 12))
            ax = sns.barplot(x="importance", y="feature", data=feature_importances_df.iloc[:50])
            ax.set_title("Feature Importances (Top 50)")
            ax.set_xlabel("Importance")
            ax.set_ylabel("Features")
            with tempfile.NamedTemporaryFile(prefix="feature_importance_", suffix=".png") as f:
                plt.tight_layout()
                plt.savefig(f.name)
                plt.close()
                mlflow.log_artifact(f.name)

            return {"status": STATUS_OK, "params": params, "model": model, **metrics}

    return train

# Optimize hyperparameters

In [7]:
space = {
    "boosting_type": hp.choice("boosting_type", ["gbdt", "dart", "goss"]),
    "learning_rate": hp.loguniform("learning_rate", -5, 0),  # between e^-5 and 1
    "n_estimators": scope.int(hp.quniform("n_estimators", 100, 1000, 1)),
    "max_depth": scope.int(hp.quniform("max_depth", 3, 10, 1)),
    "num_leaves": scope.int(hp.quniform("num_leaves", 20, 150, 1)),
    "min_child_samples": scope.int(hp.quniform("min_child_samples", 20, 500, 1)),
    "feature_fraction": hp.uniform("feature_fraction", 0.5, 1.0),
    "lambda_l1": hp.uniform("lambda_l1", 0, 5),
    "lambda_l2": hp.uniform("lambda_l2", 0, 5),
    "min_split_gain": hp.uniform("min_split_gain", 0, 1),
    "min_child_weight": hp.uniform("min_child_weight", 0.001, 10),
    "subsample": hp.uniform("subsample", 0.5, 1),
    "colsample_bytree": hp.uniform("colsample_bytree", 0.5, 1),
    "reg_alpha": hp.uniform("reg_alpha", 0.0, 1.0),
    "reg_lambda": hp.uniform("reg_lambda", 0.0, 1.0),
    "objective": "binary",
    "class_weight": "balanced",
    "verbose": -1,
    "seed": 80,
}

In [9]:
df_payout_renamed_columns = {
    "meta_レースキー": "レースキー",
    "meta_馬番": "馬番",
    "cat_距離": "距離",
    "meta_発走日時": "発走日時",
    "meta_複勝払戻金": "payout",
    "num_年齢": "年齢",
    "meta_場コード": "場コード",
}

df_payout = (
    data.iloc[X_test.index]
    .reset_index(drop=True)
    .rename(columns=df_payout_renamed_columns)[df_payout_renamed_columns.values()]
)

experiment_name = "20240223_rfe_full_features_lgbm"
if mlflow.get_experiment_by_name(experiment_name) is None:
    mlflow.create_experiment(experiment_name)

fn = create_objective_fn(
    X_train,
    y_train,
    X_test,
    y_test,
    df_payout=df_payout,
    experiment_name=experiment_name,
)

In [None]:
# trials_turf = SparkTrials(parallelism=3, spark_session=spark)
trials_turf = Trials()
fmin(fn=fn, space=space, algo=tpe.suggest, max_evals=60, trials=trials_turf)

In [None]:
numeric_features = X_train.select_dtypes("number").columns.tolist()
# categorical_features = X_train.select_dtypes("category").columns.tolist()
preprocessor = ColumnTransformer(
    transformers=[
        ("num", StandardScaler(), numeric_features),
        # (
        #     "cat",
        #     OneHotEncoder(handle_unknown="ignore"),
        #     categorical_features,
        # ),
    ]
)
model = Pipeline(
    steps=[
        ("preprocessor", preprocessor),
        ("classifier", lgb.LGBMClassifier(**params)),
    ]
)

params = {
    "boosting_type": "gbdt",
    "class_weight": "balanced",
    "colsample_bytree": 0.8016642153767848,
    "feature_fraction": 0.5578235667548754,
    "lambda_l1": 2.551673582227088,
    "lambda_l2": 1.3506414200964172,
    "learning_rate": 0.02904727910263315,
    "max_depth": 10,
    "min_child_samples": 68,
    "min_child_weight": 7.736782598405014,
    "min_split_gain": 0.0071078853628913415,
    "n_estimators": 861,
    "num_leaves": 121,
    "objective": "binary",
    "reg_alpha": 0.25409327833670503,
    "reg_lambda": 0.4275373164043184,
    "seed": 42,
    "subsample": 0.9179630226670973,
    "verbose": -1,
}

rfe = RFE(estimator=lgb.LGBMClassifier(), step=1)
rfe.fit(X_train, y_train)