In [1]:
# pip install -U ruletree

In [2]:
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
from IPython.display import IFrame, JSON
from RuleTree import RuleTreeClassifier
from RuleTree.encoding.ruletree_to_jakowski_tree_encoder import ruletree_to_jakowski, jakowski_to_ruletree

In [3]:
iris = datasets.load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target

X = iris_df.drop('target', axis=1).values
y = iris_df['target'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

In [4]:
rt = RuleTreeClassifier(max_depth=8, prune_useless_leaves=False, random_state=0)
rt.fit(X_train, y_train)
#rt.root.simplify()

print(classification_report(y_pred=rt.predict(X_test), y_true=y_test))

y_pred_before_encoding = rt.predict(X_test)
y_pred_proba_before_encoding = rt.predict_proba(X_test)

rt.export_graphviz(filename="demo")#, columns_names=iris_df.columns[:-1])
IFrame("demo.pdf", width=600, height=300)

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      0.94      0.97        18
           2       0.92      1.00      0.96        11

    accuracy                           0.98        45
   macro avg       0.97      0.98      0.98        45
weighted avg       0.98      0.98      0.98        45



In [5]:
json = rt.to_dict(filename="demo.json")



In [18]:
json

{'tree_type': 'RuleTree.tree.RuleTreeClassifier',
 'args': {'max_leaf_nodes': inf,
  'min_samples_split': 2,
  'max_depth': 8,
  'prune_useless_leaves': False,
  'base_stumps': None,
  'stump_selection': 'random',
  'random_state': 0},
 'classes_': [0, 1, 2],
 'n_classes_': 3,
 'nodes': [{'node_id': 'R',
   'is_leaf': False,
   'prediction': 2,
   'prediction_probability': [0.3238095238095238,
    0.3047619047619048,
    0.37142857142857144],
   'prediction_classes_': [0, 1, 2],
   'left_node': 'Rl',
   'right_node': 'Rr',
   'feature_idx': 3,
   'threshold': 0.75,
   'is_categorical': None,
   'samples': 105,
   'feature_name': 'X_3',
   'textual_rule': 'X_3 <= 0.75\t105',
   'blob_rule': 'X_3 <= 0.75',
   'graphviz_rule': {'label': 'X_3 ≤ 0.75'},
   'not_textual_rule': 'X_3 > 0.75',
   'not_blob_rule': 'X_3 > 0.75',
   'not_graphviz_rule': {'label': 'X_3 > 0.75'},
   'stump_type': 'RuleTree.stumps.classification.DecisionTreeStumpClassifier',
   'impurity': 0.6643083900226757,
   'arg

In [6]:
enc = ruletree_to_jakowski(json)
enc

array([[ 4.        ,  4.        ,  3.        ,  4.        ,  4.        ,
         4.        ,  4.        ,  4.        ,  4.        ,  4.        ,
         4.        ,  4.        ,  2.        ,  4.        ,  4.        ,
        -1.        , -1.        , -1.        , -1.        , -1.        ,
        -1.        , -1.        , -1.        , -1.        , -1.        ,
        -1.        , -1.        , -1.        , -1.        , -1.        ,
        -1.        ],
       [ 0.75      ,  0.75      ,  4.95000005,  0.75      ,  0.75      ,
         1.65000004,  1.75      ,  0.75      ,  0.75      ,  0.75      ,
         0.75      ,  1.65000004,  3.10000002,  1.65000004,  1.75      ,
         1.        ,  1.        ,  1.        ,  1.        ,  1.        ,
         1.        ,  1.        ,  1.        ,  2.        ,  2.        ,
         3.        ,  2.        ,  3.        ,  2.        ,  3.        ,
         3.        ]])

In [7]:
rt_2 = RuleTreeClassifier.from_dict(jakowski_to_ruletree(enc))
rt_2.classes_ = rt.classes_

RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier
RuleTree.stumps.classification.DecisionTreeStumpClassifier


In [8]:
jakowski_to_ruletree(enc)

{'tree_type': 'RuleTree.tree.RuleTreeClassifier',
 'nodes': [{'node_id': 'R',
   'stump_type': 'RuleTree.stumps.classification.DecisionTreeStumpClassifier',
   'feature_idx': 3.0,
   'threshold': 0.75,
   'is_leaf': False,
   'left_node': 'Rl',
   'right_node': 'Rr',
   'is_categorical': False},
  {'node_id': 'Rl',
   'stump_type': 'RuleTree.stumps.classification.DecisionTreeStumpClassifier',
   'feature_idx': 3.0,
   'threshold': 0.75,
   'is_leaf': False,
   'left_node': 'Rll',
   'right_node': 'Rlr',
   'is_categorical': False},
  {'node_id': 'Rr',
   'stump_type': 'RuleTree.stumps.classification.DecisionTreeStumpClassifier',
   'feature_idx': 2.0,
   'threshold': 4.950000047683716,
   'is_leaf': False,
   'left_node': 'Rrl',
   'right_node': 'Rrr',
   'is_categorical': False},
  {'node_id': 'Rll',
   'stump_type': 'RuleTree.stumps.classification.DecisionTreeStumpClassifier',
   'feature_idx': 3.0,
   'threshold': 0.75,
   'is_leaf': False,
   'left_node': 'Rlll',
   'right_node': '

In [9]:
rt_2.export_graphviz(filename="demo_complete")
IFrame("demo_complete.pdf", width=600, height=300)

In [10]:
rt_2.root = rt_2.root.simplify()

In [11]:
rt_2.export_graphviz(filename="demo_complete_simplified")
IFrame("demo_complete_simplified.pdf", width=600, height=300)

In [19]:
print(classification_report(y_pred=rt_2.predict(X_test), y_true=y_test))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      0.94      0.97        18
           2       0.92      1.00      0.96        11

    accuracy                           0.98        45
   macro avg       0.97      0.98      0.98        45
weighted avg       0.98      0.98      0.98        45



In [20]:
y_pred_after_encoding = rt_2.predict(X_test)
y_pred_proba_after_encoding = rt_2.predict_proba(X_test)

In [21]:
np.allclose(y_pred_before_encoding, y_pred_after_encoding), np.allclose(y_pred_proba_before_encoding, y_pred_proba_after_encoding)

(True, False)

In [22]:
y_pred_proba_before_encoding

array([[0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.]])