In [1]:
%load_ext autoreload

%autoreload 2

# Imports

In [2]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

In [3]:
from rule_estimator import *

In [4]:
X, y = load_iris(return_X_y=True, as_frame=True)
X.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2


In [5]:
load_iris().target_names

array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

# Instantiate RuleClassifier

In [6]:
model = RuleClassifier(
    LesserThanNode("petal length (cm)", 1.9, 
               if_true=DummyRule(0), 
               if_false=CaseWhen([
                    LesserThan("petal length (cm)", 4.5, 1),
                    GreaterThan("petal length (cm)", 5.1, 2),
                    LesserThan("petal width (cm)", 1.4, 1),
                    GreaterThan("petal width (cm)", 1.8, 2),
                ], default=1),
    ), 
    default=2
)

In [20]:
model = RuleClassifier(
    LesserThanNode("petal length (cm)", 1.9, # BinaryDecisionNode
        if_true=DummyRule(default=0), # DummyRule: always predict 0
        if_false=CaseWhen([
            # Go through these rules and if one applies, assign the prediction
            LesserThan("petal length (cm)", 4.5, prediction=1),
            GreaterThan("petal length (cm)", 5.1, prediction=2),
            LesserThan("petal width (cm)", 1.4, prediction=1),
            GreaterThan("petal width (cm)", 1.8, prediction=2),
        ], 
        default=1 # if no rule applies, assign prediction=1
        ),
    ), 
    default=2 # If no rule applied, assign prediction=2
)

In [21]:
print(classification_report(y, model.predict(X)))

              precision    recall  f1-score   support

           0       1.00      0.96      0.98        50
           1       0.83      1.00      0.91        50
           2       1.00      0.84      0.91        50

    accuracy                           0.93       150
   macro avg       0.94      0.93      0.93       150
weighted avg       0.94      0.93      0.93       150



In [22]:
print(model.describe())

RulesClassifier
   BinaryDecisionNode petal length (cm) < 1.9
     DummyRule: Always predict 0
      Default: 0 
     CaseWhen
        If petal length (cm) < 4.5 then predict 1
        If petal length (cm) > 5.1 then predict 2
        If petal width (cm) < 1.4 then predict 1
        If petal width (cm) > 1.8 then predict 2
      Default: 1 
 Default: 2 



# Storing model to `.yaml`

In [23]:
print(model.to_yaml())

# RulesClassifier
#    BinaryDecisionNode petal length (cm) < 1.9
#      DummyRule: Always predict 0
#       Default: 0 
#      CaseWhen
#         If petal length (cm) < 4.5 then predict 1
#         If petal length (cm) > 5.1 then predict 2
#         If petal width (cm) < 1.4 then predict 1
#         If petal width (cm) > 1.8 then predict 2
#       Default: 1 
#  Default: 2 
__businessrule__:
  module: rule_estimator.core
  name: RuleClassifier
  description: RulesClassifier
  params:
    rules:
      __businessrule__:
        module: rule_estimator.business_rules
        name: LesserThanNode
        description: BinaryDecisionNode petal length (cm) < 1.9
        params:
          col: petal length (cm)
          cutoff: 1.9
          if_true:
            __businessrule__:
              module: rule_estimator.business_rules
              name: DummyRule
              description: 'DummyRule: Always predict 0'
              params:
                default: 0
          if_false:
        

In [13]:
model.to_yaml("iris_rules.yaml")

In [14]:
loaded_model = RuleClassifier.from_yaml("iris_rules.yaml")

In [15]:
loaded_model.predict(X)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2.,
       2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       1., 2., 2., 2., 1., 2., 2., 1., 1., 2., 2., 2., 2., 2., 1., 2., 2.,
       2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.])

# Including a `final_estimator`

In [24]:
rules_plus_final_estimator = RuleClassifier(
    LesserThanNode("petal length (cm)", 1.9, 
               if_true=DummyRule(0), 
               if_false=CaseWhen([
                    LesserThan("petal length (cm)", 4.5, 1),
                    GreaterThan("petal length (cm)", 5.1, 2),
                    LesserThan("petal width (cm)", 1.4, 1),
                    GreaterThan("petal width (cm)", 1.8, 2),
                ]),
    ), 
    final_estimator=DecisionTreeClassifier(),
    fit_remaining_only=False
)

In [25]:
rules_plus_final_estimator.fit(X, y)
print(classification_report(y, rules_plus_final_estimator.predict(X)))

Fitting final_estimator...
              precision    recall  f1-score   support

           0       1.00      0.96      0.98        50
           1       0.96      1.00      0.98        50
           2       1.00      1.00      1.00        50

    accuracy                           0.99       150
   macro avg       0.99      0.99      0.99       150
weighted avg       0.99      0.99      0.99       150



In [19]:
print(rules_plus_final_estimator.describe())

RulesClassifier
   BinaryDecisionNode petal length (cm) < 1.9
     DummyRule: Always predict 0
      Default: 0 
     CaseWhen
        If petal length (cm) < 4.5 then predict 1
        If petal length (cm) > 5.1 then predict 2
        If petal width (cm) < 1.4 then predict 1
        If petal width (cm) > 1.8 then predict 2
final_estimator = DecisionTreeClassifier()

