In [1]:
import os
import gc
import glob
from joblib import Parallel, delayed, dump, load
import pandas as pd
from pandas.core.common import flatten
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, QuantileTransformer
from sklearn.cluster import KMeans
import scipy as sc
from sklearn.model_selection import KFold, GroupKFold
import lightgbm as lgb
import warnings
warnings.filterwarnings("ignore")
pd.set_option("max_columns", 300)

# Parameters

In [2]:
CONFIG = {
    "root_dir": "../../input/optiver-realized-volatility-prediction/",
    "ckpt_dir": "../../ckpts",
    "kfold_seed": 42,
    "n_splits": 5,
    "n_clusters": 7,
}

# Read train test

In [3]:
def read_train_test():
    
    train = pd.read_csv("../../input/optiver-realized-volatility-prediction/train.csv")
    test = pd.read_csv("../../input/optiver-realized-volatility-prediction/test.csv")
    
    # Create a key to merge with book and trade data
    train["row_id"] = train["stock_id"].astype(str) + "-" + train["time_id"].astype(str)
    test["row_id"] = test["stock_id"].astype(str) + "-" + test["time_id"].astype(str)
    
    print("Our training set has {} rows".format(train.shape[0]))
    
    return train, test

# Define basic metric function

In [4]:
def activity_counts(df):
    activity_counts_ = df.groupby(["time_id"])["seconds_in_bucket"].agg("count").reset_index()
    activity_counts_ = activity_counts_.rename(columns={"seconds_in_bucket": "activity_counts"})
    return activity_counts_


def calc_wap(df, pos=1):
    wap = (df["bid_price{}".format(pos)] * df["ask_size{}".format(pos)] + df["ask_price{}".format(pos)] * df[
        "bid_size{}".format(pos)]) / (df["bid_size{}".format(pos)] + df["ask_size{}".format(pos)])
    return wap


def calc_wap2(df, pos=1):
    wap = (df["bid_price{}".format(pos)] * df["bid_size{}".format(pos)] + df["ask_price{}".format(pos)] * df[
        "ask_size{}".format(pos)]) / (df["bid_size{}".format(pos)] + df["ask_size{}".format(pos)])
    return wap


def wp(df):
    wp_ = (df["bid_price1"] * df["bid_size1"] + df["ask_price1"] * df["ask_size1"] + df["bid_price2"] * df[
        "bid_size2"] + df["ask_price2"] * df["ask_size2"]) / (
                  df["bid_size1"] + df["ask_size1"] + df["bid_size2"] + df["ask_size2"])
    return wp_


def maximum_drawdown(series, window=600):
    # window for 10 minutes, use min_periods=1 if you want to allow the expanding window
    roll_max = series.rolling(window, min_periods=1).max()
    second_drawdown = series / roll_max - 1.0
    max_drawdown = second_drawdown.rolling(window, min_periods=1).min()

    return max_drawdown


def log_return(series):
    return np.log(series).diff().fillna(0)


def rolling_log_return(series, rolling=60):
    return np.log(series.rolling(rolling)).diff().fillna(0)


def realized_volatility(series):
    return np.sqrt(np.sum(series ** 2))


def diff(series):
    return series.diff().fillna(0)


def time_diff(series):
    return series.diff().fillna(series)


def order_flow_imbalance(df, pos=1):
    df["bid_price{}_diff".format(pos)] = df.groupby(["time_id"])["bid_price{}".format(pos)].apply(diff)
    df["bid_size{}_diff".format(pos)] = df.groupby(["time_id"])["bid_price{}".format(pos)].apply(diff)
    df["bid_order_flow{}".format(pos)] = df["bid_size{}".format(pos)].copy(deep=True)
    df["bid_order_flow{}".format(pos)].loc[df["bid_price{}_diff".format(pos)] < 0] *= -1
    df["bid_order_flow{}".format(pos)].loc[df["bid_price{}_diff".format(pos)] == 0] = \
        df["bid_size{}_diff".format(pos)].loc[df["bid_price{}_diff".format(pos)] == 0]

    df["ask_price{}_diff".format(pos)] = df.groupby(["time_id"])["ask_price{}".format(pos)].apply(diff)
    df["ask_size{}_diff".format(pos)] = df.groupby(["time_id"])["ask_price{}".format(pos)].apply(diff)
    df["ask_order_flow{}".format(pos)] = df["ask_size{}".format(pos)].copy(deep=True)
    df["ask_order_flow{}".format(pos)].loc[df["ask_price{}_diff".format(pos)] < 0] *= -1
    df["ask_order_flow{}".format(pos)].loc[df["ask_price{}_diff".format(pos)] == 0] = \
        df["ask_size{}_diff".format(pos)].loc[df["ask_price{}_diff".format(pos)] == 0]

    order_flow_imbalance_ = df["bid_order_flow{}".format(pos)] - df["ask_order_flow{}".format(pos)]

    df.drop(["bid_price{}_diff".format(pos), "bid_size{}_diff".format(pos), "bid_order_flow{}".format(pos),
             "ask_price{}_diff".format(pos), "ask_size{}_diff".format(pos), "ask_order_flow{}".format(pos)], axis=1,
            inplace=True)

    return order_flow_imbalance_ + 1e-8


def order_book_slope(df):

    df["mid_point"] = (df["bid_price1"] + df["ask_price1"]) / 2
    best_mid_point_ = df.groupby(["time_id"])["mid_point"].agg("max").reset_index()
    best_mid_point_ = best_mid_point_.rename(columns={"mid_point": "best_mid_point"})
    df = df.merge(best_mid_point_, how="left", on="time_id")

    best_mid_point = df["best_mid_point"].copy()
    df.drop(["mid_point", "best_mid_point"], axis=1, inplace=True)

    def ratio(series):
        ratio_ = series / series.shift()
        return ratio_

    bid_price1_ratio = df.groupby(["time_id"])["bid_price1"].apply(ratio)
    bid_price1_mid_point_ratio = df["bid_price1"] / best_mid_point
    bid_price1_ratio = abs(bid_price1_ratio.fillna(bid_price1_mid_point_ratio) - 1)

    bid_size1_ratio = df.groupby(["time_id"])["bid_size1"].apply(ratio) - 1
    bid_size1_ratio = bid_size1_ratio.fillna(df["bid_size1"])
    df["DE"] = (bid_size1_ratio / bid_price1_ratio).replace([np.inf, -np.inf], np.nan).fillna(0)

    ask_price1_ratio = df.groupby(["time_id"])["ask_price1"].apply(ratio)
    ask_price1_mid_point_ratio = df["ask_price1"] / best_mid_point
    ask_price1_ratio = abs(ask_price1_ratio.fillna(ask_price1_mid_point_ratio) - 1)

    ask_size1_ratio = df.groupby(["time_id"])["ask_size1"].apply(ratio) - 1
    ask_size1_ratio = ask_size1_ratio.fillna(df["ask_size1"])
    df["SE"] = (ask_size1_ratio / ask_price1_ratio).replace([np.inf, -np.inf], np.nan).fillna(0)

    df["order_book_slope"] = (df["DE"] + df["SE"]) / 2
    order_book_slope_ = df.groupby(["time_id"])["order_book_slope"].agg("mean").reset_index()
    df.drop(["order_book_slope", "DE", "SE"], axis=1, inplace=True)

    return order_book_slope_


def ldispersion(df):
    LDispersion = 1 / 2 * (
            df["bid_size1"] / (df["bid_size1"] + df["bid_size2"]) * abs(df["bid_price1"] - df["bid_price2"]) + df[
        "ask_size1"] / (df["ask_size1"] + df["ask_size2"]) * abs(df["ask_price1"] - df["ask_price2"]))
    return LDispersion


def depth_imbalance(df, pos=1):
    depth_imbalance_ = (df["bid_size{}".format(pos)] - df["ask_size{}".format(pos)]) / (
            df["bid_size{}".format(pos)] + df["ask_size{}".format(pos)])

    return depth_imbalance_


def height_imbalance(df, pos=1):
    height_imbalance_ = (df["bid_price{}".format(pos)] - df["ask_price{}".format(pos)]) / (
            df["bid_price{}".format(pos)] + df["ask_price{}".format(pos)])

    return height_imbalance_


def pressure_imbalance(df):
    mid_price = (df["bid_price1"] + df["ask_price1"]) / 2

    weight_buy = mid_price / (mid_price - df["bid_price1"]) + mid_price / (mid_price - df["bid_price2"])
    pressure_buy = df["bid_size1"] * (mid_price / (mid_price - df["bid_price1"])) / weight_buy + df["bid_size2"] * (
            mid_price / (mid_price - df["bid_price2"])) / weight_buy

    weight_sell = mid_price / (df["ask_price1"] - mid_price) + mid_price / (df["ask_price2"] - mid_price)
    pressure_sell = df["ask_size1"] * (mid_price / (df["ask_price1"] - mid_price)) / weight_sell + df["ask_size2"] * (
            mid_price / (df["ask_price2"] - mid_price)) / weight_sell

    pressure_imbalance_ = np.log(pressure_buy) - np.log(pressure_sell)

    return pressure_imbalance_


def relative_spread(df, pos=1):
    relative_spread_ = 2 * (df["ask_price{}".format(pos)] - df["bid_price{}".format(pos)]) / (
            df["ask_price{}".format(pos)] + df["bid_price{}".format(pos)])

    return relative_spread_


def count_unique(series):
    return len(np.unique(series))

# Processor

In [5]:
# Function to preprocess book data (for each stock id)
def book_preprocessor(file_path):
    df = pd.read_parquet(file_path)

    # float 64 to float 32
    float_cols = df.select_dtypes(include=[np.float64]).columns
    df[float_cols] = df[float_cols].astype(np.float32)

    # int 64 to int 32
    int_cols = df.select_dtypes(include=[np.int64]).columns
    df[int_cols] = df[int_cols].astype(np.int32)

    # Calculate seconds gap
    df["seconds_gap"] = df.groupby(["time_id"])["seconds_in_bucket"].apply(time_diff)

    # Calculate Wap
    df["wap1"] = calc_wap(df, pos=1)
    df["wap2"] = calc_wap(df, pos=2)

    # Calculate wap balance
    df["wap_balance"] = abs(df["wap1"] - df["wap2"])

    # Calculate log returns
    df["log_return1"] = df.groupby(["time_id"])["wap1"].apply(log_return)
    df["log_return2"] = df.groupby(["time_id"])["wap2"].apply(log_return)

    # Calculate spread
    df["bid_ask_spread1"] = df["ask_price1"] / df["bid_price1"] - 1
    df["bid_ask_spread2"] = df["ask_price2"] / df["bid_price2"] - 1

    # order flow imbalance
    df["order_flow_imbalance1"] = order_flow_imbalance(df, 1)
    df["order_flow_imbalance2"] = order_flow_imbalance(df, 2)

    # order book slope
    order_slope_ = order_book_slope(df)
    df = df.merge(order_slope_, how="left", on="time_id")

    # depth imbalance
    df["depth_imbalance1"] = depth_imbalance(df, pos=1)
    df["depth_imbalance2"] = depth_imbalance(df, pos=2)

    # height imbalance
    df["height_imbalance1"] = height_imbalance(df, pos=1)
    df["height_imbalance2"] = height_imbalance(df, pos=2)

    # pressure imbalance
    df["pressure_imbalance"] = pressure_imbalance(df)

    # total volume
    df["total_volume"] = (df["ask_size1"] + df["ask_size2"]) + (df["bid_size1"] + df["bid_size2"])

    # Dict for aggregations
    create_feature_dict = {
        "wap1": [np.sum, np.std],
        "wap2": [np.sum, np.std],
        "log_return1": [realized_volatility],
        "log_return2": [realized_volatility],
        "wap_balance": [np.sum, np.max, np.min, np.std],
        "bid_ask_spread1": [np.sum, np.max, np.min, np.std],
        "bid_ask_spread2": [np.sum, np.max, np.min, np.std],
        "order_flow_imbalance1": [np.sum, np.max, np.min, np.std],
        "order_flow_imbalance2": [np.sum, np.max, np.min, np.std],
        "order_book_slope": [np.mean, np.max],
        "depth_imbalance1": [np.sum, np.max, np.std],
        "depth_imbalance2": [np.sum, np.max, np.std],
        "height_imbalance1": [np.sum, np.max, np.std],
        "height_imbalance2": [np.sum, np.max, np.std],
        "pressure_imbalance": [np.sum, np.max, np.std],
        "total_volume": [np.sum],
        "seconds_gap": [np.mean]
    }
    create_feature_dict_time = {
        "log_return1": [realized_volatility],
        "log_return2": [realized_volatility],
        "wap_balance": [np.sum, np.max, np.min, np.std],
        "bid_ask_spread1": [np.sum, np.max, np.min, np.std],
        "bid_ask_spread2": [np.sum, np.max, np.min, np.std],
        "order_flow_imbalance1": [np.sum, np.max, np.min, np.std],
        "order_flow_imbalance2": [np.sum, np.max, np.min, np.std],
        "total_volume": [np.sum],
        "seconds_gap": [np.mean]
    }

    # Function to get group stats for different windows (seconds in bucket)
    def get_stats_window(feature_dict, seconds_in_bucket, add_suffix=False):
        # Group by the window
        df_feature_ = df[df["seconds_in_bucket"] >= seconds_in_bucket].groupby(["time_id"]).agg(
            feature_dict).reset_index()
        # Rename columns joining suffix
        df_feature_.columns = ["_".join(col) for col in df_feature_.columns]
        # Add a suffix to differentiate windows
        if add_suffix:
            df_feature_ = df_feature_.add_suffix("_" + str(seconds_in_bucket))
        return df_feature_

    # Get the stats for different windows
    windows = [0, 150, 300, 450]
    add_suffixes = [False, True, True, True]
    df_feature = None

    for window, add_suffix in zip(windows, add_suffixes):
        if df_feature is None:
            df_feature = get_stats_window(feature_dict=create_feature_dict, seconds_in_bucket=window,
                                          add_suffix=add_suffix)
        else:
            new_df_feature = get_stats_window(feature_dict=create_feature_dict_time, seconds_in_bucket=window,
                                              add_suffix=add_suffix)
            df_feature = df_feature.merge(new_df_feature, how="left", left_on="time_id_",
                                          right_on="time_id__{}".format(window))

            # Drop unnecesary time_ids
            df_feature.drop(["time_id__{}".format(window)], axis=1, inplace=True)

    # Create row_id so we can merge
    stock_id = file_path.split("=")[1]
    df_feature["row_id"] = df_feature["time_id_"].apply(lambda x: f"{stock_id}-{x}")
    df_feature.drop(["time_id_"], axis=1, inplace=True)

    return df_feature

In [6]:
# Function to preprocess trade data (for each stock id)
def trade_preprocessor(file_path):
    df = pd.read_parquet(file_path)

    # float 64 to float 32
    float_cols = df.select_dtypes(include=[np.float64]).columns
    df[float_cols] = df[float_cols].astype(np.float32)

    # int 64 to int 32
    int_cols = df.select_dtypes(include=[np.int64]).columns
    df[int_cols] = df[int_cols].astype(np.int32)

    # Calculate seconds gap
    df["seconds_gap"] = df.groupby(["time_id"])["seconds_in_bucket"].apply(time_diff)

    # Calculate log return
    df["price_log_return"] = df.groupby("time_id")["price"].apply(log_return)

    # Calculate volumes
    df["volumes"] = df["price"] * df["size"]

    # Dict for aggregations
    create_feature_dict = {
        "price_log_return": [realized_volatility],
        "volumes": [np.sum, np.max, np.std],
        "order_count": [np.sum],
        "seconds_gap": [np.mean]
    }
    create_feature_dict_time = {
        "price_log_return": [realized_volatility],
        "volumes": [np.sum, np.max, np.std],
        "order_count": [np.sum],
        "seconds_gap": [np.mean]
    }

    # Function to get group stats for different windows (seconds in bucket)
    def get_stats_window(feature_dict, seconds_in_bucket, add_suffix=False):
        # Group by the window
        df_feature_ = df[df["seconds_in_bucket"] >= seconds_in_bucket].groupby(["time_id"]).agg(
            feature_dict).reset_index()
        # Rename columns joining suffix
        df_feature_.columns = ["_".join(col) for col in df_feature_.columns]
        # Add a suffix to differentiate windows
        if add_suffix:
            df_feature_ = df_feature_.add_suffix("_" + str(seconds_in_bucket))
        return df_feature_

    # Get the stats for different windows
    windows = [0, 150, 300, 450]
    add_suffixes = [False, True, True, True]
    df_feature = None

    for window, add_suffix in zip(windows, add_suffixes):
        if df_feature is None:
            df_feature = get_stats_window(feature_dict=create_feature_dict, seconds_in_bucket=window,
                                          add_suffix=add_suffix)
        else:
            new_df_feature = get_stats_window(feature_dict=create_feature_dict_time, seconds_in_bucket=window,
                                              add_suffix=add_suffix)
            df_feature = df_feature.merge(new_df_feature, how="left", left_on="time_id_",
                                          right_on="time_id__{}".format(window))

            # Drop unnecesary time_ids
            df_feature.drop(["time_id__{}".format(window)], axis=1, inplace=True)

    def tendency(price, vol):
        df_diff = np.diff(price)
        val = (df_diff / price[1:]) * 100
        power = np.sum(val * vol[1:])
        return (power)

    lis = []
    for n_time_id in df["time_id"].unique():
        
        df_id = df[df["time_id"] == n_time_id]
        
        tendencyV = tendency(df_id["price"].values, df_id["size"].values)
        energy = np.mean(df_id["price"].values ** 2)

        lis.append(
            {
                "time_id": n_time_id,
                "tendency": tendencyV,
                "energy": energy,
            }
        )

    df_lr = pd.DataFrame(lis)
    df_feature = df_feature.merge(df_lr, how="left", left_on="time_id_", right_on="time_id")

    # Create row_id so we can merge
    df_feature = df_feature.add_prefix("trade_")
    stock_id = file_path.split("=")[1]
    df_feature["row_id"] = df_feature["trade_time_id_"].apply(lambda x: f"{stock_id}-{x}")
    df_feature.drop(["trade_time_id_", "trade_time_id"], axis=1, inplace=True)
    return df_feature

In [7]:
# Funtion to make preprocessing function in parallel (for each stock id)
def preprocessor(list_stock_ids, is_train = True):
    
    # Parrallel for loop
    def for_joblib(stock_id):
        # Train
        if is_train:
            file_path_book = CONFIG["root_dir"] + "book_train.parquet/stock_id=" + str(stock_id)
            file_path_trade = CONFIG["root_dir"] + "trade_train.parquet/stock_id=" + str(stock_id)
        # Test
        else:
            file_path_book = CONFIG["root_dir"] + "book_test.parquet/stock_id=" + str(stock_id)
            file_path_trade = CONFIG["root_dir"] + "trade_test.parquet/stock_id=" + str(stock_id)
    
        # Preprocess book and trade data and merge them
        df_tmp = pd.merge(book_preprocessor(file_path_book), trade_preprocessor(file_path_trade), on = "row_id", how = "left")
        
        # Return the merge dataframe
        return df_tmp
    
    # Use parallel api to call paralle for loop
    df = Parallel(n_jobs = -1, verbose = 1)(delayed(for_joblib)(stock_id) for stock_id in list_stock_ids)
    
    # Concatenate all the dataframes that return from Parallel
    df = pd.concat(df, ignore_index = True)
    
    return df

In [8]:
# Read train and test
train, test = read_train_test()

# Get unique stock ids 
train_stock_ids = train["stock_id"].unique()

# Preprocess them using Parallel and our single stock id functions
train_ = preprocessor(train_stock_ids, is_train=True)
train = train.merge(train_, on=["row_id"], how="left")

# Get unique stock ids 
test_stock_ids = test["stock_id"].unique()

# Preprocess them using Parallel and our single stock id functions
test_ = preprocessor(test_stock_ids, is_train=False)
test = test.merge(test_, on=["row_id"], how="left")

Our training set has 428932 rows


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 24 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:   49.7s
[Parallel(n_jobs=-1)]: Done 112 out of 112 | elapsed:  4.6min finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 24 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.1s finished


# Log Transformation

In [9]:
# abs log columns
abs_log_columns = [column for column in train.columns if 
                       "order_flow_imbalance" in column or 
                       "order_book_slope" in column or 
                       "depth_imbalance" in column or 
                       "pressure_imbalance" in column or
                       "total_volume" in column or
                       "seconds_gap" in column or
                       "trade_volumes" in column or
                       "trade_order_count" in column or
                       "trade_seconds_gap" in column or
                       "trade_tendency" in column
                      ]

# apply abs + 1e-8 + log
train[abs_log_columns] = (train[abs_log_columns].apply(np.abs) + 1e-8).apply(np.log)
test[abs_log_columns] = (test[abs_log_columns].apply(np.abs) + 1e-8).apply(np.log)

# Fill inf with nan

In [10]:
train = train.replace([np.inf, -np.inf], np.nan)
test = test.replace([np.inf, -np.inf], np.nan)

train = train.fillna(train.mean())
test = test.fillna(train.mean())

In [11]:
test.head()

Unnamed: 0,stock_id,time_id,row_id,wap1_sum,wap1_std,wap2_sum,wap2_std,log_return1_realized_volatility,log_return2_realized_volatility,wap_balance_sum,wap_balance_amax,wap_balance_amin,wap_balance_std,bid_ask_spread1_sum,bid_ask_spread1_amax,bid_ask_spread1_amin,bid_ask_spread1_std,bid_ask_spread2_sum,bid_ask_spread2_amax,bid_ask_spread2_amin,bid_ask_spread2_std,order_flow_imbalance1_sum,order_flow_imbalance1_amax,order_flow_imbalance1_amin,order_flow_imbalance1_std,order_flow_imbalance2_sum,order_flow_imbalance2_amax,order_flow_imbalance2_amin,order_flow_imbalance2_std,order_book_slope_mean,order_book_slope_amax,depth_imbalance1_sum,depth_imbalance1_amax,depth_imbalance1_std,depth_imbalance2_sum,depth_imbalance2_amax,depth_imbalance2_std,height_imbalance1_sum,height_imbalance1_amax,height_imbalance1_std,height_imbalance2_sum,height_imbalance2_amax,height_imbalance2_std,pressure_imbalance_sum,pressure_imbalance_amax,pressure_imbalance_std,total_volume_sum,seconds_gap_mean,log_return1_realized_volatility_150,log_return2_realized_volatility_150,wap_balance_sum_150,wap_balance_amax_150,wap_balance_amin_150,wap_balance_std_150,bid_ask_spread1_sum_150,bid_ask_spread1_amax_150,bid_ask_spread1_amin_150,bid_ask_spread1_std_150,bid_ask_spread2_sum_150,bid_ask_spread2_amax_150,bid_ask_spread2_amin_150,bid_ask_spread2_std_150,order_flow_imbalance1_sum_150,order_flow_imbalance1_amax_150,order_flow_imbalance1_amin_150,order_flow_imbalance1_std_150,order_flow_imbalance2_sum_150,order_flow_imbalance2_amax_150,order_flow_imbalance2_amin_150,order_flow_imbalance2_std_150,total_volume_sum_150,seconds_gap_mean_150,log_return1_realized_volatility_300,log_return2_realized_volatility_300,wap_balance_sum_300,wap_balance_amax_300,wap_balance_amin_300,wap_balance_std_300,bid_ask_spread1_sum_300,bid_ask_spread1_amax_300,bid_ask_spread1_amin_300,bid_ask_spread1_std_300,bid_ask_spread2_sum_300,bid_ask_spread2_amax_300,bid_ask_spread2_amin_300,bid_ask_spread2_std_300,order_flow_imbalance1_sum_300,order_flow_imbalance1_amax_300,order_flow_imbalance1_amin_300,order_flow_imbalance1_std_300,order_flow_imbalance2_sum_300,order_flow_imbalance2_amax_300,order_flow_imbalance2_amin_300,order_flow_imbalance2_std_300,total_volume_sum_300,seconds_gap_mean_300,log_return1_realized_volatility_450,log_return2_realized_volatility_450,wap_balance_sum_450,wap_balance_amax_450,wap_balance_amin_450,wap_balance_std_450,bid_ask_spread1_sum_450,bid_ask_spread1_amax_450,bid_ask_spread1_amin_450,bid_ask_spread1_std_450,bid_ask_spread2_sum_450,bid_ask_spread2_amax_450,bid_ask_spread2_amin_450,bid_ask_spread2_std_450,order_flow_imbalance1_sum_450,order_flow_imbalance1_amax_450,order_flow_imbalance1_amin_450,order_flow_imbalance1_std_450,order_flow_imbalance2_sum_450,order_flow_imbalance2_amax_450,order_flow_imbalance2_amin_450,order_flow_imbalance2_std_450,total_volume_sum_450,seconds_gap_mean_450,trade_price_log_return_realized_volatility,trade_volumes_sum,trade_volumes_amax,trade_volumes_std,trade_order_count_sum,trade_seconds_gap_mean,trade_price_log_return_realized_volatility_150,trade_volumes_sum_150,trade_volumes_amax_150,trade_volumes_std_150,trade_order_count_sum_150,trade_seconds_gap_mean_150,trade_price_log_return_realized_volatility_300,trade_volumes_sum_300,trade_volumes_amax_300,trade_volumes_std_300,trade_order_count_sum_300,trade_seconds_gap_mean_300,trade_price_log_return_realized_volatility_450,trade_volumes_sum_450,trade_volumes_amax_450,trade_volumes_std_450,trade_order_count_sum_450,trade_seconds_gap_mean_450,trade_tendency,trade_energy
0,0,4,0-4,3.001215,0.00017,3.00165,0.000153,0.000294,0.000252,0.000436,0.000169,0.000125,2.2e-05,0.001672,0.00059,0.000541,2.8e-05,0.003197,0.00123,0.000984,0.000142,2.995732,-17.727534,2.995732,2.446426,2.70805,-17.727534,2.70805,2.158744,11.666106,11.666106,-0.252666,-0.13815,-0.634768,0.703567,-0.299243,-2.740748,-0.000836,-0.00027,1.4e-05,-0.001598,-0.000492,7.1e-05,1.194424,0.935679,0.226678,6.958448,0.510826,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.000295,5.30336,4.605229,4.045865,2.397895,2.197225,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,1.047792,1.000302
1,0,32,0-32,389.911957,0.00111,389.912109,0.001149,0.004233,0.005808,0.092346,0.001088,2e-06,0.000216,0.222789,0.001408,0.000244,0.00022,0.374884,0.001955,0.000558,0.000266,8.243604,6.474128,6.664381,5.029068,8.392779,6.719712,6.578603,5.159263,9.808663,9.808663,3.365088,-0.021464,-0.633288,3.316562,-0.108935,-0.761674,-0.111319,-0.000122,0.00011,-0.187273,-0.000279,0.000133,4.268167,1.186723,0.08646,12.579914,0.496146,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.002669,9.42212,7.169559,5.509533,5.435508,2.198447,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,3.922755,1.000005
2,0,34,0-34,389.911957,0.00111,389.912109,0.001149,0.004233,0.005808,0.092346,0.001088,2e-06,0.000216,0.222789,0.001408,0.000244,0.00022,0.374884,0.001955,0.000558,0.000266,8.243604,6.474128,6.664381,5.029068,8.392779,6.719712,6.578603,5.159263,9.808663,9.808663,3.365088,-0.021464,-0.633288,3.316562,-0.108935,-0.761674,-0.111319,-0.000122,0.00011,-0.187273,-0.000279,0.000133,4.268167,1.186723,0.08646,12.579914,0.496146,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.002669,9.42212,7.169559,5.509533,5.435508,2.198447,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,3.922755,1.000005


In [12]:
train.head()

Unnamed: 0,stock_id,time_id,target,row_id,wap1_sum,wap1_std,wap2_sum,wap2_std,log_return1_realized_volatility,log_return2_realized_volatility,wap_balance_sum,wap_balance_amax,wap_balance_amin,wap_balance_std,bid_ask_spread1_sum,bid_ask_spread1_amax,bid_ask_spread1_amin,bid_ask_spread1_std,bid_ask_spread2_sum,bid_ask_spread2_amax,bid_ask_spread2_amin,bid_ask_spread2_std,order_flow_imbalance1_sum,order_flow_imbalance1_amax,order_flow_imbalance1_amin,order_flow_imbalance1_std,order_flow_imbalance2_sum,order_flow_imbalance2_amax,order_flow_imbalance2_amin,order_flow_imbalance2_std,order_book_slope_mean,order_book_slope_amax,depth_imbalance1_sum,depth_imbalance1_amax,depth_imbalance1_std,depth_imbalance2_sum,depth_imbalance2_amax,depth_imbalance2_std,height_imbalance1_sum,height_imbalance1_amax,height_imbalance1_std,height_imbalance2_sum,height_imbalance2_amax,height_imbalance2_std,pressure_imbalance_sum,pressure_imbalance_amax,pressure_imbalance_std,total_volume_sum,seconds_gap_mean,log_return1_realized_volatility_150,log_return2_realized_volatility_150,wap_balance_sum_150,wap_balance_amax_150,wap_balance_amin_150,wap_balance_std_150,bid_ask_spread1_sum_150,bid_ask_spread1_amax_150,bid_ask_spread1_amin_150,bid_ask_spread1_std_150,bid_ask_spread2_sum_150,bid_ask_spread2_amax_150,bid_ask_spread2_amin_150,bid_ask_spread2_std_150,order_flow_imbalance1_sum_150,order_flow_imbalance1_amax_150,order_flow_imbalance1_amin_150,order_flow_imbalance1_std_150,order_flow_imbalance2_sum_150,order_flow_imbalance2_amax_150,order_flow_imbalance2_amin_150,order_flow_imbalance2_std_150,total_volume_sum_150,seconds_gap_mean_150,log_return1_realized_volatility_300,log_return2_realized_volatility_300,wap_balance_sum_300,wap_balance_amax_300,wap_balance_amin_300,wap_balance_std_300,bid_ask_spread1_sum_300,bid_ask_spread1_amax_300,bid_ask_spread1_amin_300,bid_ask_spread1_std_300,bid_ask_spread2_sum_300,bid_ask_spread2_amax_300,bid_ask_spread2_amin_300,bid_ask_spread2_std_300,order_flow_imbalance1_sum_300,order_flow_imbalance1_amax_300,order_flow_imbalance1_amin_300,order_flow_imbalance1_std_300,order_flow_imbalance2_sum_300,order_flow_imbalance2_amax_300,order_flow_imbalance2_amin_300,order_flow_imbalance2_std_300,total_volume_sum_300,seconds_gap_mean_300,log_return1_realized_volatility_450,log_return2_realized_volatility_450,wap_balance_sum_450,wap_balance_amax_450,wap_balance_amin_450,wap_balance_std_450,bid_ask_spread1_sum_450,bid_ask_spread1_amax_450,bid_ask_spread1_amin_450,bid_ask_spread1_std_450,bid_ask_spread2_sum_450,bid_ask_spread2_amax_450,bid_ask_spread2_amin_450,bid_ask_spread2_std_450,order_flow_imbalance1_sum_450,order_flow_imbalance1_amax_450,order_flow_imbalance1_amin_450,order_flow_imbalance1_std_450,order_flow_imbalance2_sum_450,order_flow_imbalance2_amax_450,order_flow_imbalance2_amin_450,order_flow_imbalance2_std_450,total_volume_sum_450,seconds_gap_mean_450,trade_price_log_return_realized_volatility,trade_volumes_sum,trade_volumes_amax,trade_volumes_std,trade_order_count_sum,trade_seconds_gap_mean,trade_price_log_return_realized_volatility_150,trade_volumes_sum_150,trade_volumes_amax_150,trade_volumes_std_150,trade_order_count_sum_150,trade_seconds_gap_mean_150,trade_price_log_return_realized_volatility_300,trade_volumes_sum_300,trade_volumes_amax_300,trade_volumes_std_300,trade_order_count_sum_300,trade_seconds_gap_mean_300,trade_price_log_return_realized_volatility_450,trade_volumes_sum_450,trade_volumes_amax_450,trade_volumes_std_450,trade_order_count_sum_450,trade_seconds_gap_mean_450,trade_tendency,trade_energy
0,0,5,0.004136,0-5,303.125061,0.000693,303.10553,0.000781,0.004499,0.006999,0.117051,0.001414,1.192093e-07,0.000295,0.257371,0.001394,0.000361,0.000212,0.355666,0.001701,0.00067,0.000213,7.596392,6.006353,5.308268,4.224421,7.478735,5.991465,5.991465,4.497738,10.452333,10.452333,0.652249,-0.01,-0.343972,3.353146,-0.013245,-0.337334,-0.128628,-0.00018,0.000106,-0.177725,-0.000335,0.000107,2.893845,1.52718,0.758014,11.489616,0.674767,0.003796,0.006087,0.091996,0.001184,1.192093e-07,0.000281,0.199148,0.001392,0.000361,0.000221,0.276655,0.001649,0.00067,0.000208,6.925595,6.006353,5.308268,4.222584,6.251904,5.57973,5.710427,4.39824,11.238015,0.649087,0.002953,0.004864,0.051757,0.001034,1.192093e-07,0.000273,0.114324,0.00134,0.000361,0.000237,0.167702,0.001599,0.00067,0.000214,4.204693,5.298317,5.308268,4.056979,4.543295,5.438079,5.710427,4.427955,10.621205,0.752501,0.001722,0.004114,0.024869,0.001034,1.192093e-07,0.000277,0.053258,0.001135,0.000361,0.000181,0.082251,0.001599,0.00067,0.00025,5.537334,4.836282,5.308268,3.988823,5.384495,5.438079,5.710427,4.540973,9.795234,0.750306,0.002006,8.06782,6.215792,4.77726,4.70048,2.682732,0.001701,7.638484,6.215792,4.809663,4.290459,2.681021,0.001308,7.373136,6.215792,4.957082,3.988984,2.682325,0.00106,6.952122,6.215792,4.929568,3.610918,2.371578,2.958298,1.007459
1,0,11,0.001445,0-11,200.047775,0.000262,200.041168,0.000272,0.001204,0.002476,0.042309,0.000639,4.768372e-07,0.000155,0.078856,0.000904,0.000151,0.000157,0.134233,0.001105,0.000301,0.0002,7.24065,4.919981,5.886104,3.898109,4.60517,5.886104,6.22059,4.485531,9.850611,9.850611,3.509541,-0.006623,-0.353675,-4.434955,-0.006645,-0.555974,-0.039418,-7.5e-05,7.9e-05,-0.067091,-0.000151,0.0001,4.69023,1.4354,0.298921,11.318005,1.091923,0.001058,0.002262,0.035451,0.000639,4.768372e-07,0.000158,0.061031,0.000753,0.000151,0.000112,0.107468,0.001054,0.000301,0.000164,7.010312,4.919981,5.886104,3.893688,4.26268,5.886104,6.22059,4.534274,11.191824,0.980106,0.000981,0.002009,0.027443,0.000639,2.241135e-05,0.000158,0.040599,0.000753,0.000151,0.000121,0.07363,0.001054,0.000301,0.000175,6.75227,4.919981,5.886104,4.047241,6.028279,5.886104,6.22059,4.70822,10.928094,0.955511,0.000918,0.001883,0.014525,0.000639,2.264977e-05,0.000175,0.018817,0.000753,0.000201,0.000144,0.039147,0.001054,0.000502,0.000206,6.228511,4.663439,5.886104,4.267981,7.060476,5.886104,5.934894,4.669374,10.093736,1.060872,0.000901,7.161896,5.634861,4.35459,4.043051,2.939162,0.000813,7.067619,5.634861,4.428789,3.89182,2.899588,0.000587,6.802757,5.634861,4.505685,3.583519,2.847812,0.000501,6.719386,5.634861,4.675397,3.091042,2.76001,1.757711,1.000413
2,0,16,0.002168,0-16,187.913849,0.000864,187.939819,0.000862,0.002369,0.004801,0.06223,0.001135,4.887581e-06,0.000246,0.13638,0.00115,0.000384,0.000164,0.210687,0.001917,0.000575,0.000295,7.340836,5.32301,5.717028,4.102399,7.536364,5.717028,5.752573,4.360462,9.983113,9.983113,2.987343,-0.020001,-0.538069,3.68942,-0.010753,-0.543551,-0.068165,-0.000192,8.2e-05,-0.105282,-0.000288,0.000147,1.308376,1.348124,0.257778,11.267971,1.157149,0.002138,0.004019,0.044348,0.001135,1.829863e-05,0.000276,0.080841,0.00115,0.000384,0.000164,0.128644,0.001917,0.000575,0.000301,7.015712,5.32301,5.717028,4.160287,6.700731,5.717028,5.752573,4.411731,10.839502,1.327899,0.001295,0.003196,0.029308,0.001135,1.829863e-05,0.000294,0.046882,0.000958,0.000384,0.000162,0.073504,0.001825,0.000575,0.000321,5.950643,5.32301,5.717028,4.307268,7.41397,5.717028,4.795791,4.310067,10.340322,1.480936,0.001158,0.002972,0.016056,0.001135,1.829863e-05,0.000282,0.026617,0.000815,0.000384,0.000105,0.043991,0.001825,0.000575,0.000314,5.68358,5.32301,5.717028,4.495272,7.397562,5.717028,4.60517,4.448165,9.913487,1.303407,0.001961,7.677219,5.967283,4.731364,4.219508,3.174715,0.001621,7.604718,5.967283,4.794383,4.077537,3.171784,0.001137,7.07908,5.967283,4.87099,3.637586,3.218876,0.001048,6.987423,5.967283,4.964332,3.496508,3.001272,3.909281,0.998409
3,0,31,0.002195,0-31,119.859779,0.000757,119.835945,0.000656,0.002574,0.003637,0.04561,0.001082,1.507998e-05,0.000248,0.103301,0.001624,0.000324,0.00028,0.139155,0.002041,0.000648,0.000366,7.674153,5.913503,5.303305,4.386731,7.210818,5.31812,5.298317,4.362316,10.739769,10.739769,1.911819,-0.016195,-0.343264,3.690901,-0.131336,-0.645408,-0.051626,-0.000162,0.00014,-0.069533,-0.000324,0.000183,3.197619,1.303165,0.172368,10.863451,1.58412,0.002196,0.003273,0.029322,0.001082,1.507998e-05,0.000247,0.074589,0.001624,0.000324,0.000297,0.097188,0.002041,0.000648,0.000381,7.415175,5.7301,5.298317,4.385569,6.864848,5.298317,5.298317,4.319056,10.444736,1.683193,0.001776,0.002713,0.017524,0.00084,1.507998e-05,0.000228,0.04418,0.001624,0.000324,0.000278,0.057634,0.001716,0.000648,0.000316,6.45047,5.703782,5.298317,4.419545,6.280396,5.298317,5.225747,4.170158,10.00618,1.692669,0.000993,0.001424,0.006441,0.00084,5.167723e-05,0.000253,0.019057,0.00116,0.000973,7.4e-05,0.022027,0.001392,0.001113,0.000113,5.442418,5.703782,5.298317,4.687749,3.583519,4.836282,5.225747,4.276439,9.181941,2.029609,0.001561,7.580498,6.107218,4.973959,4.077537,3.648057,0.001401,7.395552,6.107218,5.075901,3.912023,3.731155,0.001089,7.348488,6.107218,5.120764,3.828641,3.532875,0.000802,6.240201,6.107218,5.49036,2.397895,4.127134,2.705595,0.998041
4,0,62,0.001747,0-62,175.932861,0.000258,175.93425,0.000317,0.001894,0.003257,0.044783,0.000724,3.278255e-06,0.000188,0.069916,0.000793,9.3e-05,0.00013,0.122743,0.001166,0.000373,0.000185,5.451038,5.32301,5.31812,3.910492,7.921536,5.808142,5.347108,4.211845,10.122023,10.122023,2.747648,-0.00995,-0.393141,3.472107,-0.006667,-0.390728,-0.034951,-4.7e-05,6.5e-05,-0.061349,-0.000186,9.3e-05,5.01564,1.52257,0.493275,11.00886,1.202836,0.001609,0.002927,0.032718,0.000724,3.278255e-06,0.000193,0.053359,0.000793,9.3e-05,0.000137,0.094379,0.00112,0.000373,0.000197,5.365976,5.288267,5.31812,3.79169,7.642524,5.808142,5.347108,4.266524,10.822195,1.179225,0.00152,0.002188,0.022397,0.00066,3.278255e-06,0.000188,0.037829,0.000793,0.000187,0.00014,0.065599,0.00112,0.00042,0.000174,6.120297,5.288267,4.644391,3.850977,7.042286,5.808142,5.31812,4.193934,10.498884,1.170846,0.001378,0.000966,0.013088,0.00066,2.205372e-05,0.000203,0.018705,0.000793,0.000233,0.000138,0.030654,0.001027,0.000606,0.00013,4.418841,5.288267,4.615121,3.857062,2.079441,5.267858,5.31812,4.145817,9.554639,1.350955,0.000871,7.490113,5.83131,4.769509,4.488636,3.278858,0.00055,7.358388,5.83131,4.845378,4.276666,3.32773,0.000453,7.105312,5.83131,4.941584,3.988984,3.302548,0.00036,3.760895,3.52611,2.74613,2.639057,3.637586,2.354738,0.999237


In [13]:
# Process agg by kmeans
def get_kmeans_idx(n_clusters=7):
    train_p = pd.read_csv("../../input/optiver-realized-volatility-prediction/train.csv")
    train_p = train_p.pivot(index="time_id", columns="stock_id", values="target")

    corr = train_p.corr()

    ids = corr.index

    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(corr.values)

    kmeans_clusters = []
    for n in range(n_clusters):
        kmeans_clusters.append ([(x - 1) for x in ((ids + 1)*(kmeans.labels_ == n)) if x > 0])
        
    return kmeans_clusters
    

def agg_stat_features_by_clusters(df, n_clusters=7, function=np.nanmean, post_fix="_cluster_mean"):

    kmeans_clusters = get_kmeans_idx(n_clusters=n_clusters)

    clusters = []
    agg_columns = [
        "time_id",
        "stock_id",
        "log_return1_realized_volatility",
        "log_return2_realized_volatility",
        "order_flow_imbalance1_sum",
        "order_flow_imbalance2_sum",
        "order_book_slope_mean",
        "depth_imbalance1_std",
        "depth_imbalance2_std",
        "height_imbalance1_sum",
        "height_imbalance2_sum",
        "pressure_imbalance_std",
        "total_volume_sum",
        "seconds_gap_mean",
        "trade_price_log_return_realized_volatility",
        "trade_volumes_sum",
        "trade_order_count_sum",
        "trade_seconds_gap_mean",
        "trade_tendency",
        "trade_energy"
    ]

    for cluster_idx, ind in enumerate(kmeans_clusters):
        cluster_df = df.loc[df["stock_id"].isin(ind), agg_columns].groupby(["time_id"]).agg(function)
        cluster_df.loc[:, "stock_id"] = str(cluster_idx) + post_fix
        clusters.append(cluster_df)

    clusters_df = pd.concat(clusters).reset_index()
    # multi index (column, c1)
    clusters_df = clusters_df.pivot(index="time_id", columns="stock_id")
    # ravel multi index to list of tuple [(target, c1), ...]
    clusters_df.columns = ["_".join(x) for x in clusters_df.columns.ravel()]
    clusters_df.reset_index(inplace=True)

    postfixes = [
        "0" + post_fix,
        "1" + post_fix,
        "3" + post_fix,
        "4" + post_fix,
        "6" + post_fix,
    ]
    merge_columns = []
    for column in agg_columns:
        if column == "time_id":
            merge_columns.append(column)
        elif column == "stock_id":
            continue
        else:
            for postfix in postfixes:
                merge_columns.append(column + "_" + postfix)
                
    not_exist_columns = [column for column in merge_columns if column not in clusters_df.columns]
    clusters_df[not_exist_columns] = 0
    
    df = pd.merge(df, clusters_df[merge_columns], how="left", on="time_id")

    return df


# Function to get group stats for the time_id
def agg_stat_features_by_market(df, operations=None, operations_names=None):
    def percentile(n):
        def percentile_(x):
            return np.percentile(x, n)

        percentile_.__name__ = "percentile_%s" % n
        return percentile_

    if operations is None:
        operations = [
            np.nanmean,
        ]
        operations_names = [
            "mean",
        ]

    # Get realized volatility columns
    vol_cols = [
        "log_return1_realized_volatility",
        "log_return1_realized_volatility_150",
        "log_return1_realized_volatility_300",
        "log_return1_realized_volatility_450",
    ]

    # Group by the stock id
    df_stock_id = df.groupby(["stock_id"])[vol_cols].agg(operations).reset_index()
    # Rename columns joining suffix
    df_stock_id.columns = ["_".join(col) for col in df_stock_id.columns]
    df_stock_id = df_stock_id.add_suffix("_" + "stock")

    # Group by the stock id
    df_time_id = df.groupby(["time_id"])[vol_cols].agg(operations).reset_index()
    # Rename columns joining suffix
    df_time_id.columns = ["_".join(col) for col in df_time_id.columns]
    df_time_id = df_time_id.add_suffix("_" + "time")

    # Merge with original dataframe
    df = df.merge(df_stock_id, how="left", left_on=["stock_id"], right_on=["stock_id__stock"])
    df.drop("stock_id__stock", axis=1, inplace=True)

    df = df.merge(df_time_id, how="left", left_on=["time_id"], right_on=["time_id__time"])
    df.drop("time_id__time", axis=1, inplace=True)

    return df


In [14]:
test.head()

Unnamed: 0,stock_id,time_id,row_id,wap1_sum,wap1_std,wap2_sum,wap2_std,log_return1_realized_volatility,log_return2_realized_volatility,wap_balance_sum,wap_balance_amax,wap_balance_amin,wap_balance_std,bid_ask_spread1_sum,bid_ask_spread1_amax,bid_ask_spread1_amin,bid_ask_spread1_std,bid_ask_spread2_sum,bid_ask_spread2_amax,bid_ask_spread2_amin,bid_ask_spread2_std,order_flow_imbalance1_sum,order_flow_imbalance1_amax,order_flow_imbalance1_amin,order_flow_imbalance1_std,order_flow_imbalance2_sum,order_flow_imbalance2_amax,order_flow_imbalance2_amin,order_flow_imbalance2_std,order_book_slope_mean,order_book_slope_amax,depth_imbalance1_sum,depth_imbalance1_amax,depth_imbalance1_std,depth_imbalance2_sum,depth_imbalance2_amax,depth_imbalance2_std,height_imbalance1_sum,height_imbalance1_amax,height_imbalance1_std,height_imbalance2_sum,height_imbalance2_amax,height_imbalance2_std,pressure_imbalance_sum,pressure_imbalance_amax,pressure_imbalance_std,total_volume_sum,seconds_gap_mean,log_return1_realized_volatility_150,log_return2_realized_volatility_150,wap_balance_sum_150,wap_balance_amax_150,wap_balance_amin_150,wap_balance_std_150,bid_ask_spread1_sum_150,bid_ask_spread1_amax_150,bid_ask_spread1_amin_150,bid_ask_spread1_std_150,bid_ask_spread2_sum_150,bid_ask_spread2_amax_150,bid_ask_spread2_amin_150,bid_ask_spread2_std_150,order_flow_imbalance1_sum_150,order_flow_imbalance1_amax_150,order_flow_imbalance1_amin_150,order_flow_imbalance1_std_150,order_flow_imbalance2_sum_150,order_flow_imbalance2_amax_150,order_flow_imbalance2_amin_150,order_flow_imbalance2_std_150,total_volume_sum_150,seconds_gap_mean_150,log_return1_realized_volatility_300,log_return2_realized_volatility_300,wap_balance_sum_300,wap_balance_amax_300,wap_balance_amin_300,wap_balance_std_300,bid_ask_spread1_sum_300,bid_ask_spread1_amax_300,bid_ask_spread1_amin_300,bid_ask_spread1_std_300,bid_ask_spread2_sum_300,bid_ask_spread2_amax_300,bid_ask_spread2_amin_300,bid_ask_spread2_std_300,order_flow_imbalance1_sum_300,order_flow_imbalance1_amax_300,order_flow_imbalance1_amin_300,order_flow_imbalance1_std_300,order_flow_imbalance2_sum_300,order_flow_imbalance2_amax_300,order_flow_imbalance2_amin_300,order_flow_imbalance2_std_300,total_volume_sum_300,seconds_gap_mean_300,log_return1_realized_volatility_450,log_return2_realized_volatility_450,wap_balance_sum_450,wap_balance_amax_450,wap_balance_amin_450,wap_balance_std_450,bid_ask_spread1_sum_450,bid_ask_spread1_amax_450,bid_ask_spread1_amin_450,bid_ask_spread1_std_450,bid_ask_spread2_sum_450,bid_ask_spread2_amax_450,bid_ask_spread2_amin_450,bid_ask_spread2_std_450,order_flow_imbalance1_sum_450,order_flow_imbalance1_amax_450,order_flow_imbalance1_amin_450,order_flow_imbalance1_std_450,order_flow_imbalance2_sum_450,order_flow_imbalance2_amax_450,order_flow_imbalance2_amin_450,order_flow_imbalance2_std_450,total_volume_sum_450,seconds_gap_mean_450,trade_price_log_return_realized_volatility,trade_volumes_sum,trade_volumes_amax,trade_volumes_std,trade_order_count_sum,trade_seconds_gap_mean,trade_price_log_return_realized_volatility_150,trade_volumes_sum_150,trade_volumes_amax_150,trade_volumes_std_150,trade_order_count_sum_150,trade_seconds_gap_mean_150,trade_price_log_return_realized_volatility_300,trade_volumes_sum_300,trade_volumes_amax_300,trade_volumes_std_300,trade_order_count_sum_300,trade_seconds_gap_mean_300,trade_price_log_return_realized_volatility_450,trade_volumes_sum_450,trade_volumes_amax_450,trade_volumes_std_450,trade_order_count_sum_450,trade_seconds_gap_mean_450,trade_tendency,trade_energy
0,0,4,0-4,3.001215,0.00017,3.00165,0.000153,0.000294,0.000252,0.000436,0.000169,0.000125,2.2e-05,0.001672,0.00059,0.000541,2.8e-05,0.003197,0.00123,0.000984,0.000142,2.995732,-17.727534,2.995732,2.446426,2.70805,-17.727534,2.70805,2.158744,11.666106,11.666106,-0.252666,-0.13815,-0.634768,0.703567,-0.299243,-2.740748,-0.000836,-0.00027,1.4e-05,-0.001598,-0.000492,7.1e-05,1.194424,0.935679,0.226678,6.958448,0.510826,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.000295,5.30336,4.605229,4.045865,2.397895,2.197225,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,1.047792,1.000302
1,0,32,0-32,389.911957,0.00111,389.912109,0.001149,0.004233,0.005808,0.092346,0.001088,2e-06,0.000216,0.222789,0.001408,0.000244,0.00022,0.374884,0.001955,0.000558,0.000266,8.243604,6.474128,6.664381,5.029068,8.392779,6.719712,6.578603,5.159263,9.808663,9.808663,3.365088,-0.021464,-0.633288,3.316562,-0.108935,-0.761674,-0.111319,-0.000122,0.00011,-0.187273,-0.000279,0.000133,4.268167,1.186723,0.08646,12.579914,0.496146,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.002669,9.42212,7.169559,5.509533,5.435508,2.198447,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,3.922755,1.000005
2,0,34,0-34,389.911957,0.00111,389.912109,0.001149,0.004233,0.005808,0.092346,0.001088,2e-06,0.000216,0.222789,0.001408,0.000244,0.00022,0.374884,0.001955,0.000558,0.000266,8.243604,6.474128,6.664381,5.029068,8.392779,6.719712,6.578603,5.159263,9.808663,9.808663,3.365088,-0.021464,-0.633288,3.316562,-0.108935,-0.761674,-0.111319,-0.000122,0.00011,-0.187273,-0.000279,0.000133,4.268167,1.186723,0.08646,12.579914,0.496146,0.003572,0.004899,0.067365,0.001003,2e-06,0.000208,0.162831,0.001307,0.000256,0.000208,0.275176,0.001825,0.000573,0.000249,7.965228,6.377706,6.54466,5.006596,8.105398,6.620631,6.427413,5.135831,12.278644,0.510187,0.002854,0.003917,0.04407,0.000926,4e-06,0.000203,0.106615,0.001217,0.000275,0.000198,0.180765,0.001716,0.000596,0.000237,7.578421,6.229574,6.359447,4.966197,7.709727,6.476402,6.197384,5.094582,11.861316,0.517084,0.001968,0.002703,0.021708,0.000815,9e-06,0.000195,0.052529,0.001094,0.000315,0.000182,0.089294,0.001573,0.000645,0.000218,6.922547,5.896404,5.926516,4.865447,7.040618,6.190763,5.689076,4.995432,11.149388,0.524177,0.002669,9.42212,7.169559,5.509533,5.435508,2.198447,0.002287,9.102952,7.037229,5.46627,5.124539,2.234447,0.001829,8.658495,6.85117,5.400575,4.694283,2.246024,0.001259,7.878889,6.499428,5.250235,3.955146,2.258757,3.922755,1.000005


In [15]:
train.head()

Unnamed: 0,stock_id,time_id,target,row_id,wap1_sum,wap1_std,wap2_sum,wap2_std,log_return1_realized_volatility,log_return2_realized_volatility,wap_balance_sum,wap_balance_amax,wap_balance_amin,wap_balance_std,bid_ask_spread1_sum,bid_ask_spread1_amax,bid_ask_spread1_amin,bid_ask_spread1_std,bid_ask_spread2_sum,bid_ask_spread2_amax,bid_ask_spread2_amin,bid_ask_spread2_std,order_flow_imbalance1_sum,order_flow_imbalance1_amax,order_flow_imbalance1_amin,order_flow_imbalance1_std,order_flow_imbalance2_sum,order_flow_imbalance2_amax,order_flow_imbalance2_amin,order_flow_imbalance2_std,order_book_slope_mean,order_book_slope_amax,depth_imbalance1_sum,depth_imbalance1_amax,depth_imbalance1_std,depth_imbalance2_sum,depth_imbalance2_amax,depth_imbalance2_std,height_imbalance1_sum,height_imbalance1_amax,height_imbalance1_std,height_imbalance2_sum,height_imbalance2_amax,height_imbalance2_std,pressure_imbalance_sum,pressure_imbalance_amax,pressure_imbalance_std,total_volume_sum,seconds_gap_mean,log_return1_realized_volatility_150,log_return2_realized_volatility_150,wap_balance_sum_150,wap_balance_amax_150,wap_balance_amin_150,wap_balance_std_150,bid_ask_spread1_sum_150,bid_ask_spread1_amax_150,bid_ask_spread1_amin_150,bid_ask_spread1_std_150,bid_ask_spread2_sum_150,bid_ask_spread2_amax_150,bid_ask_spread2_amin_150,bid_ask_spread2_std_150,order_flow_imbalance1_sum_150,order_flow_imbalance1_amax_150,order_flow_imbalance1_amin_150,order_flow_imbalance1_std_150,order_flow_imbalance2_sum_150,order_flow_imbalance2_amax_150,order_flow_imbalance2_amin_150,order_flow_imbalance2_std_150,total_volume_sum_150,seconds_gap_mean_150,log_return1_realized_volatility_300,log_return2_realized_volatility_300,wap_balance_sum_300,wap_balance_amax_300,wap_balance_amin_300,wap_balance_std_300,bid_ask_spread1_sum_300,bid_ask_spread1_amax_300,bid_ask_spread1_amin_300,bid_ask_spread1_std_300,bid_ask_spread2_sum_300,bid_ask_spread2_amax_300,bid_ask_spread2_amin_300,bid_ask_spread2_std_300,order_flow_imbalance1_sum_300,order_flow_imbalance1_amax_300,order_flow_imbalance1_amin_300,order_flow_imbalance1_std_300,order_flow_imbalance2_sum_300,order_flow_imbalance2_amax_300,order_flow_imbalance2_amin_300,order_flow_imbalance2_std_300,total_volume_sum_300,seconds_gap_mean_300,log_return1_realized_volatility_450,log_return2_realized_volatility_450,wap_balance_sum_450,wap_balance_amax_450,wap_balance_amin_450,wap_balance_std_450,bid_ask_spread1_sum_450,bid_ask_spread1_amax_450,bid_ask_spread1_amin_450,bid_ask_spread1_std_450,bid_ask_spread2_sum_450,bid_ask_spread2_amax_450,bid_ask_spread2_amin_450,bid_ask_spread2_std_450,order_flow_imbalance1_sum_450,order_flow_imbalance1_amax_450,order_flow_imbalance1_amin_450,order_flow_imbalance1_std_450,order_flow_imbalance2_sum_450,order_flow_imbalance2_amax_450,order_flow_imbalance2_amin_450,order_flow_imbalance2_std_450,total_volume_sum_450,seconds_gap_mean_450,trade_price_log_return_realized_volatility,trade_volumes_sum,trade_volumes_amax,trade_volumes_std,trade_order_count_sum,trade_seconds_gap_mean,trade_price_log_return_realized_volatility_150,trade_volumes_sum_150,trade_volumes_amax_150,trade_volumes_std_150,trade_order_count_sum_150,trade_seconds_gap_mean_150,trade_price_log_return_realized_volatility_300,trade_volumes_sum_300,trade_volumes_amax_300,trade_volumes_std_300,trade_order_count_sum_300,trade_seconds_gap_mean_300,trade_price_log_return_realized_volatility_450,trade_volumes_sum_450,trade_volumes_amax_450,trade_volumes_std_450,trade_order_count_sum_450,trade_seconds_gap_mean_450,trade_tendency,trade_energy
0,0,5,0.004136,0-5,303.125061,0.000693,303.10553,0.000781,0.004499,0.006999,0.117051,0.001414,1.192093e-07,0.000295,0.257371,0.001394,0.000361,0.000212,0.355666,0.001701,0.00067,0.000213,7.596392,6.006353,5.308268,4.224421,7.478735,5.991465,5.991465,4.497738,10.452333,10.452333,0.652249,-0.01,-0.343972,3.353146,-0.013245,-0.337334,-0.128628,-0.00018,0.000106,-0.177725,-0.000335,0.000107,2.893845,1.52718,0.758014,11.489616,0.674767,0.003796,0.006087,0.091996,0.001184,1.192093e-07,0.000281,0.199148,0.001392,0.000361,0.000221,0.276655,0.001649,0.00067,0.000208,6.925595,6.006353,5.308268,4.222584,6.251904,5.57973,5.710427,4.39824,11.238015,0.649087,0.002953,0.004864,0.051757,0.001034,1.192093e-07,0.000273,0.114324,0.00134,0.000361,0.000237,0.167702,0.001599,0.00067,0.000214,4.204693,5.298317,5.308268,4.056979,4.543295,5.438079,5.710427,4.427955,10.621205,0.752501,0.001722,0.004114,0.024869,0.001034,1.192093e-07,0.000277,0.053258,0.001135,0.000361,0.000181,0.082251,0.001599,0.00067,0.00025,5.537334,4.836282,5.308268,3.988823,5.384495,5.438079,5.710427,4.540973,9.795234,0.750306,0.002006,8.06782,6.215792,4.77726,4.70048,2.682732,0.001701,7.638484,6.215792,4.809663,4.290459,2.681021,0.001308,7.373136,6.215792,4.957082,3.988984,2.682325,0.00106,6.952122,6.215792,4.929568,3.610918,2.371578,2.958298,1.007459
1,0,11,0.001445,0-11,200.047775,0.000262,200.041168,0.000272,0.001204,0.002476,0.042309,0.000639,4.768372e-07,0.000155,0.078856,0.000904,0.000151,0.000157,0.134233,0.001105,0.000301,0.0002,7.24065,4.919981,5.886104,3.898109,4.60517,5.886104,6.22059,4.485531,9.850611,9.850611,3.509541,-0.006623,-0.353675,-4.434955,-0.006645,-0.555974,-0.039418,-7.5e-05,7.9e-05,-0.067091,-0.000151,0.0001,4.69023,1.4354,0.298921,11.318005,1.091923,0.001058,0.002262,0.035451,0.000639,4.768372e-07,0.000158,0.061031,0.000753,0.000151,0.000112,0.107468,0.001054,0.000301,0.000164,7.010312,4.919981,5.886104,3.893688,4.26268,5.886104,6.22059,4.534274,11.191824,0.980106,0.000981,0.002009,0.027443,0.000639,2.241135e-05,0.000158,0.040599,0.000753,0.000151,0.000121,0.07363,0.001054,0.000301,0.000175,6.75227,4.919981,5.886104,4.047241,6.028279,5.886104,6.22059,4.70822,10.928094,0.955511,0.000918,0.001883,0.014525,0.000639,2.264977e-05,0.000175,0.018817,0.000753,0.000201,0.000144,0.039147,0.001054,0.000502,0.000206,6.228511,4.663439,5.886104,4.267981,7.060476,5.886104,5.934894,4.669374,10.093736,1.060872,0.000901,7.161896,5.634861,4.35459,4.043051,2.939162,0.000813,7.067619,5.634861,4.428789,3.89182,2.899588,0.000587,6.802757,5.634861,4.505685,3.583519,2.847812,0.000501,6.719386,5.634861,4.675397,3.091042,2.76001,1.757711,1.000413
2,0,16,0.002168,0-16,187.913849,0.000864,187.939819,0.000862,0.002369,0.004801,0.06223,0.001135,4.887581e-06,0.000246,0.13638,0.00115,0.000384,0.000164,0.210687,0.001917,0.000575,0.000295,7.340836,5.32301,5.717028,4.102399,7.536364,5.717028,5.752573,4.360462,9.983113,9.983113,2.987343,-0.020001,-0.538069,3.68942,-0.010753,-0.543551,-0.068165,-0.000192,8.2e-05,-0.105282,-0.000288,0.000147,1.308376,1.348124,0.257778,11.267971,1.157149,0.002138,0.004019,0.044348,0.001135,1.829863e-05,0.000276,0.080841,0.00115,0.000384,0.000164,0.128644,0.001917,0.000575,0.000301,7.015712,5.32301,5.717028,4.160287,6.700731,5.717028,5.752573,4.411731,10.839502,1.327899,0.001295,0.003196,0.029308,0.001135,1.829863e-05,0.000294,0.046882,0.000958,0.000384,0.000162,0.073504,0.001825,0.000575,0.000321,5.950643,5.32301,5.717028,4.307268,7.41397,5.717028,4.795791,4.310067,10.340322,1.480936,0.001158,0.002972,0.016056,0.001135,1.829863e-05,0.000282,0.026617,0.000815,0.000384,0.000105,0.043991,0.001825,0.000575,0.000314,5.68358,5.32301,5.717028,4.495272,7.397562,5.717028,4.60517,4.448165,9.913487,1.303407,0.001961,7.677219,5.967283,4.731364,4.219508,3.174715,0.001621,7.604718,5.967283,4.794383,4.077537,3.171784,0.001137,7.07908,5.967283,4.87099,3.637586,3.218876,0.001048,6.987423,5.967283,4.964332,3.496508,3.001272,3.909281,0.998409
3,0,31,0.002195,0-31,119.859779,0.000757,119.835945,0.000656,0.002574,0.003637,0.04561,0.001082,1.507998e-05,0.000248,0.103301,0.001624,0.000324,0.00028,0.139155,0.002041,0.000648,0.000366,7.674153,5.913503,5.303305,4.386731,7.210818,5.31812,5.298317,4.362316,10.739769,10.739769,1.911819,-0.016195,-0.343264,3.690901,-0.131336,-0.645408,-0.051626,-0.000162,0.00014,-0.069533,-0.000324,0.000183,3.197619,1.303165,0.172368,10.863451,1.58412,0.002196,0.003273,0.029322,0.001082,1.507998e-05,0.000247,0.074589,0.001624,0.000324,0.000297,0.097188,0.002041,0.000648,0.000381,7.415175,5.7301,5.298317,4.385569,6.864848,5.298317,5.298317,4.319056,10.444736,1.683193,0.001776,0.002713,0.017524,0.00084,1.507998e-05,0.000228,0.04418,0.001624,0.000324,0.000278,0.057634,0.001716,0.000648,0.000316,6.45047,5.703782,5.298317,4.419545,6.280396,5.298317,5.225747,4.170158,10.00618,1.692669,0.000993,0.001424,0.006441,0.00084,5.167723e-05,0.000253,0.019057,0.00116,0.000973,7.4e-05,0.022027,0.001392,0.001113,0.000113,5.442418,5.703782,5.298317,4.687749,3.583519,4.836282,5.225747,4.276439,9.181941,2.029609,0.001561,7.580498,6.107218,4.973959,4.077537,3.648057,0.001401,7.395552,6.107218,5.075901,3.912023,3.731155,0.001089,7.348488,6.107218,5.120764,3.828641,3.532875,0.000802,6.240201,6.107218,5.49036,2.397895,4.127134,2.705595,0.998041
4,0,62,0.001747,0-62,175.932861,0.000258,175.93425,0.000317,0.001894,0.003257,0.044783,0.000724,3.278255e-06,0.000188,0.069916,0.000793,9.3e-05,0.00013,0.122743,0.001166,0.000373,0.000185,5.451038,5.32301,5.31812,3.910492,7.921536,5.808142,5.347108,4.211845,10.122023,10.122023,2.747648,-0.00995,-0.393141,3.472107,-0.006667,-0.390728,-0.034951,-4.7e-05,6.5e-05,-0.061349,-0.000186,9.3e-05,5.01564,1.52257,0.493275,11.00886,1.202836,0.001609,0.002927,0.032718,0.000724,3.278255e-06,0.000193,0.053359,0.000793,9.3e-05,0.000137,0.094379,0.00112,0.000373,0.000197,5.365976,5.288267,5.31812,3.79169,7.642524,5.808142,5.347108,4.266524,10.822195,1.179225,0.00152,0.002188,0.022397,0.00066,3.278255e-06,0.000188,0.037829,0.000793,0.000187,0.00014,0.065599,0.00112,0.00042,0.000174,6.120297,5.288267,4.644391,3.850977,7.042286,5.808142,5.31812,4.193934,10.498884,1.170846,0.001378,0.000966,0.013088,0.00066,2.205372e-05,0.000203,0.018705,0.000793,0.000233,0.000138,0.030654,0.001027,0.000606,0.00013,4.418841,5.288267,4.615121,3.857062,2.079441,5.267858,5.31812,4.145817,9.554639,1.350955,0.000871,7.490113,5.83131,4.769509,4.488636,3.278858,0.00055,7.358388,5.83131,4.845378,4.276666,3.32773,0.000453,7.105312,5.83131,4.941584,3.988984,3.302548,0.00036,3.760895,3.52611,2.74613,2.639057,3.637586,2.354738,0.999237


In [18]:
# Function to calculate the root mean squared percentage error
def rmspe(y_true, y_pred):
    return np.sqrt(np.mean(np.square((y_true - y_pred) / y_true)))

# Function to early stop with root mean squared percentage error
def feval_rmspe(y_pred, lgb_train):
    y_true = lgb_train.get_label()
    return "RMSPE", rmspe(y_true, y_pred), False

def train_and_evaluate(train, test):
    
    # scale
    scaler = StandardScaler()
    
    # Split features and target
    x = train.drop(["row_id", "target"], axis=1)
    y = train["target"]
    
    # x_test with train feature
    x_test = test.drop("row_id", axis=1)
    x_test = agg_stat_features_by_market(x_test)
    x_test = agg_stat_features_by_clusters(x_test, n_clusters=CONFIG["n_clusters"], function=np.nanmean, post_fix="_cluster_mean")
    x_test = agg_stat_features_by_clusters(x_test, n_clusters=CONFIG["n_clusters"], function=np.nanmax, post_fix="_cluster_max")
    x_test = agg_stat_features_by_clusters(x_test, n_clusters=CONFIG["n_clusters"], function=np.nanmin, post_fix="_cluster_min")
    x_test = agg_stat_features_by_clusters(x_test, n_clusters=CONFIG["n_clusters"], function=np.nanstd, post_fix="_cluster_std")
    
    # define normalize columns
    except_columns = ["stock_id", "time_id", "target", "row_id"]
    normalized_columns = [column for column in x_test.columns if column not in except_columns]
    x_test.drop("time_id", axis=1, inplace=True)
    
    # Transform stock id to a numeric value
    x["stock_id"] = x["stock_id"].astype(int)
    x_test["stock_id"] = x_test["stock_id"].astype(int)
    
    # Create out of folds array
    oof_predictions = np.zeros(x.shape[0])
    
    # Create test array to store predictions
    test_predictions = np.zeros(x_test.shape[0])
    
    # Create a KFold object
    kfold = GroupKFold(n_splits=CONFIG["n_splits"])
    
    # Iterate through each fold
    for fold, (trn_ind, val_ind) in enumerate(kfold.split(x, groups=x["time_id"])):
        
        print(f"Training fold {fold + 1}")
        
        scaler = load(os.path.join(CONFIG["ckpt_dir"], "std_scaler_fold_{}.bin".format(fold + 1)))
        
        x_val = x.iloc[val_ind]
        x_val = agg_stat_features_by_market(x_val)
        x_val = agg_stat_features_by_clusters(x_val, n_clusters=CONFIG["n_clusters"], function=np.nanmean, post_fix="_cluster_mean")
        x_val = agg_stat_features_by_clusters(x_val, n_clusters=CONFIG["n_clusters"], function=np.nanmax, post_fix="_cluster_max")
        x_val = agg_stat_features_by_clusters(x_val, n_clusters=CONFIG["n_clusters"], function=np.nanmin, post_fix="_cluster_min")
        x_val = agg_stat_features_by_clusters(x_val, n_clusters=CONFIG["n_clusters"], function=np.nanstd, post_fix="_cluster_std")
        x_val.drop("time_id", axis=1, inplace=True)
        x_val[normalized_columns] = scaler.transform(x_val[normalized_columns])
        
        y_train, y_val = y.iloc[trn_ind], y.iloc[val_ind]
        
        model = lgb.Booster(model_file=os.path.join(CONFIG["ckpt_dir"], "lgbm_fold_{}.txt".format(fold + 1)))
    
        # Add predictions to the out of folds array
        oof_predictions[val_ind] = model.predict(x_val)
        
        # Predict the test set
        x_test_ = x_test.copy()
        x_test_[normalized_columns] = scaler.transform(x_test_[normalized_columns])
        test_predictions += model.predict(x_test_) / CONFIG["n_splits"]
        
    rmspe_score = rmspe(y, oof_predictions)
    print(f"Our out of folds RMSPE is {rmspe_score}")
    
    # Return test predictions
    return test_predictions

In [19]:
# Traing and evaluate
test_predictions = train_and_evaluate(train, test)

# Save test predictions
# test["target"] = test_predictions
# test[["row_id", "target"]].to_csv("submission.csv",index = False)

Training fold 1
Training fold 2
Training fold 3
Training fold 4
Training fold 5
Our out of folds RMSPE is 0.21541103996743916


In [None]:
# GroupKfold Our out of folds RMSPE is 0.21541103996743916