# Imports

In [1]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

In [2]:
from rule_estimator import *

In [3]:
X, y = load_iris(return_X_y=True, as_frame=True)
X.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2


In [4]:
load_iris().target_names

array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

# Instantiate RuleClassifier

In [5]:
model = RuleClassifier(
    LesserThanNode("petal length (cm)", 1.9, 
               if_true=DefaultRule(0), 
               if_false=CaseWhen([
                    LesserThan("petal length (cm)", 4.5, 1),
                    GreaterThan("petal length (cm)", 5.1, 2),
                    LesserThan("petal width (cm)", 1.4, 1),
                    GreaterThan("petal width (cm)", 1.8, 2),
                ], default=1),
    ), 
    default=2
)

In [6]:
print(classification_report(y, model.predict(X)))

              precision    recall  f1-score   support

           0       1.00      0.96      0.98        50
           1       0.83      1.00      0.91        50
           2       1.00      0.84      0.91        50

    accuracy                           0.93       150
   macro avg       0.94      0.93      0.93       150
weighted avg       0.94      0.93      0.93       150



In [7]:
print(model.describe())

RulesClassifier
   BinaryDecisionNode petal length (cm) < 1.9
     Always predict 0
      Default: 0 
     CaseWhen
        If petal length (cm) < 4.5 then predict 1
        If petal length (cm) > 5.1 then predict 2
        If petal width (cm) < 1.4 then predict 1
        If petal width (cm) > 1.8 then predict 2
      Default: 1 
 Default: 2 



# Storing model to `.yaml`

In [9]:
print(model.to_yaml())

# RulesClassifier
#    BinaryDecisionNode petal length (cm) < 1.9
#      Always predict 0
#       Default: 0 
#      CaseWhen
#         If petal length (cm) < 4.5 then predict 1
#         If petal length (cm) > 5.1 then predict 2
#         If petal width (cm) < 1.4 then predict 1
#         If petal width (cm) > 1.8 then predict 2
#       Default: 1 
#  Default: 2 
__storable__:
  module: rule_estimator.core
  name: RuleClassifier
  params:
    rules:
      __storable__:
        module: rule_estimator.business_rules
        name: LesserThanNode
        params:
          col: petal length (cm)
          cutoff: 1.9
          if_true:
            __storable__:
              module: rule_estimator.business_rules
              name: DefaultRule
              params:
                default: 0
          if_false:
            __storable__:
              module: rule_estimator.business_rules
              name: CaseWhen
              params:
                rules:
                - __storable_

In [10]:
model.to_yaml("iris_rules.yaml")

In [11]:
loaded_model = RuleClassifier.from_yaml("iris_rules.yaml")

In [12]:
loaded_model.predict(X)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
       2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       1., 2., 2., 2., 1., 2., 2., 1., 1., 2., 2., 2., 2., 2., 1., 2., 2.,
       2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.])

# Including a `final_estimator`

In [13]:
model2 = RuleClassifier(
    LesserThanNode("petal length (cm)", 1.9, 
               if_true=DefaultRule(0), 
               if_false=CaseWhen([
                    LesserThan("petal length (cm)", 4.5, 1),
                    GreaterThan("petal length (cm)", 5.1, 2),
                    LesserThan("petal width (cm)", 1.4, 1),
                    GreaterThan("petal width (cm)", 1.8, 2),
                ]),
    ), 
    final_estimator=DecisionTreeClassifier()
)

In [14]:
model2.fit(X, y)
print(classification_report(y, model2.predict(X)))

Fitting final_estimator...
              precision    recall  f1-score   support

           0       1.00      0.96      0.98        50
           1       0.96      1.00      0.98        50
           2       1.00      1.00      1.00        50

    accuracy                           0.99       150
   macro avg       0.99      0.99      0.99       150
weighted avg       0.99      0.99      0.99       150



In [15]:
print(model2.describe())

RulesClassifier
   BinaryDecisionNode petal length (cm) < 1.9
     Always predict 0
      Default: 0 
     CaseWhen
        If petal length (cm) < 4.5 then predict 1
        If petal length (cm) > 5.1 then predict 2
        If petal width (cm) < 1.4 then predict 1
        If petal width (cm) > 1.8 then predict 2
final_estimator = DecisionTreeClassifier()

