# Analysis of Random Forest Models
The purpose of this notebook is to analyse a trained random forest model to find out which features are used at all how they were splitted. The single decision trees can be traversed and most common features can be determined. 

In [18]:
import re
import itertools
from collections import deque

from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SQLContext

from pyspark.mllib.tree import RandomForestModel

In [2]:
conf = (SparkConf().setMaster("local[*]").setAppName('pyspark'))
sc = SparkContext(conf=conf)
sql_context = SQLContext(sc)

## Configuration
You can specifying the model to be analyzed by settings the ```MODEL_LOCATION``` variable. The path can point to a directory in the local file system, HDFS or S3. 

In [3]:
MODEL_LOCATION = "../models/random_forest_n5_d15/model_40.8_-73.95/"

In [4]:
FEATURE_MAPPING = ["Pickup_Count_Dis_1h", "Dropoff_Count_Dis_1h",
                   "Pickup_Count_Dis_4h", "Dropoff_Count_Dis_4h",
                   "Pickup_Count_Nb_1h", "Dropoff_Count_Nb_1h",
                   "Pickup_Count_Nb_4h", "Dropoff_Count_Nb_4h",
                   "Pickup_Count_Nyc_1h", "Dropoff_Count_Nyc_1h",
                   "Pickup_Count_Nyc_4h", "Dropoff_Count_Nyc_4h",] \
                    + ["Hour", "Day_Of_Week", "Day_Of_Year", "IsHoliday"] \
                    + [
                   "AWND_GHCND:US1NJBG0018", "AWND_GHCND:US1NYKN0003",
                   "AWND_GHCND:US1NYKN0025", "AWND_GHCND:US1NYNS0007",
                   "AWND_GHCND:US1NYQN0002", "AWND_GHCND:US1NYRC0001",
                   "AWND_GHCND:US1NYRC0002", "AWND_GHCND:USC00300961",
                   "AWND_GHCND:USW00014732", "AWND_GHCND:USW00094728",
                   "AWND_GHCND:USW00094789", "PRCP_GHCND:US1NJBG0018",
                   "PRCP_GHCND:US1NYKN0003", "PRCP_GHCND:US1NYKN0025",
                   "PRCP_GHCND:US1NYNS0007", "PRCP_GHCND:US1NYQN0002",
                   "PRCP_GHCND:US1NYRC0001", "PRCP_GHCND:US1NYRC0002",
                   "PRCP_GHCND:USC00300961", "PRCP_GHCND:USW00014732",
                   "PRCP_GHCND:USW00094728", "PRCP_GHCND:USW00094789",
                   "TMAX_GHCND:US1NJBG0018", "TMAX_GHCND:US1NYKN0003",
                   "TMAX_GHCND:US1NYKN0025", "TMAX_GHCND:US1NYNS0007",
                   "TMAX_GHCND:US1NYQN0002", "TMAX_GHCND:US1NYRC0001",
                   "TMAX_GHCND:US1NYRC0002", "TMAX_GHCND:USC00300961",
                   "TMAX_GHCND:USW00014732", "TMAX_GHCND:USW00094728",
                   "TMAX_GHCND:USW00094789", "TMIN_GHCND:US1NJBG0018",
                   "TMIN_GHCND:US1NYKN0003", "TMIN_GHCND:US1NYKN0025",
                   "TMIN_GHCND:US1NYNS0007", "TMIN_GHCND:US1NYQN0002",
                   "TMIN_GHCND:US1NYRC0001", "TMIN_GHCND:US1NYRC0002",
                   "TMIN_GHCND:USC00300961", "TMIN_GHCND:USW00014732",
                   "TMIN_GHCND:USW00094728", "TMIN_GHCND:USW00094789"] \
                    + ["Venue %d (0h)" % i for i in range(2434)] \
                    + ["Venue %d (-3)" % i for i in range(2434)] \
                    + ["Venue %d (-2)" % i for i in range(2434)] \
                    + ["Venue %d (-1)" % i for i in range(2434)] \
                    + ["Venue %d (1)" % i for i in range(2434)] \
                    + ["Venue %d (2)" % i for i in range(2434)] \
                    + ["Venue %d (3)" % i for i in range(2434)]

## Load Model & Parse Debug String
Since PySpark does not provide an API for traversing and analyzing a random forest model inherently, the debug string containing the structure of the random forest, is parsed. 

In [5]:
model = RandomForestModel.load(sc, MODEL_LOCATION)
debug_string = model.toDebugString().split('\n')[2:-1]
print("\n".join(debug_string[:10]) + "\n...")

  Tree 0:
    If (feature 10 <= 127478.0)
     If (feature 4 <= 304.0)
      If (feature 12 in {3.0,4.0,2.0,5.0,1.0,20.0,0.0,19.0,11.0,12.0,16.0,13.0})
       If (feature 4 <= 165.0)
        If (feature 15 in {2.0,1.0,3.0,0.0})
         If (feature 6 <= 478.0)
          If (feature 10 <= 24922.0)
           If (feature 12 in {2.0,3.0,4.0,0.0,1.0})
            If (feature 13 in {22.0,20.0,7.0,11.0,0.0,9.0,29.0,21.0,2.0,13.0,15.0})
...


In [6]:
class InternalNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, left, right):
        
        self.is_leaf = False
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.left = left
        self.right = right
        
        assert left.parent_split_feature == right.parent_split_feature
        assert left.parent_split_value == right.parent_split_value
        
        self.split_feature = left.parent_split_feature
        self.split_value = left.parent_split_value
        
    def __str__(self):
        
        return self.to_string(0)
        
    def to_string(self, indention):
        
        return (" " * indention) + "InternalNode: %s, %s" % (self.split_feature, self.split_value) + "\n" \
                + self.left.to_string(indention + 1) + "\n" \
                + self.right.to_string(indention + 1)
    
        
class LeafNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, prediction):
        
        self.is_leaf = True
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.prediction = prediction
        
    def to_string(self, indention):
        
        return (" " * indention) + "LeafNode: %s" % self.prediction

In [7]:
def split_debug_string(debug_string):
    indention = len(re.search(" +", debug_string[0]).group())

    result = []
    currentList = None
    for i in range(len(debug_string)):
        if not debug_string[i].startswith(" " * (indention + 1)):
            currentList = []
            result.append(currentList)
        currentList.append(debug_string[i])
    
    return result

In [8]:
tree_debug_strings = split_debug_string(debug_string)

In [9]:
def get_node_from_debug_string(debug_string):
    node_debug_string = debug_string[0].strip()
    if node_debug_string.startswith("Tree"):
        parent_split_feature = None
        parent_split_value = None
    else:
        match = re.match("(If|Else) \(feature (\d+) (<=|>|in|not in) (.+)\)", node_debug_string)
        if match is None:
            print node_debug_string
        feature_index = int(match.group(2))
        parent_split_feature = FEATURE_MAPPING[feature_index]
        parent_split_value = match.group(4)
    
    split = split_debug_string(debug_string[1:])
    if len(split) == 1:
        assert len(split[0]) == 1
        prediction_value = float(re.match("Predict: (-?\d+\.\d+)", split[0][0].strip()).group(1))
        return LeafNode(parent_split_feature, parent_split_value, prediction_value)
    
    assert len(split) == 2
    left_child = get_node_from_debug_string(split[0])
    right_child = get_node_from_debug_string(split[1])
    
    return InternalNode(parent_split_feature, parent_split_value, left_child, right_child)

In [10]:
trees = [get_node_from_debug_string(tree_debug_string) for tree_debug_string in tree_debug_strings]

In [11]:
print(str(trees[0])[:2000] + "...")

InternalNode: Pickup_Count_Nyc_4h, 127478.0
 InternalNode: Pickup_Count_Nb_1h, 304.0
  InternalNode: Hour, {3.0,4.0,2.0,5.0,1.0,20.0,0.0,19.0,11.0,12.0,16.0,13.0}
   InternalNode: Pickup_Count_Nb_1h, 165.0
    InternalNode: IsHoliday, {2.0,1.0,3.0,0.0}
     InternalNode: Pickup_Count_Nb_4h, 478.0
      InternalNode: Pickup_Count_Nyc_4h, 24922.0
       InternalNode: Hour, {2.0,3.0,4.0,0.0,1.0}
        InternalNode: Day_Of_Week, {22.0,20.0,7.0,11.0,0.0,9.0,29.0,21.0,2.0,13.0,15.0}
         InternalNode: Dropoff_Count_Nyc_4h, 16929.0
          InternalNode: Dropoff_Count_Dis_4h, 72.0
           InternalNode: Day_Of_Week, {29.0}
            LeafNode: 1.0
            InternalNode: Day_Of_Year, {1.0}
             LeafNode: 1.0
             LeafNode: 2.0
           LeafNode: 3.0
          InternalNode: Day_Of_Year, {3.0,1.0,5.0,6.0,7.0,9.0,0.0,10.0,2.0}
           InternalNode: IsHoliday, {2.0,1.0}
            InternalNode: Pickup_Count_Dis_1h, 5.0
             InternalNode: Day_Of_Week, {0.0

## Get Most Common Features
The most common features are determined by collecting all features, which appear in any tree of the random forest up to specific level (per default 4).

In [12]:
def get_features_and_levels(tree, maxlevel=10):
    result = []
    queue = deque([(tree, 0)])

    while len(queue):
        node, level = queue.popleft()
        if level > maxlevel:
            break
        if not node.is_leaf:
            result.append((node.split_feature, level))
            queue.append((node.left, level + 1))
            queue.append((node.right, level + 1))
    
    return result

In [13]:
def get_top_features(trees, maxlevel=4):
    return set([feature for tree in trees
                        for feature, level in get_features_and_levels(tree, maxlevel)])

In [14]:
sorted(get_top_features(trees))

['AWND_GHCND:US1NJBG0018',
 'Day_Of_Week',
 'Day_Of_Year',
 'Dropoff_Count_Dis_1h',
 'Dropoff_Count_Dis_4h',
 'Dropoff_Count_Nb_1h',
 'Dropoff_Count_Nb_4h',
 'Dropoff_Count_Nyc_1h',
 'Dropoff_Count_Nyc_4h',
 'Hour',
 'IsHoliday',
 'Pickup_Count_Dis_1h',
 'Pickup_Count_Dis_4h',
 'Pickup_Count_Nb_1h',
 'Pickup_Count_Nb_4h',
 'Pickup_Count_Nyc_1h',
 'Pickup_Count_Nyc_4h',
 'Venue 1187 (1)',
 'Venue 1755 (-3)',
 'Venue 1972 (-3)',
 'Venue 428 (-2)']

# Get Feature Importance

Here, we compute an importance score for each feature, based on what level the feature occurs in.

For this, we define the *node value* of a decision tree node to be: $$2^{-level}$$

Assuming that each node splits the data in half, this node fraction corresponds to the fraction of data that is affected by the decision made in this node.

The *importance score* of a feature is then the sum of all node values where the feature is used.

In [15]:
def get_counts(input):
    return {name : len(list(occurrences)) for name, occurrences in itertools.groupby(input, lambda x: x)}

In [16]:
def get_features_by_level(tree):
    grouped = itertools.groupby(get_features_and_levels(tree), lambda (feature, level): level)
    return {level : get_counts([feature for feature, level in features])
            for level, features in grouped}

In [19]:
scores = {}
for tree in trees:
    for level, features in get_features_by_level(tree).iteritems():
        for feature, count in features.iteritems():
            if not feature in scores:
                scores[feature] = 0
            scores[feature] += count * 2**(-level)

In [22]:
sorted(scores.iteritems(), key=lambda (feature, score) : -score)[:20]

[('Pickup_Count_Nb_1h', 4.7470703125),
 ('Pickup_Count_Dis_1h', 4.1552734375),
 ('Pickup_Count_Nyc_4h', 2.3056640625),
 ('Pickup_Count_Nyc_1h', 1.87109375),
 ('Hour', 1.8310546875),
 ('Dropoff_Count_Nyc_4h', 1.4365234375),
 ('Pickup_Count_Nb_4h', 1.4033203125),
 ('Dropoff_Count_Nyc_1h', 1.2607421875),
 ('Pickup_Count_Dis_4h', 1.2177734375),
 ('Dropoff_Count_Nb_4h', 1.0439453125),
 ('Day_Of_Week', 1.0185546875),
 ('Dropoff_Count_Dis_4h', 0.8076171875),
 ('IsHoliday', 0.7763671875),
 ('Day_Of_Year', 0.7158203125),
 ('Dropoff_Count_Dis_1h', 0.57421875),
 ('Dropoff_Count_Nb_1h', 0.3056640625),
 ('Venue 1972 (-3)', 0.2568359375),
 ('AWND_GHCND:US1NJBG0018', 0.244140625),
 ('Venue 428 (-2)', 0.1328125),
 ('Venue 1755 (-3)', 0.072265625)]