In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from load_data import load_data_from_csv
from data_preprocessor.data_preprocessor import CompositeDataPreprocessor, ReduceMemUsageDataPreprocessor, FillNaPreProcessor
from data_preprocessor.feature_engineering import BasicFeaturesPreprocessor, DupletsTripletsPreprocessor, MovingAvgPreProcessor, RemoveIrrelevantFeaturesDataPreprocessor, DropTargetNADataPreprocessor, DTWKMeansPreprocessor
from data_preprocessor.polynomial_features import PolynomialFeaturesPreProcessor
from data_preprocessor.stockid_features import StockIdFeaturesPreProcessor
from data_preprocessor.deep_feature_synthesis import DfsPreProcessor
from data_generator.data_generator import DefaultTrainEvalDataGenerator, ManualKFoldDataGenerator, TimeSeriesKFoldDataGenerator

from model_pipeline.lgb_pipeline import LGBModelPipelineFactory

from model_post_processor.model_post_processor import CompositeModelPostProcessor, SaveModelPostProcessor

from train_pipeline.train_pipeline import DefaultTrainPipeline
from train_pipeline.train_optuna_pipeline import DefaultOptunaTrainPipeline

from train_pipeline.train_pipeline_callbacks import MAECallback
from utils.scoring_utils import ScoringUtils
from model_pipeline.dummy_models import BaselineEstimator

import optuna.integration.lightgbm as lgb
import optuna

import numpy as np

import sys

Install h5py to use hdf5 features: http://docs.h5py.org/
  warn(h5py_msg)


In [24]:
import pandas as pd

In [3]:
N_fold = 5
model_save_dir = './models/'

processors = [    
    ReduceMemUsageDataPreprocessor(verbose=True),
    # BasicFeaturesPreprocessor(),
    # DupletsTripletsPreprocessor(),
    # MovingAvgPreProcessor("wap"),   
    # StockIdFeaturesPreProcessor(),   
    # DTWKMeansPreprocessor(),
    DfsPreProcessor(),
    DropTargetNADataPreprocessor(),    
    RemoveIrrelevantFeaturesDataPreprocessor(['stock_id', 'date_id','time_id', 'row_id']),
    # FillNaPreProcessor(),
    # PolynomialFeaturesPreProcessor(),
]


processor = CompositeDataPreprocessor(processors)



In [4]:
# DATA_PATH = '/kaggle/input'
DATA_PATH = '..'
df_train, df_test, revealed_targets, sample_submission = load_data_from_csv(DATA_PATH)
print(df_train.columns)

raw_data = df_train
# df_train = df_train[:100000]


Index(['stock_id', 'date_id', 'seconds_in_bucket', 'imbalance_size',
       'imbalance_buy_sell_flag', 'reference_price', 'matched_size',
       'far_price', 'near_price', 'bid_price', 'bid_size', 'ask_price',
       'ask_size', 'wap', 'target', 'time_id', 'row_id'],
      dtype='object')


In [None]:
# df_train = raw_data

In [5]:
df_train = ReduceMemUsageDataPreprocessor(verbose=True).apply(df_train)

Memory usage of dataframe is 679.36 MB
Memory usage after optimization is: 304.72 MB
Decreased by 55.15%
dtypes:
stock_id                     int16
date_id                      int16
seconds_in_bucket            int16
imbalance_size             float32
imbalance_buy_sell_flag       int8
reference_price            float32
matched_size               float32
far_price                  float32
near_price                 float32
bid_price                  float32
bid_size                   float32
ask_price                  float32
ask_size                   float32
wap                        float32
target                     float32
time_id                      int16
row_id                      object
dtype: object


In [7]:
import featuretools as ft

In [81]:
df_ = df_train.copy()

es = ft.EntitySet(id = 'closing_movements_data')
# es = es.entity_from_dataframe(entity_id = 'df', dataframe = df_, index = 'row_id')


In [82]:
from woodwork.logical_types import Categorical, PostalCode

In [83]:
es = es.add_dataframe(
    dataframe_name="closing_movements",
    dataframe=df_,
    index="row_id",
    time_index="time_id",
    logical_types={
        "imbalance_buy_sell_flag": Categorical,
        # "zip_code": PostalCode,
    },
)


Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.
Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.


In [84]:
es

Entityset: closing_movements_data
  DataFrames:
    closing_movements [Rows: 5237980, Columns: 17]
  Relationships:
    No relationships

In [85]:
stocks_df = pd.DataFrame()
stocks_df["stock_id"] = pd.Series(pd.unique(df_["stock_id"]))
stocks_df["dummy"] = pd.Series(pd.unique(df_["stock_id"]))
stocks_df

Unnamed: 0,stock_id,dummy
0,0,0
1,1,1
2,2,2
3,3,3
4,4,4
...,...,...
195,153,153
196,199,199
197,79,79
198,135,135


In [86]:
es = es.add_dataframe(
    dataframe_name="stocks", dataframe=stocks_df, index="stock_id"
)

es

Entityset: closing_movements_data
  DataFrames:
    closing_movements [Rows: 5237980, Columns: 17]
    stocks [Rows: 200, Columns: 2]
  Relationships:
    No relationships

In [87]:
es = es.add_relationship("stocks", "stock_id", "closing_movements", "stock_id")
es

Entityset: closing_movements_data
  DataFrames:
    closing_movements [Rows: 5237980, Columns: 17]
    stocks [Rows: 200, Columns: 2]
  Relationships:
    closing_movements.stock_id -> stocks.stock_id

In [88]:
es["closing_movements"].ww.schema

Unnamed: 0_level_0,Logical Type,Semantic Tag(s)
Column,Unnamed: 1_level_1,Unnamed: 2_level_1
stock_id,Integer,"['foreign_key', 'numeric']"
date_id,Integer,['numeric']
seconds_in_bucket,Integer,['numeric']
imbalance_size,Double,['numeric']
imbalance_buy_sell_flag,Categorical,['category']
reference_price,Double,['numeric']
matched_size,Double,['numeric']
far_price,Double,['numeric']
near_price,Double,['numeric']
bid_price,Double,['numeric']


In [91]:
default_agg_primitives =  ["sum", "std", "max", "skew", "min", "mean"]
default_trans_primitives =  ["day", "year", "month", "weekday", "haversine", "numwords", "characters"]

In [92]:
feature_matrix, feature_defs = ft.dfs(entityset=es, 
                                    target_dataframe_name="stocks",
                                    # trans_primitives = default_trans_primitives,
                                    agg_primitives=default_agg_primitives, 
                                    max_depth = 2)
feature_matrix

The provided callable <function min at 0x00000149671BC670> is currently using SeriesGroupBy.min. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "min" instead.
The provided callable <function mean at 0x00000149671BCE50> is currently using SeriesGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.
The provided callable <function std at 0x00000149671BCF70> is currently using SeriesGroupBy.std. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "std" instead.
The provided callable <function max at 0x00000149671BC550> is currently using SeriesGroupBy.max. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "max" instead.
The provided callable <function sum at 0x000001496718FEB0> is currently using Ser

Unnamed: 0_level_0,dummy,MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_price),MAX(closing_movements.bid_size),MAX(closing_movements.date_id),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_size),MAX(closing_movements.matched_size),MAX(closing_movements.near_price),...,SUM(closing_movements.date_id),SUM(closing_movements.far_price),SUM(closing_movements.imbalance_size),SUM(closing_movements.matched_size),SUM(closing_movements.near_price),SUM(closing_movements.reference_price),SUM(closing_movements.seconds_in_bucket),SUM(closing_movements.target),SUM(closing_movements.time_id),SUM(closing_movements.wap)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,1.009462,1.365950e+06,1.009097,5.899042e+06,480.0,1.174934,133453864.0,2.256941e+08,1.103863,...,6349200.0,11900.113714,1.113933e+11,6.813479e+11,11992.981230,26450.626731,7142850.0,-6904.957294,349920285.0,26450.826897
1,1,1.018351,7.483277e+05,1.016511,6.939674e+05,480.0,1.246998,29274852.0,3.122824e+08,1.109980,...,6349200.0,11773.638891,1.371727e+10,1.117034e+11,12024.151295,26453.051268,7142850.0,-3171.616168,349920285.0,26453.122419
2,2,1.022605,1.733813e+06,1.019387,1.069922e+06,480.0,1.353206,619560640.0,1.494537e+09,1.098223,...,6349200.0,11827.957806,4.936810e+10,2.617821e+11,12028.561642,26458.947830,7142850.0,661.582899,349920285.0,26458.844050
3,3,1.012767,1.277225e+06,1.012299,1.929015e+06,480.0,1.068764,314800768.0,8.262816e+08,1.102969,...,6349200.0,11977.538373,2.468673e+11,1.850779e+12,12013.644835,26454.420663,7142850.0,-3649.805199,349920285.0,26454.400392
4,4,1.009318,8.845469e+05,1.008878,1.604196e+06,480.0,1.101952,77445264.0,3.428062e+08,1.101019,...,6349200.0,11933.699412,1.007934e+11,7.229124e+11,12014.354280,26451.257714,7142850.0,-5923.464309,349920285.0,26451.367152
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,153,1.018965,1.746364e+06,1.017415,9.941822e+05,480.0,1.174682,90502128.0,2.407256e+08,1.103799,...,6216375.0,9960.324740,1.894796e+10,1.159091e+11,10285.603993,22595.542040,6103350.0,-12855.575029,342510960.0,22594.496180
199,199,1.013363,1.412633e+06,1.013005,4.564530e+06,480.0,1.105067,311616448.0,6.911707e+08,1.107342,...,6138660.0,9745.474416,6.826221e+10,5.264856e+11,9821.062983,21617.019395,5836050.0,2861.233957,338209905.0,21617.067432
79,79,1.013261,2.352600e+06,1.012422,1.424681e+06,480.0,1.145861,208590048.0,7.087096e+08,1.095605,...,5453250.0,7484.772269,5.917951e+10,3.768174e+11,7506.546914,16502.956904,4455000.0,2867.971111,300374250.0,16502.883964
135,135,1.024372,3.413509e+06,1.023618,4.011558e+06,480.0,1.116020,63495328.0,1.304246e+08,1.103567,...,5351225.0,7139.776514,1.651260e+10,8.557969e+10,7254.829291,15948.914329,4306500.0,1048.204993,294748025.0,15948.675643


In [93]:
feature_matrix.columns

Index(['dummy', 'MAX(closing_movements.ask_price)',
       'MAX(closing_movements.ask_size)', 'MAX(closing_movements.bid_price)',
       'MAX(closing_movements.bid_size)', 'MAX(closing_movements.date_id)',
       'MAX(closing_movements.far_price)',
       'MAX(closing_movements.imbalance_size)',
       'MAX(closing_movements.matched_size)',
       'MAX(closing_movements.near_price)',
       'MAX(closing_movements.reference_price)',
       'MAX(closing_movements.seconds_in_bucket)',
       'MAX(closing_movements.target)', 'MAX(closing_movements.time_id)',
       'MAX(closing_movements.wap)', 'MEAN(closing_movements.ask_price)',
       'MEAN(closing_movements.ask_size)', 'MEAN(closing_movements.bid_price)',
       'MEAN(closing_movements.bid_size)', 'MEAN(closing_movements.date_id)',
       'MEAN(closing_movements.far_price)',
       'MEAN(closing_movements.imbalance_size)',
       'MEAN(closing_movements.matched_size)',
       'MEAN(closing_movements.near_price)',
       'MEAN(closing_m

In [96]:
ft.selection.remove_highly_null_features(feature_matrix)

Unnamed: 0_level_0,dummy,MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_price),MAX(closing_movements.bid_size),MAX(closing_movements.date_id),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_size),MAX(closing_movements.matched_size),MAX(closing_movements.near_price),...,SUM(closing_movements.date_id),SUM(closing_movements.far_price),SUM(closing_movements.imbalance_size),SUM(closing_movements.matched_size),SUM(closing_movements.near_price),SUM(closing_movements.reference_price),SUM(closing_movements.seconds_in_bucket),SUM(closing_movements.target),SUM(closing_movements.time_id),SUM(closing_movements.wap)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,1.009462,1.365950e+06,1.009097,5.899042e+06,480.0,1.174934,133453864.0,2.256941e+08,1.103863,...,6349200.0,11900.113714,1.113933e+11,6.813479e+11,11992.981230,26450.626731,7142850.0,-6904.957294,349920285.0,26450.826897
1,1,1.018351,7.483277e+05,1.016511,6.939674e+05,480.0,1.246998,29274852.0,3.122824e+08,1.109980,...,6349200.0,11773.638891,1.371727e+10,1.117034e+11,12024.151295,26453.051268,7142850.0,-3171.616168,349920285.0,26453.122419
2,2,1.022605,1.733813e+06,1.019387,1.069922e+06,480.0,1.353206,619560640.0,1.494537e+09,1.098223,...,6349200.0,11827.957806,4.936810e+10,2.617821e+11,12028.561642,26458.947830,7142850.0,661.582899,349920285.0,26458.844050
3,3,1.012767,1.277225e+06,1.012299,1.929015e+06,480.0,1.068764,314800768.0,8.262816e+08,1.102969,...,6349200.0,11977.538373,2.468673e+11,1.850779e+12,12013.644835,26454.420663,7142850.0,-3649.805199,349920285.0,26454.400392
4,4,1.009318,8.845469e+05,1.008878,1.604196e+06,480.0,1.101952,77445264.0,3.428062e+08,1.101019,...,6349200.0,11933.699412,1.007934e+11,7.229124e+11,12014.354280,26451.257714,7142850.0,-5923.464309,349920285.0,26451.367152
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,153,1.018965,1.746364e+06,1.017415,9.941822e+05,480.0,1.174682,90502128.0,2.407256e+08,1.103799,...,6216375.0,9960.324740,1.894796e+10,1.159091e+11,10285.603993,22595.542040,6103350.0,-12855.575029,342510960.0,22594.496180
199,199,1.013363,1.412633e+06,1.013005,4.564530e+06,480.0,1.105067,311616448.0,6.911707e+08,1.107342,...,6138660.0,9745.474416,6.826221e+10,5.264856e+11,9821.062983,21617.019395,5836050.0,2861.233957,338209905.0,21617.067432
79,79,1.013261,2.352600e+06,1.012422,1.424681e+06,480.0,1.145861,208590048.0,7.087096e+08,1.095605,...,5453250.0,7484.772269,5.917951e+10,3.768174e+11,7506.546914,16502.956904,4455000.0,2867.971111,300374250.0,16502.883964
135,135,1.024372,3.413509e+06,1.023618,4.011558e+06,480.0,1.116020,63495328.0,1.304246e+08,1.103567,...,5351225.0,7139.776514,1.651260e+10,8.557969e+10,7254.829291,15948.914329,4306500.0,1048.204993,294748025.0,15948.675643


In [98]:
from featuretools.selection import (
    remove_highly_correlated_features,
    remove_highly_null_features,
    remove_single_value_features,
)

In [100]:
new_fm, new_features = remove_single_value_features(feature_matrix, features=feature_defs)
new_fm

Unnamed: 0_level_0,dummy,MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_price),MAX(closing_movements.bid_size),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_size),MAX(closing_movements.matched_size),MAX(closing_movements.near_price),MAX(closing_movements.reference_price),...,SUM(closing_movements.date_id),SUM(closing_movements.far_price),SUM(closing_movements.imbalance_size),SUM(closing_movements.matched_size),SUM(closing_movements.near_price),SUM(closing_movements.reference_price),SUM(closing_movements.seconds_in_bucket),SUM(closing_movements.target),SUM(closing_movements.time_id),SUM(closing_movements.wap)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,1.009462,1.365950e+06,1.009097,5.899042e+06,1.174934,133453864.0,2.256941e+08,1.103863,1.009097,...,6349200.0,11900.113714,1.113933e+11,6.813479e+11,11992.981230,26450.626731,7142850.0,-6904.957294,349920285.0,26450.826897
1,1,1.018351,7.483277e+05,1.016511,6.939674e+05,1.246998,29274852.0,3.122824e+08,1.109980,1.017256,...,6349200.0,11773.638891,1.371727e+10,1.117034e+11,12024.151295,26453.051268,7142850.0,-3171.616168,349920285.0,26453.122419
2,2,1.022605,1.733813e+06,1.019387,1.069922e+06,1.353206,619560640.0,1.494537e+09,1.098223,1.021975,...,6349200.0,11827.957806,4.936810e+10,2.617821e+11,12028.561642,26458.947830,7142850.0,661.582899,349920285.0,26458.844050
3,3,1.012767,1.277225e+06,1.012299,1.929015e+06,1.068764,314800768.0,8.262816e+08,1.102969,1.012715,...,6349200.0,11977.538373,2.468673e+11,1.850779e+12,12013.644835,26454.420663,7142850.0,-3649.805199,349920285.0,26454.400392
4,4,1.009318,8.845469e+05,1.008878,1.604196e+06,1.101952,77445264.0,3.428062e+08,1.101019,1.009001,...,6349200.0,11933.699412,1.007934e+11,7.229124e+11,12014.354280,26451.257714,7142850.0,-5923.464309,349920285.0,26451.367152
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,153,1.018965,1.746364e+06,1.017415,9.941822e+05,1.174682,90502128.0,2.407256e+08,1.103799,1.018577,...,6216375.0,9960.324740,1.894796e+10,1.159091e+11,10285.603993,22595.542040,6103350.0,-12855.575029,342510960.0,22594.496180
199,199,1.013363,1.412633e+06,1.013005,4.564530e+06,1.105067,311616448.0,6.911707e+08,1.107342,1.013363,...,6138660.0,9745.474416,6.826221e+10,5.264856e+11,9821.062983,21617.019395,5836050.0,2861.233957,338209905.0,21617.067432
79,79,1.013261,2.352600e+06,1.012422,1.424681e+06,1.145861,208590048.0,7.087096e+08,1.095605,1.012842,...,5453250.0,7484.772269,5.917951e+10,3.768174e+11,7506.546914,16502.956904,4455000.0,2867.971111,300374250.0,16502.883964
135,135,1.024372,3.413509e+06,1.023618,4.011558e+06,1.116020,63495328.0,1.304246e+08,1.103567,1.025126,...,5351225.0,7139.776514,1.651260e+10,8.557969e+10,7254.829291,15948.914329,4306500.0,1048.204993,294748025.0,15948.675643


In [101]:
new_fm2, new_features2 = remove_highly_correlated_features(new_fm, features=new_features)
new_fm2.head()

Unnamed: 0_level_0,dummy,MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_size),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_size),MAX(closing_movements.matched_size),MAX(closing_movements.near_price),MAX(closing_movements.target),MEAN(closing_movements.ask_price),...,SKEW(closing_movements.far_price),SKEW(closing_movements.imbalance_size),SKEW(closing_movements.matched_size),SKEW(closing_movements.near_price),SKEW(closing_movements.target),STD(closing_movements.ask_price),STD(closing_movements.ask_size),STD(closing_movements.bid_size),STD(closing_movements.near_price),STD(closing_movements.target)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,1.009462,1365950.0,5899042.0,1.174934,133453864.0,225694100.0,1.103863,46.720505,0.999954,...,-2.050147,7.760194,4.226085,-1.091127,-0.434009,0.001679,57668.534384,75318.384942,0.013321,6.054234
1,1,1.018351,748327.7,693967.4,1.246998,29274852.0,312282400.0,1.10998,62.069893,1.000392,...,1.355987,6.298835,20.28013,0.624114,-0.093806,0.002793,33105.16694,33339.741636,0.010931,11.585558
2,2,1.022605,1733813.0,1069922.0,1.353206,619560640.0,1494537000.0,1.098223,103.969574,1.000608,...,2.975471,24.721674,19.269559,-1.698475,-0.648965,0.002663,40230.444159,35444.210118,0.010853,10.9796
3,3,1.012767,1277225.0,1929015.0,1.068764,314800768.0,826281600.0,1.102969,45.73941,1.000091,...,-0.409586,5.338071,5.122866,-0.605044,-0.135992,0.001453,48042.78273,45832.987755,0.010039,5.268299
4,4,1.009318,884546.9,1604196.0,1.101952,77445264.0,342806200.0,1.101019,37.4496,1.000002,...,-25.581169,3.942997,6.103034,-0.639007,-0.104974,0.001861,38529.121984,42141.152591,0.010009,5.927555


In [103]:
new_fm2.drop(['dummy'], axis = 1, 

Unnamed: 0_level_0,MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_size),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_size),MAX(closing_movements.matched_size),MAX(closing_movements.near_price),MAX(closing_movements.target),MEAN(closing_movements.ask_price),MEAN(closing_movements.ask_size),...,SKEW(closing_movements.far_price),SKEW(closing_movements.imbalance_size),SKEW(closing_movements.matched_size),SKEW(closing_movements.near_price),SKEW(closing_movements.target),STD(closing_movements.ask_price),STD(closing_movements.ask_size),STD(closing_movements.bid_size),STD(closing_movements.near_price),STD(closing_movements.target)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,1.009462,1.365950e+06,5.899042e+06,1.174934,133453864.0,2.256941e+08,1.103863,46.720505,0.999954,38356.347455,...,-2.050147,7.760194,4.226085,-1.091127,-0.434009,0.001679,57668.534384,75318.384942,0.013321,6.054234
1,1.018351,7.483277e+05,6.939674e+05,1.246998,29274852.0,3.122824e+08,1.109980,62.069893,1.000392,23091.037226,...,1.355987,6.298835,20.280130,0.624114,-0.093806,0.002793,33105.166940,33339.741636,0.010931,11.585558
2,1.022605,1.733813e+06,1.069922e+06,1.353206,619560640.0,1.494537e+09,1.098223,103.969574,1.000608,25819.680797,...,2.975471,24.721674,19.269559,-1.698475,-0.648965,0.002663,40230.444159,35444.210118,0.010853,10.979600
3,1.012767,1.277225e+06,1.929015e+06,1.068764,314800768.0,8.262816e+08,1.102969,45.739410,1.000091,35430.385339,...,-0.409586,5.338071,5.122866,-0.605044,-0.135992,0.001453,48042.782730,45832.987755,0.010039,5.268299
4,1.009318,8.845469e+05,1.604196e+06,1.101952,77445264.0,3.428062e+08,1.101019,37.449600,1.000002,29114.623243,...,-25.581169,3.942997,6.103034,-0.639007,-0.104974,0.001861,38529.121984,42141.152591,0.010009,5.927555
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,1.018965,1.746364e+06,9.941822e+05,1.174682,90502128.0,2.407256e+08,1.103799,139.269836,0.999855,36066.697132,...,1.024624,13.055524,7.417886,2.806757,-0.294145,0.003211,56364.484604,49912.453498,0.012680,14.081304
199,1.013363,1.412633e+06,4.564530e+06,1.105067,311616448.0,6.911707e+08,1.107342,79.900024,1.000278,83447.149526,...,-0.535993,19.699556,8.281383,1.887790,0.130295,0.002162,86952.577849,107257.479023,0.013319,9.060902
79,1.013261,2.352600e+06,1.424681e+06,1.145861,208590048.0,7.087096e+08,1.095605,100.320580,1.000554,148141.128335,...,0.043616,11.280358,9.677143,0.563881,-0.151660,0.002737,150455.777596,133100.599281,0.011609,11.550668
135,1.024372,3.413509e+06,4.011558e+06,1.116020,63495328.0,1.304246e+08,1.103567,160.100464,1.000225,106366.746110,...,-3.049728,9.332666,7.456330,4.728427,0.523074,0.003369,146936.614639,142747.352571,0.008932,12.270062


In [104]:
df_

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,bid_size,ask_price,ask_size,wap,target,time_id,row_id
0,0,0,0,3.180603e+06,1,0.999812,13380277.00,,,0.999812,60651.500000,1.000026,8493.030273,1.000000,-3.029704,0,0
1,1,0,0,1.666039e+05,-1,0.999896,1642214.25,,,0.999896,3233.040039,1.000660,20605.089844,1.000000,-5.519986,0,1
2,2,0,0,3.028799e+05,-1,0.999561,1819368.00,,,0.999403,37956.000000,1.000298,18995.000000,1.000000,-8.389950,0,2
3,3,0,0,1.191768e+07,-1,1.000171,18389746.00,,,0.999999,2324.899902,1.000214,479032.406250,1.000000,-4.010201,0,3
4,4,0,0,4.475500e+05,-1,0.999532,17860614.00,,,0.999394,16485.539062,1.000016,434.100006,1.000000,-7.349849,0,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
480540195,195,480,540,2.440723e+06,-1,1.000317,28280362.00,0.999734,0.999734,1.000317,32257.039062,1.000434,319862.406250,1.000328,2.310276,26454,480540195
480540196,196,480,540,3.495105e+05,-1,1.000643,9187699.00,1.000129,1.000386,1.000643,205108.406250,1.000900,93393.070312,1.000819,-8.220077,26454,480540196
480540197,197,480,540,0.000000e+00,0,0.995789,12725436.00,0.995789,0.995789,0.995789,16790.660156,0.995883,180038.312500,0.995797,1.169443,26454,480540197
480540198,198,480,540,1.000899e+06,1,0.999210,94773272.00,0.999210,0.999210,0.998970,125631.718750,0.999210,669893.000000,0.999008,-1.540184,26454,480540198


In [105]:
df_.merge(new_fm2, left_on = "stock_id", right_on = "stock_id", how = "left")

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,...,SKEW(closing_movements.far_price),SKEW(closing_movements.imbalance_size),SKEW(closing_movements.matched_size),SKEW(closing_movements.near_price),SKEW(closing_movements.target),STD(closing_movements.ask_price),STD(closing_movements.ask_size),STD(closing_movements.bid_size),STD(closing_movements.near_price),STD(closing_movements.target)
0,0,0,0,3.180603e+06,1,0.999812,13380277.00,,,0.999812,...,-2.050147,7.760194,4.226085,-1.091127,-0.434009,0.001679,57668.534384,75318.384942,0.013321,6.054234
1,1,0,0,1.666039e+05,-1,0.999896,1642214.25,,,0.999896,...,1.355987,6.298835,20.280130,0.624114,-0.093806,0.002793,33105.166940,33339.741636,0.010931,11.585558
2,2,0,0,3.028799e+05,-1,0.999561,1819368.00,,,0.999403,...,2.975471,24.721674,19.269559,-1.698475,-0.648965,0.002663,40230.444159,35444.210118,0.010853,10.979600
3,3,0,0,1.191768e+07,-1,1.000171,18389746.00,,,0.999999,...,-0.409586,5.338071,5.122866,-0.605044,-0.135992,0.001453,48042.782730,45832.987755,0.010039,5.268299
4,4,0,0,4.475500e+05,-1,0.999532,17860614.00,,,0.999394,...,-25.581169,3.942997,6.103034,-0.639007,-0.104974,0.001861,38529.121984,42141.152591,0.010009,5.927555
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5237975,195,480,540,2.440723e+06,-1,1.000317,28280362.00,0.999734,0.999734,1.000317,...,-1.444147,4.093805,4.185393,-1.078113,-0.324241,0.001526,50676.098700,47541.936211,0.011793,6.144210
5237976,196,480,540,3.495105e+05,-1,1.000643,9187699.00,1.000129,1.000386,1.000643,...,-0.289048,4.834395,3.986324,-0.837400,-0.061839,0.001707,41661.136075,36408.928198,0.013986,7.213802
5237977,197,480,540,0.000000e+00,0,0.995789,12725436.00,0.995789,0.995789,0.995789,...,0.930170,6.823695,9.423061,0.285158,-0.008281,0.002347,36390.231064,35573.943906,0.012686,8.674533
5237978,198,480,540,1.000899e+06,1,0.999210,94773272.00,0.999210,0.999210,0.998970,...,-0.246794,3.684398,4.633701,1.175606,-0.218825,0.001572,178691.093249,175991.694841,0.009116,5.956440


In [95]:
feature_matrix2, feature_defs2 = ft.dfs(
    entityset=es,
    target_dataframe_name="stocks",
    agg_primitives=["mean", "sum", "mode"],
    # trans_primitives=["month", "hour"],
    max_depth=2,
)
feature_matrix2

The provided callable <function sum at 0x000001496718FEB0> is currently using SeriesGroupBy.sum. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "sum" instead.
The provided callable <function mean at 0x00000149671BCE50> is currently using SeriesGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.


Unnamed: 0_level_0,dummy,MEAN(closing_movements.ask_price),MEAN(closing_movements.ask_size),MEAN(closing_movements.bid_price),MEAN(closing_movements.bid_size),MEAN(closing_movements.date_id),MEAN(closing_movements.far_price),MEAN(closing_movements.imbalance_size),MEAN(closing_movements.matched_size),MEAN(closing_movements.near_price),...,SUM(closing_movements.date_id),SUM(closing_movements.far_price),SUM(closing_movements.imbalance_size),SUM(closing_movements.matched_size),SUM(closing_movements.near_price),SUM(closing_movements.reference_price),SUM(closing_movements.seconds_in_bucket),SUM(closing_movements.target),SUM(closing_movements.time_id),SUM(closing_movements.wap)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,0.999954,38356.347455,0.999735,36045.936725,240.0,0.996826,4.210671e+06,2.575498e+07,0.997337,...,6349200.0,11900.113714,1.113933e+11,6.813479e+11,11992.981230,26450.626731,7142850.0,-6904.957294,349920285.0,26450.826897
1,1,1.000392,23091.037226,0.999480,22565.737433,240.0,0.999121,5.185133e+05,4.222393e+06,0.999929,...,6349200.0,11773.638891,1.371727e+10,1.117034e+11,12024.151295,26453.051268,7142850.0,-3171.616168,349920285.0,26453.122419
2,2,1.000608,25819.680797,0.999697,23600.347566,240.0,1.001351,1.866116e+06,9.895372e+06,1.000296,...,6349200.0,11827.957806,4.936810e+10,2.617821e+11,12028.561642,26458.947830,7142850.0,661.582899,349920285.0,26458.844050
3,3,1.000091,35430.385339,0.999872,32339.317298,240.0,0.998877,9.331595e+06,6.995952e+07,0.999056,...,6349200.0,11977.538373,2.468673e+11,1.850779e+12,12013.644835,26454.420663,7142850.0,-3649.805199,349920285.0,26454.400392
4,4,1.000002,29114.623243,0.999728,28348.285387,240.0,0.998720,3.809995e+06,2.732611e+07,0.999115,...,6349200.0,11933.699412,1.007934e+11,7.229124e+11,12014.354280,26451.257714,7142850.0,-5923.464309,349920285.0,26451.367152
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,153,0.999855,36066.697132,0.999198,38539.690720,275.0,1.002146,8.382197e+05,5.127586e+06,1.001032,...,6216375.0,9960.324740,1.894796e+10,1.159091e+11,10285.603993,22595.542040,6103350.0,-12855.575029,342510960.0,22594.496180
199,199,1.000278,83447.149526,0.999918,81383.169055,284.0,0.998819,3.158094e+06,2.435742e+07,0.999599,...,6138660.0,9745.474416,6.826221e+10,5.264856e+11,9821.062983,21617.019395,5836050.0,2861.233957,338209905.0,21617.067432
79,79,1.000554,148141.128335,0.999790,147523.106819,330.5,1.001307,3.586637e+06,2.283742e+07,1.000873,...,5453250.0,7484.772269,5.917951e+10,3.768174e+11,7506.546914,16502.956904,4455000.0,2867.971111,300374250.0,16502.883964
135,135,1.000225,106366.746110,0.999589,112588.670274,335.5,1.000669,1.035273e+06,5.365498e+06,1.000666,...,5351225.0,7139.776514,1.651260e+10,8.557969e+10,7254.829291,15948.914329,4306500.0,1048.204993,294748025.0,15948.675643


In [57]:
# feature_matrix, feature_defs = ft.dfs(entityset=es, 
#                                     target_dataframe_name="stocks")
# feature_matrix.columns

The provided callable <function min at 0x00000149671BC670> is currently using SeriesGroupBy.min. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "min" instead.
The provided callable <function mean at 0x00000149671BCE50> is currently using SeriesGroupBy.mean. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "mean" instead.
The provided callable <function std at 0x00000149671BCF70> is currently using SeriesGroupBy.std. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "std" instead.
The provided callable <function max at 0x00000149671BC550> is currently using SeriesGroupBy.max. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "max" instead.
The provided callable <function sum at 0x000001496718FEB0> is currently using Ser

Unnamed: 0_level_0,dummy,COUNT(closing_movements),MAX(closing_movements.ask_price),MAX(closing_movements.ask_size),MAX(closing_movements.bid_price),MAX(closing_movements.bid_size),MAX(closing_movements.date_id),MAX(closing_movements.far_price),MAX(closing_movements.imbalance_buy_sell_flag),MAX(closing_movements.imbalance_size),...,SUM(closing_movements.far_price),SUM(closing_movements.imbalance_buy_sell_flag),SUM(closing_movements.imbalance_size),SUM(closing_movements.matched_size),SUM(closing_movements.near_price),SUM(closing_movements.reference_price),SUM(closing_movements.seconds_in_bucket),SUM(closing_movements.target),SUM(closing_movements.time_id),SUM(closing_movements.wap)
stock_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,26455,1.009462,1.365950e+06,1.009097,5.899042e+06,480.0,1.174934,1.0,133453864.0,...,11900.113714,-4111.0,1.113933e+11,6.813479e+11,11992.981230,26450.626731,7142850.0,-6904.957294,349920285.0,26450.826897
1,1,26455,1.018351,7.483277e+05,1.016511,6.939674e+05,480.0,1.246998,1.0,29274852.0,...,11773.638891,-855.0,1.371727e+10,1.117034e+11,12024.151295,26453.051268,7142850.0,-3171.616168,349920285.0,26453.122419
2,2,26455,1.022605,1.733813e+06,1.019387,1.069922e+06,480.0,1.353206,1.0,619560640.0,...,11827.957806,1243.0,4.936810e+10,2.617821e+11,12028.561642,26458.947830,7142850.0,661.582899,349920285.0,26458.844050
3,3,26455,1.012767,1.277225e+06,1.012299,1.929015e+06,480.0,1.068764,1.0,314800768.0,...,11977.538373,-1696.0,2.468673e+11,1.850779e+12,12013.644835,26454.420663,7142850.0,-3649.805199,349920285.0,26454.400392
4,4,26455,1.009318,8.845469e+05,1.008878,1.604196e+06,480.0,1.101952,1.0,77445264.0,...,11933.699412,-2035.0,1.007934e+11,7.229124e+11,12014.354280,26451.257714,7142850.0,-5923.464309,349920285.0,26451.367152
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
153,153,22605,1.018965,1.746364e+06,1.017415,9.941822e+05,480.0,1.174682,1.0,90502128.0,...,9960.324740,4144.0,1.894796e+10,1.159091e+11,10285.603993,22595.542040,6103350.0,-12855.575029,342510960.0,22594.496180
199,199,21615,1.013363,1.412633e+06,1.013005,4.564530e+06,480.0,1.105067,1.0,311616448.0,...,9745.474416,-479.0,6.826221e+10,5.264856e+11,9821.062983,21617.019395,5836050.0,2861.233957,338209905.0,21617.067432
79,79,16500,1.013261,2.352600e+06,1.012422,1.424681e+06,480.0,1.145861,1.0,208590048.0,...,7484.772269,228.0,5.917951e+10,3.768174e+11,7506.546914,16502.956904,4455000.0,2867.971111,300374250.0,16502.883964
135,135,15950,1.024372,3.413509e+06,1.023618,4.011558e+06,480.0,1.116020,1.0,63495328.0,...,7139.776514,1283.0,1.651260e+10,8.557969e+10,7254.829291,15948.914329,4306500.0,1048.204993,294748025.0,15948.675643


In [None]:
import numpy as np
import pandas as pd
import featuretools as ft
from data_preprocessor.data_preprocessor import DataPreprocessor

class DfsPreProcessor(DataPreprocessor):
    def apply(self, df):

        df_ = df.copy()

        es = ft.EntitySet(id = 'train_df')
        # es = es.entity_from_dataframe(entity_id = 'df', dataframe = df_, index = 'row_id')
        es = es.add_dataframe(
            dataframe_name="closing_movements",
            dataframe=df_,
            index="row_id",
            time_index="time_id",
            # logical_types={
            #     "product_id": Categorical,
            #     "zip_code": PostalCode,
            # },
        )

        print(es["closing_movements"].ww.schema)

        default_agg_primitives =  ["sum", "std", "max", "skew", "min", "mean", "count", "percent_true", "num_unique", "mode"]
        default_trans_primitives =  ["day", "year", "month", "weekday", "haversine", "numwords", "characters"]

        feature_names = ft.dfs(entityset = es, 
                            #    target_entity = 'df',
                       trans_primitives = default_trans_primitives,
                       agg_primitives=default_agg_primitives, 
                       max_depth = 2, features_only=True)
        
        print(feature_names)        

        return df_

In [6]:
df_train = DfsPreProcessor().apply(df_train)

Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.
Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.


KeyError: 'DataFrame transactions does not exist in train_df'