# Extracting the statistics for the number of conditions with thresholds on splits

In [5]:
import os

import datetime
import numpy as np
import pandas as pd

from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import RepeatedStratifiedKFold, RepeatedKFold

from flipping_random_forest import count_lattice_splits

from datasets import regr_datasets, binclas_datasets
from config import data_dir, random_seed, n_repeats_ms, n_splits_ms

In [6]:
labels = ['dtc', 'dtr', 'rfc', 'rfr']
params = {}

In [7]:
for label in labels:
    params[label] = {}
    best_params = pd.read_csv(os.path.join(data_dir, f'params_{label}.csv'))
    for idx, row in best_params.iterrows():
        params[label][row['name']] = eval(row['params'])

In [8]:
counts = {}

for idx, row in binclas_datasets.iterrows():
    dataset = row['data_loader_function']()

    param = params['dtc'][dataset['name']]
    param = {key: value for key, value in param.items() if key != 'mode'}

    X = dataset['data']
    y = dataset['target']

    estimator = DecisionTreeClassifier(**(param | {'random_state': random_seed})).fit(X, y)
    n_lattice_splits, n_splits = count_lattice_splits(X, estimator)
    counts[dataset['name']] = {'n_lattice_splits': n_lattice_splits, 'n_splits': n_splits, 'n_lattice_splits_kfold': 0, 'n_splits_kfold': 0}

    for idx, (train, test) in enumerate(RepeatedStratifiedKFold(n_splits=n_splits_ms, n_repeats=n_repeats_ms).split(X, y, y)):
        X_train = X[train]
        y_train = y[train]

        estimator = DecisionTreeClassifier(**(param | {'random_state': random_seed})).fit(X_train, y_train)

        for tree in [estimator]:
            n_lattice_splits, n_splits = count_lattice_splits(X_train, tree)
            counts[dataset['name']]['n_lattice_splits_kfold'] += n_lattice_splits
            counts[dataset['name']]['n_splits_kfold'] += n_splits

pd.DataFrame.from_dict(counts).T.to_csv(os.path.join(data_dir, 'splits_dtc.csv'))

In [9]:
counts = {}

for idx, row in regr_datasets.iterrows():
    dataset = row['data_loader_function']()

    param = params['dtr'][dataset['name']]
    param = {key: value for key, value in param.items() if key != 'mode'}

    X = dataset['data']
    y = dataset['target']

    estimator = DecisionTreeRegressor(**(param | {'random_state': random_seed})).fit(X, y)
    n_lattice_splits, n_splits = count_lattice_splits(X, estimator)
    counts[dataset['name']] = {'n_lattice_splits': n_lattice_splits, 'n_splits': n_splits, 'n_lattice_splits_kfold': 0, 'n_splits_kfold': 0}

    for idx, (train, test) in enumerate(RepeatedKFold(n_splits=n_splits_ms, n_repeats=n_repeats_ms).split(X, y)):
        X_train = X[train]
        y_train = y[train]

        estimator = DecisionTreeRegressor(**(param | {'random_state': random_seed})).fit(X_train, y_train)

        for tree in [estimator]:
            n_lattice_splits, n_splits = count_lattice_splits(X_train, tree)
            counts[dataset['name']]['n_lattice_splits_kfold'] += n_lattice_splits
            counts[dataset['name']]['n_splits_kfold'] += n_splits

pd.DataFrame.from_dict(counts).T.to_csv(os.path.join(data_dir, 'splits_dtr.csv'))

In [10]:
counts = {}

for idx, row in binclas_datasets.iterrows():
    dataset = row['data_loader_function']()

    print(datetime.datetime.now(), dataset['name'])

    param = params['rfc'][dataset['name']]
    param = {key: value for key, value in param.items() if key != 'mode'}

    X = dataset['data']
    y = dataset['target']

    estimator = RandomForestClassifier(**(param | {'random_state': random_seed})).fit(X, y)
    counts[dataset['name']] = {'n_lattice_splits': 0, 'n_splits': 0, 'n_lattice_splits_kfold': 0, 'n_splits_kfold': 0}

    for tree in estimator.estimators_:
        n_lattice_splits, n_splits = count_lattice_splits(X, tree)
        counts[dataset['name']]['n_lattice_splits'] += n_lattice_splits
        counts[dataset['name']]['n_splits'] += n_splits

    for idx, (train, test) in enumerate(RepeatedStratifiedKFold(n_splits=n_splits_ms, n_repeats=n_repeats_ms).split(X, y, y)):
        X_train = X[train]
        y_train = y[train]

        estimator = RandomForestClassifier(**(param | {'random_state': random_seed})).fit(X_train, y_train)

        for tree in estimator.estimators_:
            n_lattice_splits, n_splits = count_lattice_splits(X_train, tree)
            counts[dataset['name']]['n_lattice_splits_kfold'] += n_lattice_splits
            counts[dataset['name']]['n_splits_kfold'] += n_splits

pd.DataFrame.from_dict(counts).T.to_csv(os.path.join(data_dir, 'splits_rfc.csv'))

2023-12-11 20:34:33.903230 appendicitis
2023-12-11 20:34:53.667639 haberman
2023-12-11 20:35:09.137469 new_thyroid1
2023-12-11 20:35:27.480012 glass0
2023-12-11 20:36:04.338121 shuttle-6_vs_2-3
2023-12-11 20:36:26.834603 bupa
2023-12-11 20:37:03.039830 cleveland-0_vs_4
2023-12-11 20:37:27.059797 ecoli1
2023-12-11 20:37:52.549153 poker-9_vs_7
2023-12-11 20:38:14.161753 monk-2
2023-12-11 20:38:38.693712 hepatitis
2023-12-11 20:39:02.539090 yeast-0-3-5-9_vs_7-8
2023-12-11 20:39:31.283486 mammographic
2023-12-11 20:39:56.236259 saheart
2023-12-11 20:40:29.067881 page-blocks-1-3_vs_4
2023-12-11 20:41:05.524729 lymphography-normal-fibrosis
2023-12-11 20:41:43.552439 pima
2023-12-11 20:42:24.241294 wisconsin
2023-12-11 20:42:52.666715 abalone9_18
2023-12-11 20:43:28.351234 winequality-red-3_vs_5


In [11]:
counts = {}

for idx, row in regr_datasets.iterrows():
    dataset = row['data_loader_function']()

    print(datetime.datetime.now(), dataset['name'])

    param = params['rfr'][dataset['name']]
    param = {key: value for key, value in param.items() if key != 'mode'}

    X = dataset['data']
    y = dataset['target']

    estimator = RandomForestRegressor(**(param | {'random_state': random_seed})).fit(X, y)
    counts[dataset['name']] = {'n_lattice_splits': 0, 'n_splits': 0, 'n_lattice_splits_kfold': 0, 'n_splits_kfold': 0}

    for tree in estimator.estimators_:
        n_lattice_splits, n_splits = count_lattice_splits(X, tree)
        counts[dataset['name']]['n_lattice_splits'] += n_lattice_splits
        counts[dataset['name']]['n_splits'] += n_splits

    for idx, (train, test) in enumerate(RepeatedKFold(n_splits=n_splits_ms, n_repeats=n_repeats_ms).split(X, y)):
        X_train = X[train]
        y_train = y[train]

        estimator = RandomForestRegressor(**(param | {'random_state': random_seed})).fit(X_train, y_train)

        for tree in estimator.estimators_:
            n_lattice_splits, n_splits = count_lattice_splits(X_train, tree)
            counts[dataset['name']]['n_lattice_splits_kfold'] += n_lattice_splits
            counts[dataset['name']]['n_splits_kfold'] += n_splits

pd.DataFrame.from_dict(counts).T.to_csv(os.path.join(data_dir, 'splits_rfr.csv'))

2023-12-11 20:43:55.138809 diabetes
2023-12-11 20:44:09.551083 o-ring
2023-12-11 20:44:22.497570 stock_portfolio_performance
2023-12-11 20:44:47.232702 wsn-ale
2023-12-11 20:45:33.957894 daily-demand
2023-12-11 20:47:49.516810 slump_test
2023-12-11 20:48:29.365168 servo
2023-12-11 20:48:47.654150 yacht_hydrodynamics
2023-12-11 20:49:57.276747 autoMPG6
2023-12-11 20:51:25.047695 excitation_current
2023-12-11 20:53:17.770422 real_estate_valuation
2023-12-11 20:54:50.218391 wankara
2023-12-11 20:56:17.643599 plastic
2023-12-11 20:57:14.057795 laser
2023-12-11 21:02:24.882530 qsar-aquatic-toxicity
2023-12-11 21:06:05.777302 baseball
2023-12-11 21:06:55.678323 maternal_health_risk
2023-12-11 21:08:38.343943 cpu_performance
2023-12-11 21:10:17.500702 airfoil
2023-12-11 21:17:43.883331 medical_cost


In [13]:
data = pd.read_csv(os.path.join(data_dir, 'splits_rfc.csv'))
data['perc'] = data['n_lattice_splits'] / data['n_splits']
data['perc_kfold'] = data['n_lattice_splits_kfold'] / data['n_splits_kfold']
data

Unnamed: 0.1,Unnamed: 0,n_lattice_splits,n_splits,n_lattice_splits_kfold,n_splits_kfold,perc,perc_kfold
0,appendicitis,3,133,434,14994,0.022556,0.028945
1,haberman,28,711,3463,68489,0.039381,0.050563
2,new_thyroid1,41,570,4709,51657,0.07193,0.091159
3,glass0,118,2401,9073,202491,0.049146,0.044807
4,shuttle-6_vs_2-3,46,339,3986,31453,0.135693,0.126729
5,bupa,1095,4397,87153,357469,0.249033,0.243806
6,cleveland-0_vs_4,41,432,4192,39979,0.094907,0.104855
7,ecoli1,150,1273,14692,120687,0.117832,0.121736
8,poker-9_vs_7,101,709,9974,59397,0.142454,0.167921
9,monk-2,1,1814,521,189095,0.000551,0.002755
