In [1]:
import sys
import os

sys.path.append('..')

from interpretDistill.fourierDistill import *
from interpretDistill.binaryTransformer import *

In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score

X, y = load_iris(as_frame = True, return_X_y = True)
X.columns = X.columns.str.replace(' ', '_')
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)

In [3]:
bt_bit = BinaryTransformer(depth = 4, bit = True)
bt_bin = BinaryTransformer(depth = 4, bit = False)

In [4]:
X_train_bit = bt_bit.fit_and_transform(X_train, y_train)
X_val_bit = bt_bit.transform(X_val)
X_test_bit = bt_bit.transform(X_test)

In [5]:
X_train_bin = bt_bin.fit_and_transform(X_train, y_train)
X_val_bin = bt_bin.transform(X_val)
X_test_bin = bt_bin.transform(X_test)

In [6]:
print(f'bit number of features: {X_train_bit.shape[1]}')
print(f'bin number of features: {X_train_bin.shape[1]}')

bit number of features: 16
bin number of features: 38


In [7]:
rf_orig = RandomForestRegressor(max_depth = 6)
rf_bit = RandomForestRegressor(max_depth = 6)
rf_bin = RandomForestRegressor(max_depth = 6)

In [8]:
rf_orig.fit(X_train, y_train)
rf_bit.fit(X_train_bit, y_train)
rf_bin.fit(X_train_bin, y_train)

In [9]:
print(f'[orig] train MSE: {mean_squared_error(rf_orig.predict(X_train),y_train)}, val MSE: {mean_squared_error(rf_orig.predict(X_val),y_val)}, test MSE: {mean_squared_error(rf_orig.predict(X_test),y_test)}')
print(f'[bit] train MSE: {mean_squared_error(rf_bit.predict(X_train_bit),y_train)}, val MSE: {mean_squared_error(rf_bit.predict(X_val_bit),y_val)}, test MSE: {mean_squared_error(rf_bit.predict(X_test_bit),y_test)}')
print(f'[bin] train MSE: {mean_squared_error(rf_bin.predict(X_train_bin),y_train)}, val MSE: {mean_squared_error(rf_bin.predict(X_val_bin),y_val)}, test MSE: {mean_squared_error(rf_bin.predict(X_test_bin),y_test)}')

[orig] train MSE: 0.005806666666666667, val MSE: 0.04606333333333333, test MSE: 0.004636666666666667
[bit] train MSE: 0.005656666666666666, val MSE: 0.06445666666666665, test MSE: 0.025026666666666662
[bin] train MSE: 0.006773333333333333, val MSE: 0.0779024074074074, test MSE: 0.025484166666666665


In [10]:
print(f'[orig RF] train R2: {r2_score(rf_orig.predict(X_train),y_train)}, val R2: {r2_score(rf_orig.predict(X_val),y_val)}, test R2: {r2_score(rf_orig.predict(X_test),y_test)}')
print(f'[bit RF] train R2: {r2_score(rf_bit.predict(X_train_bit),y_train)}, val R2: {r2_score(rf_bit.predict(X_val_bit),y_val)}, test R2: {r2_score(rf_bit.predict(X_test_bit),y_test)}')
print(f'[bin RF] train R2: {r2_score(rf_bin.predict(X_train_bin),y_train)}, val R2: {r2_score(rf_bin.predict(X_val_bin),y_val)}, test R2: {r2_score(rf_bin.predict(X_test_bin),y_test)}')

[orig RF] train R2: 0.9913019331034808, val R2: 0.9203140001672259, test R2: 0.9932201032630426
[bit RF] train R2: 0.9915695402349713, val R2: 0.8879176714326757, test R2: 0.9645288880741851
[bin RF] train R2: 0.9898345776884396, val R2: 0.8683533141898429, test R2: 0.9630349676401346


In [11]:
ftd_bit = FTDistillCV(size_interactions = 3)
ftd_bin = FTDistillCV(size_interactions = 3)

In [12]:
ftd_bit.fit(X_val_bit, rf_orig.predict(X_val))
ftd_bin.fit(X_val_bin, rf_orig.predict(X_val), bt_bin.no_interaction)

3
3


<interpretDistill.fourierDistill.FTDistillCV at 0x7fde70c5d7f0>

In [19]:
sum(ftd_bit.regression_model.coef_ != 0), sum(ftd_bin.regression_model.coef_ != 0)

(51, 28)

In [20]:
print(f'[bit FTD, true y] train MSE: {mean_squared_error(ftd_bit.predict(X_train_bit),y_train)}, val MSE: {mean_squared_error(ftd_bit.predict(X_val_bit),y_val)}, test MSE: {mean_squared_error(ftd_bit.predict(X_test_bit),y_test)}')
print(f'[bit FTD, RF y] train MSE: {mean_squared_error(ftd_bit.predict(X_train_bit),rf_bit.predict(X_train_bit))}, val MSE: {mean_squared_error(ftd_bit.predict(X_val_bit),rf_bit.predict(X_val_bit))}, test MSE: {mean_squared_error(ftd_bit.predict(X_test_bit),rf_bit.predict(X_test_bit))}')
print(f'[bin FTD, true y] train MSE: {mean_squared_error(ftd_bin.predict(X_train_bin),y_train)}, val MSE: {mean_squared_error(ftd_bin.predict(X_val_bin),y_val)}, test MSE: {mean_squared_error(ftd_bin.predict(X_test_bin),y_test)}')
print(f'[bin FTD, RF y] train MSE: {mean_squared_error(ftd_bin.predict(X_train_bin),rf_bin.predict(X_train_bin))}, val MSE: {mean_squared_error(ftd_bin.predict(X_val_bin),rf_bin.predict(X_val_bin))}, test MSE: {mean_squared_error(ftd_bin.predict(X_test_bin),rf_bin.predict(X_test_bin))}')

[bit FTD, true y] train MSE: 0.07716219710794889, val MSE: 0.0462629279225592, test MSE: 0.039491436959414955
[bit FTD, RF y] train MSE: 0.05635333929487045, val MSE: 0.004049029682553336, test MSE: 0.050200425086815985
[bin FTD, true y] train MSE: 0.062068410780404905, val MSE: 0.04620940446982145, test MSE: 0.03753774797336353
[bin FTD, RF y] train MSE: 0.036558578171171315, val MSE: 0.012420321064519306, test MSE: 0.037758548483748114


In [14]:
print(f'[bit RF, true y] train R2: {r2_score(ftd_bit.predict(X_train_bit),y_train)}, val R2: {r2_score(ftd_bit.predict(X_val_bit),y_val)}, test R2: {r2_score(ftd_bit.predict(X_test_bit),y_test)}')
print(f'[bit RF, RF y] train R2: {r2_score(ftd_bit.predict(X_train_bit),rf_orig.predict(X_train))}, val R2: {r2_score(ftd_bit.predict(X_val_bit),rf_orig.predict(X_val))}, test R2: {r2_score(ftd_bit.predict(X_test_bit),rf_orig.predict(X_test))}')
print(f'[bin RF, true y] train R2: {r2_score(ftd_bin.predict(X_train_bin),y_train)}, val R2: {r2_score(ftd_bin.predict(X_val_bin),y_val)}, test R2: {r2_score(ftd_bin.predict(X_test_bin),y_test)}')
print(f'[bin RF, RF y] train R2: {r2_score(ftd_bin.predict(X_train_bin),rf_orig.predict(X_train))}, val R2: {r2_score(ftd_bin.predict(X_val_bin),rf_orig.predict(X_val))}, test R2: {r2_score(ftd_bin.predict(X_test_bin),rf_orig.predict(X_test))}')

[bit RF, true y] train R2: 0.8820437441636716, val R2: 0.9196170423932237, test R2: 0.9381004006320611
[bit RF, RF y] train R2: 0.9126951377673209, val R2: 0.9999842419781084, test R2: 0.9407679795627181
[bin RF, true y] train R2: 0.8966020286747852, val R2: 0.9196504992067717, test R2: 0.9399330718253924
[bin RF, RF y] train R2: 0.9326791591085134, val R2: 0.999989655514375, test R2: 0.9599270268371457


In [15]:
from itertools import compress

#sorted([i for i in zip(list(compress(ftd_bit.features, ftd_bit.regression_model.coef_ != 0)), list(compress(ftd_bit.regression_model.coef_, ftd_bit.regression_model.coef_ != 0)))], key = lambda x: abs(x[1]))
sorted([i for i in zip(list(compress(ftd_bin.features, ftd_bin.regression_model.coef_ != 0)), list(compress(ftd_bin.regression_model.coef_, ftd_bin.regression_model.coef_ != 0)))], key = lambda x: abs(x[1]))

[(('sepal_width_(cm)_7', 'petal_length_(cm)_12', 'sepal_length_(cm)_9'),
  4.1753498016944366e-05),
 (('petal_width_(cm)_1', 'sepal_width_(cm)_20', 'petal_length_(cm)_4'),
  -9.890771048789017e-05),
 (('sepal_length_(cm)_20', 'petal_length_(cm)_10', 'petal_width_(cm)_12'),
  0.0011962348892383694),
 (('petal_length_(cm)_12', 'sepal_length_(cm)_14'), -0.0014503511925869192),
 (('petal_width_(cm)_1', 'sepal_length_(cm)_16', 'petal_length_(cm)_4'),
  -0.0034061596538699),
 (('petal_length_(cm)_1', 'petal_width_(cm)_4', 'sepal_width_(cm)_22'),
  -0.003612152411889556),
 (('sepal_width_(cm)_4', 'petal_length_(cm)_1', 'petal_width_(cm)_10'),
  -0.004759486076198704),
 (('sepal_length_(cm)_13', 'petal_length_(cm)_1', 'petal_width_(cm)_4'),
  -0.006628793188281814),
 (('petal_length_(cm)_1', 'petal_width_(cm)_4'), 0.0076614499028041854),
 (('petal_width_(cm)_1', 'petal_length_(cm)_6', 'sepal_length_(cm)_9'),
  -0.009464007493966057),
 (('sepal_length_(cm)_7', 'petal_length_(cm)_1', 'petal_widt

In [16]:
kill

NameError: name 'kill' is not defined

In [None]:
from ucimlrepo import fetch_ucirepo 
  
# fetch dataset 
solar_flare = fetch_ucirepo(id=89) 
  
# data (as pandas dataframes) 
X = solar_flare.data.features 
y = solar_flare.data.targets 
  
# metadata 
X

In [None]:
import matplotlib.pyplot as plt