In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import *

from sklearn.metrics import accuracy_score
from src.ShapeCARTClassifier import ShapeCARTClassifier

# supress settingwithcopy warning
pd.options.mode.chained_assignment = None
from src.data_utils import *
from sklearn.datasets import load_breast_cancer

import six
import sys
sys.modules['sklearn.externals.six'] = six

np.float = float

In [None]:
dataset = 'bean'
data_factory = DataFactory_clf(dataset=dataset, cache=False)
X_train, y_train, X_val, y_val, X_test, y_test = data_factory.get_data(0)
feature_dict = data_factory.feature_dict
n_classes = len(np.unique(y_train))
print(f"n_classes: {n_classes}")

In [None]:
y_train

In [None]:
for k,v in feature_dict.items():
    print(f"{k}: {v}")

In [None]:

clf = ShapeCARTClassifier(
    max_depth=2,
    criterion='gini',
    min_samples_split=16,
    min_impurity_decrease=0.0,
    inner_max_leaf_nodes=32,
    random_state=42,
    k=2,
    verbose=False
)
clf.fit(X_train, y_train, feature_dict=feature_dict)
train_pred = clf.predict(X_train)
val_pred = clf.predict(X_val)
test_pred = clf.predict(X_test)
train_accuracy = accuracy_score(y_train, train_pred)
val_accuracy = accuracy_score(y_val, val_pred)
test_accuracy = accuracy_score(y_test, test_pred)
for idx, node in enumerate(clf.nodes):
    if node is not None:
        print(f"{idx}: {node.final_key}")
        print(f"\t {clf.children[idx]}")
print(f"train_accuracy: {train_accuracy}")
print(f"val_accuracy: {val_accuracy}")
# 82.77
print(f"test_accuracy: {test_accuracy}")

In [None]:
X_train

In [None]:
nodes = clf.nodes
import os
from matplotlib.gridspec import GridSpec
from matplotlib.colors import ListedColormap, BoundaryNorm
dir = 'visualizations/ShapeCART'
_path = os.path.join(dir, dataset)
if not os.path.exists(_path):
    os.makedirs(_path)

cmap = ListedColormap(['#D98CB9', '#FFA860'])

# # boundary at 0.5: values ≤0.5 get color[0], >0.5 get color[1]
norm = BoundaryNorm(boundaries=[-0.5, 0.5, 1.5], ncolors=cmap.N)
discontinuities = 0
for idx, node in enumerate(nodes):
    print(f"Node {idx}")
    if node is None:
        continue
    key = node.final_key
    tree = node.final_tree
    idxs = feature_dict[key]
    if len(idxs) > 1:
        discontinuities += 1
    
        
    else:
        tree_thresholds = tree.tree_.threshold
        non_child = np.where(tree.tree_.children_left != -1)[0]
        tree_thresholds = tree_thresholds[non_child]
        thresholds_pos  = tree_thresholds + 1e-5
        thresholds_neg  = tree_thresholds - 1e-5
        stacked = np.concatenate((thresholds_pos, thresholds_neg, tree_thresholds))

        stacked = stacked.reshape(-1, 1)
        rel_col = X_train[:,feature_dict[key]]
        rel_points = clf.point_idxs[idx]
        n_samples = len(rel_points) / len(rel_col)
        rel_col = rel_col[rel_points]
        # viz_info = node.viz_store[key]
        sorted_thresholds = np.sort(tree_thresholds)
        # rel_col = rel_col[(rel_col > np.quantile(rel_col, 0.000)) & (rel_col < np.quantile(rel_col, 0.999))]
        min_rel_col = np.min(rel_col)
        max_rel_col = np.max(rel_col)
        rel_col = rel_col.reshape(-1, 1)
        
        stacked = np.vstack((rel_col, stacked))
        linspace_ = np.linspace(min_rel_col, max_rel_col, 1000)
        linspace_ = linspace_.reshape(-1, 1)
        stacked = np.vstack((linspace_, stacked))
        stacked = np.sort(stacked, axis=0) 
        tree_apply = tree.apply(stacked)
        mapping = node.mapping  
        branching = np.array([mapping[tree_apply[i]] for i in range(len(tree_apply))]).reshape(-1, 1)
         
        rel_col_df = pd.DataFrame(rel_col, columns=[key])

        comb = np.hstack((stacked, branching))
        comb_df = pd.DataFrame(comb, columns=[key, 'branching'])
        comb_df = comb_df.sort_values(by=key)
        # find the number of times branching changes
        branching_changes = comb_df['branching'].ne(comb_df['branching'].shift()).sum() -1 # subtract 1 for the first element
        print(f"branching changes: {branching_changes}")
        discontinuities += branching_changes

    
        fig, ax = plt.subplots(nrows=1, figsize=(5, 2))
        x = comb_df[key]
        y = comb_df['branching']
        # draw a band where branching == 0
        ax.fill_between(x, 0, 1,
                        where=(y == 0),
                        step='post',
                        alpha=0.8,
                        color='#D98CB9',
                        label='Left')
        # draw a band where branching == 1
        # ax.fill_between(x, 0, 1,
        #                 where=(y == 1),
        #                 step='post',
        #                 alpha=0.8,
        #                 color='#67AB9F',
        #                 label='Center')
        ax.fill_between(x, 0, 1,
                        where=(y == 1),
                        step='post',       # ensures clean vertical edges
                        alpha=0.8,
                        color='#FFA860',
                        label='Right')
        change_mask = y.ne(y.shift()).fillna(True)   # True wherever y != y.shift(), plus first element
        boundary_x = x[change_mask].values          # numpy array of X positions at those boundaries
        std_x = np.std(x) / 5
        min_x = np.min(x).astype(np.float64)
        max_x = np.max(x).astype(np.float64)
        min_boundary =np.maximum(boundary_x.min().astype(np.float64) - std_x, min_x)
        max_boundary =np.minimum(boundary_x.max().astype(np.float64) + std_x, max_x)

        print(boundary_x.min(), boundary_x.max())
        ax.set_xlim(min_boundary, max_boundary)

        # min_boundary_x = np.max(boundary_x.min(), x.min())
        # max_boundary_x = np.min(boundary_x.max(), x.max())
        # break
        # # set those as your xticks
        # ax.set_xticks(boundary_x)
        # ax.set_xticklabels([f"{val: 0.3f}" for val in boundary_x],  ha='right', rotation=30, fontsize=14)

        # ax.set_title(f"Node {idx}", fontsize=14)
        # # remove yticks
        ax.set_yticks([])

        # # increase x tick font size
        ax.tick_params(axis='x', labelsize=14)
        # ax.set_xlim(x.min(), x.max())
        ax.set_xlabel(key, fontsize=14)
        # leg = ax.legend(
        #     loc="lower center",
        #     bbox_to_anchor=(0.5, 1.0),
        #     ncol=6,
        #     frameon=False,
        #     fontsize=14,
        # )
        fig.suptitle(f"Node {idx}", fontsize=14, y = 0.9)
        fig.tight_layout()
        fig.savefig(os.path.join(_path, f"node_{idx}.png"), dpi=800, format='png')
        plt.close(fig)
    

In [None]:
# nodes = clf.nodes
# import os
# from matplotlib.gridspec import GridSpec
# from matplotlib.colors import ListedColormap, BoundaryNorm

# cmap = ListedColormap(['#D98CB9', '#FFA860'])

# # boundary at 0.5: values ≤0.5 get color[0], >0.5 get color[1]
# norm = BoundaryNorm(boundaries=[-0.5, 0.5, 1.5], ncolors=cmap.N)

# dir = 'visualizations/Shape2CART'
# _path = os.path.join(dir, dataset)
# if not os.path.exists(_path):
#     os.makedirs(_path)
# for idx, node in enumerate(nodes):
#     if node is None:
#         continue
#     key1 = node.final_key[0]
#     key2 = node.final_key[1]
#     rel_idx1 = feature_dict[key1][0]
#     rel_idx2 = feature_dict[key2][0]
#     X_1 = X_train.values[:, rel_idx1]
#     X_2 = X_train.values[:, rel_idx2]
#     # percentiles_x1 = np.percentile(X_1, np.linspace(1, 100, 400))
#     # percentiles_x2 = np.percentile(X_2, np.linspace(1, 100, 400))
#     percentiles_x1 = np.linspace(X_1.min(), X_1.max(), 500)
#     percentiles_x2 = np.linspace(X_2.min(), X_2.max(), 500)

#     # percentiles_x1 = np.unique(percentiles_x1).astype(np.float32).round(1)
#     # percentiles_x2 = np.unique(percentiles_x2).astype(np.float32).round(1)
#     x_1_n = len(percentiles_x1)
#     x_2_n = len(percentiles_x2)

#     X1, X2 = np.meshgrid(percentiles_x1, percentiles_x2, indexing='ij')    # both shape (n, m)
#     Z1 = np.stack([X1.ravel(), X2.ravel()], axis=1) # shape (n*m, 2)
#     tree = node.final_tree
#     pred_ = tree.apply(Z1)
#     branching = node.mapping[pred_]
#     branching = branching.reshape(X1.shape)
#     percentiles_x1 = percentiles_x1.astype(np.float32).round(1)
#     percentiles_x2 = percentiles_x2.astype(np.float32).round(1)
#     branching = pd.DataFrame(branching, index=percentiles_x1, columns=percentiles_x2)
#     branching.sort_index(inplace=True, ascending=False)
#     fig, ax = plt.subplots(nrows=1, figsize=(5, 5))
#     sns.heatmap(branching, cbar = False, cmap=cmap, norm=norm, ax=ax)
#     ax.set_xlabel(key2, fontsize=14)
#     ax.set_ylabel(key1, fontsize=14)
#     ax.set_title(f"Node {idx} - {node.final_key}")
#     fig.tight_layout()
#     fig.savefig(os.path.join(_path, f"node_{idx}.png"), dpi=800, format='png')
#     plt.close(fig)

In [None]:
# nodes = clf.nodes
# import os
# from matplotlib.gridspec import GridSpec

# dir = 'visualizations/ShapeCART'
# _path = os.path.join(dir, dataset)
# if not os.path.exists(_path):
#     os.makedirs(_path)


# for idx, node in enumerate(nodes):
#     print(f"Node {idx}")
#     if node is None:
#         continue
#     key = node.final_key
#     tree = node.final_tree
#     value = clf.values[idx].round(2)
#     tree_thresholds = tree.tree_.threshold
#     non_child = np.where(tree.tree_.children_left != -1)[0]
#     tree_thresholds = tree_thresholds[non_child]
#     thresholds_pos  = tree_thresholds + 1e-5
#     thresholds_neg  = tree_thresholds - 1e-5
#     stacked = np.concatenate((thresholds_pos, thresholds_neg, tree_thresholds))

#     stacked = stacked.reshape(-1, 1)
#     rel_col = X_train.iloc[:,feature_dict[key]].values
    
#     rel_points = clf.point_idxs[idx]
#     n_samples = len(rel_points) / len(rel_col)
#     rel_col = rel_col[rel_points]
#     # viz_info = node.viz_store[key]
#     sorted_thresholds = np.sort(tree_thresholds)
#     # rel_col = rel_col[(rel_col > np.quantile(rel_col, 0.000)) & (rel_col < np.quantile(rel_col, 0.999))]
#     min_rel_col = np.min(rel_col)
#     max_rel_col = np.max(rel_col)
#     rel_col = rel_col.reshape(-1, 1)
    
#     stacked = np.vstack((rel_col, stacked))
#     linspace_ = np.linspace(min_rel_col, max_rel_col, 1000)
#     linspace_ = linspace_.reshape(-1, 1)
#     stacked = np.vstack((linspace_, stacked))
#     stacked = np.sort(stacked, axis=0) 
#     tree_apply = tree.apply(stacked)
#     mapping = node.mapping  
#     branching = np.array([mapping[tree_apply[i]] for i in range(len(tree_apply))]).reshape(-1, 1)
#     rel_col_df = pd.DataFrame(rel_col, columns=[key])

#     comb = np.hstack((stacked, branching))
#     comb_df = pd.DataFrame(comb, columns=[key, 'branching'])
#     comb_df = comb_df.sort_values(by=key)
#     # sns.lineplot(data=comb_df, x=key, y='branching', ax=axs[0], linewidth=4)


#     fig, ax = plt.subplots(nrows=1, figsize=(5, 2))
#     x = comb_df[key]
#     y = comb_df['branching']
#     # draw a band where branching == 0
#     ax.fill_between(x, 0, 1,
#                     where=(y == 0),
#                     step='post',
#                     alpha=0.8,
#                     color='#D98CB9',
#                     label='Left')
#     # draw a band where branching == 1
#     # ax.fill_between(x, 0, 1,
#     #                 where=(y == 1),
#     #                 step='post',
#     #                 alpha=0.8,
#     #                 color='#67AB9F',
#     #                 label='Center')
#     ax.fill_between(x, 0, 1,
#                     where=(y == 1),
#                     step='post',       # ensures clean vertical edges
#                     alpha=0.8,
#                     color='#FFA860',
#                     label='Right')
#     change_mask = y.ne(y.shift()).fillna(True)   # True wherever y != y.shift(), plus first element
#     boundary_x = x[change_mask].values          # numpy array of X positions at those boundaries
#     std_x = np.std(x) / 5
#     min_x = np.min(x).astype(np.float64)
#     max_x = np.max(x).astype(np.float64)
#     min_boundary =np.maximum(boundary_x.min().astype(np.float64) - std_x, min_x)
#     max_boundary =np.minimum(boundary_x.max().astype(np.float64) + std_x, max_x)

#     print(boundary_x.min(), boundary_x.max())
#     ax.set_xlim(min_boundary, max_boundary)

#     # min_boundary_x = np.max(boundary_x.min(), x.min())
#     # max_boundary_x = np.min(boundary_x.max(), x.max())
#     # break
#     # # set those as your xticks
#     # ax.set_xticks(boundary_x)
#     # ax.set_xticklabels([f"{val: 0.3f}" for val in boundary_x],  ha='right', rotation=30, fontsize=14)

#     # ax.set_title(f"Node {idx}", fontsize=14)
#     # # remove yticks
#     ax.set_yticks([])
#     # # increase x tick font size
#     ax.tick_params(axis='x', labelsize=14)
#     # ax.set_xlim(x.min(), x.max())
#     ax.set_xlabel(key, fontsize=14)
#     # leg = ax.legend(
#     #     loc="lower center",
#     #     bbox_to_anchor=(0.5, 1.0),
#     #     ncol=6,
#     #     frameon=False,
#     #     fontsize=14,
#     # )
#     fig.suptitle(f"Node {idx}", fontsize=14, y = 0.9)
#     fig.tight_layout()
#     fig.savefig(os.path.join(_path, f"node_{idx}.png"), dpi=800, format='png')
#     plt.close(fig)
    

In [None]:
# {'date': [0],
#  'period': [1],
#  'nswprice': [2],
#  'nswdemand': [3],
#  'vicprice': [4],
#  'vicdemand': [5],
#  'transfer': [6],
#  'day': [7, 8, 9, 10, 11, 12]}
# array([0, 1])
# n_classes: 2
# 0: nswprice
# 	 [1, 2]
# 1: date
# 	 [5, 6]
# 2: nswprice
# 	 [3, 4]
# train_accuracy: 0.7572356390692981
# val_accuracy: 0.7645111454425072
# test_accuracy: 0.7554893523115966

In [None]:
bean_types = ["Seker", "Barbunya", "Bombay", "Cali", "Dermosan", "Horoz", "Sira"]
bean_types = ["Irrelevant", "Relevant", "Correct"]
bean_types = ["<=50k", ">50k"]
bean_types = ["Edible", "Poisonous"]

for i,val in enumerate(clf.values):
    print(f"Node {i}:", val.round(3),"-> argmax:", np.argmax(val), "->", bean_types[np.argmax(val)])