In [2]:
%load_ext autoreload
%autoreload 2
import os
import pickle as pkl
from typing import Dict, Any

import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 250

# change working directory to project root
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('../..')
from experiments.notebooks import viz
from experiments.util import get_clean_dataset
    
MODEL_COMPARISON_PATH = 'experiments/comparison_data/'

datasets = [
        ("breast-cancer", 13),
        ("breast-w", 15),
        ("credit-g", 31),
        ("haberman", 43),
        ("heart", 1574),
        ("labor", 4),
        ("vote", 56),
    
    ]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# dataset stats

In [22]:
metadata = []
columns = ['Name', 'Samples', 'Features', 'Class 0', 'Class 1', 'Majority class %']
for dataset_name, data_id in [("breast-cancer", 13),]:

        dataset = fetch_openml(data_id=data_id, as_frame=False)
        X, y = dataset.data, dataset.target
        X = np.nan_to_num(X, 0)
        shape = X.shape
        class_counts = np.unique(y, return_counts=True)[1]
        metadata.append([dataset_name.capitalize(), shape[0], shape[1], class_counts[0], class_counts[1],
                         np.round(100 * np.max(class_counts) / np.sum(class_counts), decimals=1)])

metadata = pd.DataFrame(metadata, columns=columns) #.set_index('Name')

In [23]:
print(metadata.to_latex(index=False))

\begin{tabular}{lrrrrr}
\toprule
         Name &  Samples &  Features &  Class 0 &  Class 1 &  Majority class \% \\
\midrule
Breast-cancer &      286 &         9 &      201 &       85 &              70.3 \\
\bottomrule
\end{tabular}



In [24]:
from experiments.models.supercart import  SuperCART
from sklearn.model_selection import cross_validate

In [25]:
m = SuperCART()

In [26]:
m.fit(X, y)

SuperCART()

In [27]:
cross_validate(m, X, y, cv=3, scoring='accuracy')

{'fit_time': array([0.0014832 , 0.00312304, 0.00114727]),
 'score_time': array([0.00567698, 0.00265789, 0.00114465]),
 'test_score': array([0.70833333, 0.67368421, 0.72631579])}