## You can instantiate and fit a PivotTreeClassifier instance as a RuleTreeClassifier employing as base stumps Pivot-Based Stumps

In [59]:
from ruletree.tree.RuleTreeClassifier import RuleTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, pairwise_distances

breast = load_breast_cancer()
feature_names = breast.feature_names
X = breast.data
y = breast.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

X_train_matrix = pairwise_distances(X_train, metric = 'euclidean')

In [60]:
#import a base stump for pivot tree
from ruletree.stumps.instance_stumps import pt_stump_call
pt_stump = pt_stump_call()

In [61]:
pt_stump

In [62]:
#use as base_stumps a list with PivotTreeStumpClassifier
#specify a distance_measure

#adding a pre-computed distance matrix is possible, otherwise one is computed at fitting time

rt_pivot = RuleTreeClassifier(max_depth = 3,
                              criterion = 'gini',
                              prune_useless_leaves=True,
                              random_state = 42,
                              distance_measure='euclidean',
                              distance_matrix= X_train_matrix,
                              base_stumps=[pt_stump])
                              

In [63]:
rt_pivot.fit(X_train,y_train)

In [64]:
y_pred =rt_pivot.predict(X_test)
print('Prediction results', y_pred[0:10])

Prediction results [1 0 0 1 1 0 0 0 1 1]


In [65]:
print('Accuracy result: ', accuracy_score(y_test, y_pred))

Accuracy result:  0.9590643274853801


In [66]:
y_pred_proba = rt_pivot.predict_proba(X_test)
print('Prediction probability results', y_pred_proba[0:10])

Prediction probability results [[0.05179283 0.94820717]
 [0.99130435 0.00869565]
 [0.99130435 0.00869565]
 [0.05179283 0.94820717]
 [0.05179283 0.94820717]
 [0.99130435 0.00869565]
 [0.99130435 0.00869565]
 [0.99130435 0.00869565]
 [0.05179283 0.94820717]
 [0.05179283 0.94820717]]


In [67]:
rules = rt_pivot.get_rules()

In [68]:
rt_pivot.print_rules(rules)

|--- P_324 <= 3838.697	398
|   |--- P_252 <= 61.638	121
|   |   |--- P_252 <= 11.482	4
|   |   |    output: 0
|   |   |--- P_252 > 11.482
|   |   |    output: 1
|   |--- P_252 > 61.638
|   |   |--- P_92 <= 55.931	117
|   |   |    output: 1
|   |   |--- P_92 > 55.931
|   |   |    output: 0
|--- P_324 > 3838.697
|   |--- P_303 <= 161.668	277
|   |    output: 0
|   |--- P_303 > 161.668
|   |   |--- P_308 <= 17.541	253
|   |   |    output: 0
|   |   |--- P_308 > 17.541
|   |   |    output: 1
