In [78]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.tree import export_text
from sklearn.model_selection import train_test_split

In [79]:
def get_rules(tree, tree_id: int , features: list, sas_table: str, max_depth=100, spacing=2):
    """ 
    Extract the rules of a decision tree and translate them to SAS code.
    Create a SAS dataset representing those rules.
    
    
    Parameters:
    -----------
    tree: sklearn DecisionTreeClassifier
        the tree whose decision rules we want to extract
    tree_id: int
        tree identifier (0 to numbers of trees -1)
    features: list
        list of model features
    sas_table: str
        name of the SAS dataset containing the decision tree features
    max_depth: int
        number of levels of the tree considered (default is 100 - must be greater than 1)
        more information : https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_text.html
    spacing: int
        number of spaces between edges (default is 2)
        more information : https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_text.html      
    """
    if spacing < 2:
        raise ValueError('spacing must be > 1')
    # export decision tree to text, using sklearn.tree function
    rules = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)
    # translate text to SAS code
    rules_in_sas = translate_text_to_sas(tree, tree_id, sas_table, features, rules, spacing)
    return rules_in_sas

In [80]:
def translate_text_to_sas(tree, tree_id: int, sas_table: str, features: list, text: str, spacing=2):
    """ 
    Translate tree rules to SAS code, into a dataset.
    
    
    Parameters:
    -----------
    tree: sklearn DecisionTreeClassifier
        the tree whose decision rules we want to extract
    tree_id: int
        tree identifier (0 to numbers of trees -1)
    features: list
        list of model features
    sas_table: str
        name of the SAS dataset containing the decision tree features
        a column "PREDICTED_VALUE_i" will be added to this dataset (i matches tree identifier)
    text: str
        rules obtained with export_text function
    spacing: int
        number of spaces between edges (default is 2)
        more information : https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_text.html      
    """
    skip, dash = ' '*spacing, '-'*(spacing-1) # handling spacing
    sas_rules = 'DATA DECISION_TREE_' + str(tree_id + 1) + ';\n'
    sas_rules += 'SET {};\n'.format(sas_table)
    splitted_rules = text.split('\n') # Make a list of rules
    dict_space_count = {} # keys: number of eft spaces, values: number of rows with this number
    current_elseif_count = [] # for handling "END;" ASSOCIATED WITH "ELSE IF" conditions. Add as many "END;" as nested "ELSE IFs"
    #add_end = 0
    # Iterate through rules
    for i, line in enumerate(splitted_rules):
        line = line.rstrip().replace('|',' ').replace('-', ' ')
        n_spaces = len(line) - len(line.lstrip(' ')) # get spaces from left
        # Update dictionary for handling whether IF or ELSE IF
        if str(n_spaces) not in dict_space_count.keys():
            dict_space_count[str(n_spaces)] = 1
        else: 
            dict_space_count[str(n_spaces)] += 1
        # Pair --> ELSE
        if 'class' in line:
            dict_space_count[str(n_spaces)] -= 1 # do not count rows where predicted value is computed
        front_add = get_front_add(dict_space_count[str(n_spaces)], n_spaces)
        add_end_front = ''
        if len(current_elseif_count):
            if n_spaces < current_elseif_count[-1]:
                add_end_front = ''
                while(n_spaces < current_elseif_count[-1]):
                    add_end_front += 'END;\n'
                    current_elseif_count.pop() # last element corresponds to current front spaces in line
                    if not len(current_elseif_count):
                        break
        if 'ELSE' in front_add and 'class' not in line:
            current_elseif_count.append(n_spaces)
        # Handling rows for IF conditions
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' '*n_spaces, ' '*n_spaces + front_add)
            line = '{} {:g} THEN DO;'.format(line, float(val))
        # Handling rows for PREDICTED_VALUE_i
        if 'class' in line:
            line = line.replace('class:', 'PREDICTED_VALUE_' + str(tree_id + 1) + ' =')
            line += ';'
            line += '\n'
        line = add_end_front + line
        sas_rules += skip + line + '\n'
    sas_rules = sas_rules[:-1]
    sas_rules += 'RUN;'
    return sas_rules

# String to add before row for IF or ELSE IF conditions
def get_front_add(count, n_spaces):
    """
        Docstring to be completed
    """
    if count%2 == 0:
        toRet = 'END;\n'+ ' '*(n_spaces+2)+ 'ELSE IF '
    else:
        toRet = 'IF '
    return toRet



In [81]:
def get_trees_rules(input_table, rfclassifier, labels, path, file):
    # nouvelle fonction qui exportera toutes les règles dans 1 seul fichier
    file = open(path + filename, "w")
    for index, tree in enumerate(rfclassifier):
        rules = get_rules(input_table, tree, index, labels)
        file.write(rules + "\n")
    file.close()

In [82]:
# Example 
iris = datasets.load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
# sklearn provides the iris species as integer values since this is required for classification
X_train, X_test, y_train, y_test = train_test_split(df[iris.feature_names], iris.target, test_size=0.5, stratify=iris.target, random_state=123456)
rf = RandomForestClassifier(n_estimators=100, oob_score=True, random_state=123456)
rf.fit(X_train, y_train)

RandomForestClassifier(oob_score=True, random_state=123456)

In [83]:
t = export_text(rf[0], feature_names=X_train.columns.to_list(), 
                        max_depth=100,
                        decimals=2,
                        spacing=3)

In [84]:
get_rules(tree=rf[0], tree_id=0, features=X_train.columns.to_list(), sas_table="DATASET")

'DATA DECISION_TREE_1;\nSET DATASET;\n     IF petal width (cm) <= 1.7 THEN DO;\n       IF petal length (cm) <= 2.45 THEN DO;\n         PREDICTED_VALUE_1 = 0.0;\n\n       END;\n       ELSE IF petal length (cm) > 2.45 THEN DO;\n         PREDICTED_VALUE_1 = 1.0;\n\n  END;\n   END;\n     ELSE IF petal width (cm) > 1.7 THEN DO;\n       IF sepal length (cm) <= 6 THEN DO;\n         IF sepal width (cm) <= 3.1 THEN DO;\n           PREDICTED_VALUE_1 = 2.0;\n\n         END;\n         ELSE IF sepal width (cm) > 3.1 THEN DO;\n           PREDICTED_VALUE_1 = 1.0;\n\n  END;\n     END;\n       ELSE IF sepal length (cm) > 6 THEN DO;\n         PREDICTED_VALUE_1 = 2.0;\n\n  END;\nEND;\nRUN;'