In [1]:
%load_ext autoreload
%autoreload 2
import os
import pickle as pkl
from typing import Dict, Any

import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
from copy import deepcopy
mpl.rcParams['figure.dpi'] = 250

# change working directory to project root
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('../..')
from experiments.notebooks import viz
from experiments.data_util import get_clean_dataset
from experiments.config.datasets import DATASETS_CLASSIFICATION
pd.options.display.max_rows = 100

/Volumes/GoogleDrive/My Drive/research/rules/imodels/experiments/notebooks


# classification dataset stats

In [4]:
metadata = []
columns = ['Name', 'Samples', 'Features', 'Class 0', 'Class 1', 'Majority class %']
for dset_name, dset_file in DATASETS_CLASSIFICATION:
    X, y, feat_names = get_clean_dataset(dset_file)
#     X = np.nan_to_num(X, 0)
    shape = X.shape
    class_counts = np.unique(y, return_counts=True)[1]
    metadata.append([dset_name.capitalize(), shape[0], shape[1], class_counts[0], class_counts[1],
                     np.round(100 * np.max(class_counts) / np.sum(class_counts), decimals=1)])

metadata = pd.DataFrame(metadata, columns=columns) #.set_index('Name')
metadata

Unnamed: 0,Name,Samples,Features,Class 0,Class 1,Majority class %
0,Recidivism,6172,20,3182,2990,51.6
1,Credit,30000,33,23364,6636,77.9
2,Juvenile,3640,286,3153,487,86.6
3,Readmission,101763,150,54861,46902,53.9
4,Breast-cancer,277,17,196,81,70.8
5,Credit-g,1000,60,300,700,70.0
6,Haberman,306,3,81,225,73.5
7,Heart,270,15,150,120,55.6


In [5]:
print(metadata.to_latex(index=False))

\begin{tabular}{lrrrrr}
\toprule
         Name &  Samples &  Features &  Class 0 &  Class 1 &  Majority class \% \\
\midrule
   Recidivism &     6172 &        20 &     3182 &     2990 &              51.6 \\
       Credit &    30000 &        33 &    23364 &     6636 &              77.9 \\
     Juvenile &     3640 &       286 &     3153 &      487 &              86.6 \\
  Readmission &   101763 &       150 &    54861 &    46902 &              53.9 \\
Breast-cancer &      277 &        17 &      196 &       81 &              70.8 \\
     Credit-g &     1000 &        60 &      300 &      700 &              70.0 \\
     Haberman &      306 &         3 &       81 &      225 &              73.5 \\
        Heart &      270 &        15 &      150 &      120 &              55.6 \\
\bottomrule
\end{tabular}



# regression dataset names


In [2]:
from pmlb import fetch_data, classification_dataset_names

In [3]:
pmlb_meta = pd.read_csv('../data/pmlb_data/pmlb_metadata.csv').sort_values(by=['n_observations', 'n_features'], ascending=False)

In [4]:
pmlb_meta[(pmlb_meta.Task == 'regression') & ~(pmlb_meta.Dataset.str.contains('feynman'))].head(100)

Unnamed: 0,Dataset,n_observations,n_features,n_classes,Endpoint,Imbalance,Task,Metadata
174,1595_poker,1025010,10,,continuous,0.37,regression,
168,1191_BNG_pbc,1000000,18,,continuous,0.0,regression,
170,1196_BNG_pharynx,1000000,10,,continuous,0.0,regression,
173,1203_BNG_pwLinear,177147,10,,continuous,0.0,regression,
172,1201_BNG_breastTumor,116640,9,,continuous,0.04,regression,
181,215_2dplanes,40768,10,,continuous,0.0,regression,
189,344_mv,40768,10,,continuous,0.0,regression,
207,564_fried,40768,10,,continuous,0.0,regression,
169,1193_BNG_lowbwt,31104,9,,continuous,0.0,regression,
209,574_house_16H,22784,16,,continuous,0.02,regression,


In [22]:
X, y, feature_names = get_clean_dataset('542_pollution', data_source='pmlb')

In [23]:
feature_names

Index(['PREC', 'JANT', 'JULT', 'OVR65', 'POPN', 'EDUC', 'HOUS', 'DENS', 'NONW',
       'WWDRK', 'POOR', 'HC', 'NOX', 'SO2', 'HUMID', 'target'],
      dtype='object')

In [24]:
y

array([ 921.86999512,  997.875     ,  962.35400391,  982.29101562,
       1071.28894043, 1030.38000488,  934.70001221,  899.5289917 ,
       1001.90197754,  912.34698486, 1017.61297607, 1024.88500977,
        970.46697998,  985.95001221,  958.83898926,  860.10101318,
        936.23400879,  871.76599121,  959.2210083 ,  941.18103027,
        891.70800781,  871.3380127 ,  971.12200928,  887.46600342,
        952.5289917 ,  968.66497803,  919.72900391,  844.05297852,
        861.83300781,  989.26501465, 1006.48999023,  861.43902588,
        929.15002441,  857.62200928,  961.00897217,  923.23400879,
       1113.15600586,  994.64801025, 1015.02301025,  991.28997803,
        893.99102783,  938.5       ,  946.18499756, 1025.50195312,
        874.28100586,  953.55999756,  839.70898438,  911.70098877,
        790.73297119,  899.26397705,  904.1550293 ,  950.67199707,
        972.46398926,  912.20202637,  967.80297852,  823.76397705,
       1003.50201416,  895.69598389,  911.8170166 ,  954.44201