# Analysis of Random Forest Models
The purpose of this notebook is to analyse a trained random forest model to find out which features are used at all how they were splitted. The single decision trees can be traversed and most common features can be determined. 

In [1]:
import re
from collections import deque

from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SQLContext

from pyspark.mllib.tree import RandomForestModel

In [2]:
conf = (SparkConf().setMaster("local[*]").setAppName('pyspark'))
sc = SparkContext(conf=conf)
sql_context = SQLContext(sc)

## Configuration
You can specifying the model to be analyzed by settings the ```MODEL_LOCATION``` variable. The path can point to a directory in the local file system, HDFS or S3. 

In [18]:
MODEL_LOCATION = "../models/random_forest_n5_d15/model_40.8_-73.95/"

In [7]:
FEATURE_MAPPING = ["Pickup_Count_Dis_1h", "Dropoff_Count_Dis_1h",
                   "Pickup_Count_Dis_4h", "Dropoff_Count_Dis_4h",
                   "Pickup_Count_Nb_1h", "Dropoff_Count_Nb_1h",
                   "Pickup_Count_Nb_4h", "Dropoff_Count_Nb_4h",
                   "Pickup_Count_Nyc_1h", "Dropoff_Count_Nyc_1h",
                   "Pickup_Count_Nyc_4h", "Dropoff_Count_Nyc_4h",] \
                    + ["Hour", "Day_Of_Week", "Day_Of_Year", "IsHoliday"] \
                    + [
                   "AWND_GHCND:US1NJBG0018", "AWND_GHCND:US1NYKN0003",
                   "AWND_GHCND:US1NYKN0025", "AWND_GHCND:US1NYNS0007",
                   "AWND_GHCND:US1NYQN0002", "AWND_GHCND:US1NYRC0001",
                   "AWND_GHCND:US1NYRC0002", "AWND_GHCND:USC00300961",
                   "AWND_GHCND:USW00014732", "AWND_GHCND:USW00094728",
                   "AWND_GHCND:USW00094789", "PRCP_GHCND:US1NJBG0018",
                   "PRCP_GHCND:US1NYKN0003", "PRCP_GHCND:US1NYKN0025",
                   "PRCP_GHCND:US1NYNS0007", "PRCP_GHCND:US1NYQN0002",
                   "PRCP_GHCND:US1NYRC0001", "PRCP_GHCND:US1NYRC0002",
                   "PRCP_GHCND:USC00300961", "PRCP_GHCND:USW00014732",
                   "PRCP_GHCND:USW00094728", "PRCP_GHCND:USW00094789",
                   "TMAX_GHCND:US1NJBG0018", "TMAX_GHCND:US1NYKN0003",
                   "TMAX_GHCND:US1NYKN0025", "TMAX_GHCND:US1NYNS0007",
                   "TMAX_GHCND:US1NYQN0002", "TMAX_GHCND:US1NYRC0001",
                   "TMAX_GHCND:US1NYRC0002", "TMAX_GHCND:USC00300961",
                   "TMAX_GHCND:USW00014732", "TMAX_GHCND:USW00094728",
                   "TMAX_GHCND:USW00094789", "TMIN_GHCND:US1NJBG0018",
                   "TMIN_GHCND:US1NYKN0003", "TMIN_GHCND:US1NYKN0025",
                   "TMIN_GHCND:US1NYNS0007", "TMIN_GHCND:US1NYQN0002",
                   "TMIN_GHCND:US1NYRC0001", "TMIN_GHCND:US1NYRC0002",
                   "TMIN_GHCND:USC00300961", "TMIN_GHCND:USW00014732",
                   "TMIN_GHCND:USW00094728", "TMIN_GHCND:USW00094789"] \
                    + ["Venue %d (0h)" % i for i in range(2434)] \
                    + ["Venue %d (-3)" % i for i in range(2434)] \
                    + ["Venue %d (-2)" % i for i in range(2434)] \
                    + ["Venue %d (-1)" % i for i in range(2434)] \
                    + ["Venue %d (1)" % i for i in range(2434)] \
                    + ["Venue %d (2)" % i for i in range(2434)] \
                    + ["Venue %d (3)" % i for i in range(2434)]

## Load Model & Parse Debug String
Since PySpark does not provide an API for traversing and analyzing a random forest model inherently, the debug string containing the structure of the random forest, is parsed. 

In [8]:
model = RandomForestModel.load(sc, MODEL_LOCATION)
debug_string = model.toDebugString().split('\n')[2:-1]
print("\n".join(debug_string[:10]) + "\n...")

  Tree 0:
    If (feature 4 <= 214.0)
     If (feature 12 in {6.0,7.0,16.0,10.0,8.0,15.0,5.0,13.0,12.0,9.0,14.0,11.0,17.0,4.0,3.0,18.0,2.0,19.0})
      If (feature 2 <= 8.0)
       If (feature 4 <= 18.0)
        If (feature 8 <= 11001.0)
         If (feature 8 <= 3575.0)
          If (feature 10 <= 20289.0)
           If (feature 13 in {4.0,5.0,2.0,1.0,6.0})
            If (feature 1 <= 10.0)
...


In [9]:
class InternalNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, left, right):
        
        self.is_leaf = False
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.left = left
        self.right = right
        
        assert left.parent_split_feature == right.parent_split_feature
        assert left.parent_split_value == right.parent_split_value
        
        self.split_feature = left.parent_split_feature
        self.split_value = left.parent_split_value
        
    def __str__(self):
        
        return self.to_string(0)
        
    def to_string(self, indention):
        
        return (" " * indention) + "InternalNode: %s, %s" % (self.split_feature, self.split_value) + "\n" \
                + self.left.to_string(indention + 1) + "\n" \
                + self.right.to_string(indention + 1)
    
        
class LeafNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, prediction):
        
        self.is_leaf = True
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.prediction = prediction
        
    def to_string(self, indention):
        
        return (" " * indention) + "LeafNode: %s" % self.prediction

In [10]:
def split_debug_string(debug_string):
    indention = len(re.search(" +", debug_string[0]).group())

    result = []
    currentList = None
    for i in range(len(debug_string)):
        if not debug_string[i].startswith(" " * (indention + 1)):
            currentList = []
            result.append(currentList)
        currentList.append(debug_string[i])
    
    return result

In [11]:
tree_debug_strings = split_debug_string(debug_string)

In [12]:
def get_node_from_debug_string(debug_string):
    node_debug_string = debug_string[0].strip()
    if node_debug_string.startswith("Tree"):
        parent_split_feature = None
        parent_split_value = None
    else:
        match = re.match("(If|Else) \(feature (\d+) (<=|>|in|not in) (.+)\)", node_debug_string)
        if match is None:
            print node_debug_string
        feature_index = int(match.group(2))
        parent_split_feature = FEATURE_MAPPING[feature_index]
        parent_split_value = match.group(4)
    
    split = split_debug_string(debug_string[1:])
    if len(split) == 1:
        assert len(split[0]) == 1
        prediction_value = float(re.match("Predict: (-?\d+\.\d+)", split[0][0].strip()).group(1))
        return LeafNode(parent_split_feature, parent_split_value, prediction_value)
    
    assert len(split) == 2
    left_child = get_node_from_debug_string(split[0])
    right_child = get_node_from_debug_string(split[1])
    
    return InternalNode(parent_split_feature, parent_split_value, left_child, right_child)

In [13]:
trees = [get_node_from_debug_string(tree_debug_string) for tree_debug_string in tree_debug_strings]

In [14]:
print(str(trees[0])[:2000] + "...")

InternalNode: Pickup_Count_Nb_1h, 214.0
 InternalNode: Hour, {6.0,7.0,16.0,10.0,8.0,15.0,5.0,13.0,12.0,9.0,14.0,11.0,17.0,4.0,3.0,18.0,2.0,19.0}
  InternalNode: Pickup_Count_Dis_4h, 8.0
   InternalNode: Pickup_Count_Nb_1h, 18.0
    InternalNode: Pickup_Count_Nyc_1h, 11001.0
     InternalNode: Pickup_Count_Nyc_1h, 3575.0
      InternalNode: Pickup_Count_Nyc_4h, 20289.0
       InternalNode: Day_Of_Week, {4.0,5.0,2.0,1.0,6.0}
        InternalNode: Dropoff_Count_Dis_1h, 10.0
         InternalNode: Day_Of_Year, {0.0,2.0,3.0,6.0,9.0,12.0,13.0,17.0,18.0,20.0,25.0,27.0,30.0,35.0,39.0,46.0,49.0,53.0,56.0,59.0,60.0,61.0,63.0,64.0,65.0,66.0,67.0,74.0,75.0,81.0,82.0,83.0,90.0,96.0,97.0,102.0,103.0,109.0,116.0,117.0,118.0,122.0,152.0,156.0,159.0,181.0,187.0,192.0,199.0,213.0,214.0,215.0,221.0,222.0,236.0,237.0,238.0,242.0,249.0,250.0,265.0,271.0,273.0,274.0,276.0,277.0,278.0,283.0,284.0,285.0,298.0,299.0,303.0,307.0,313.0,319.0,330.0,331.0,332.0,340.0,353.0,360.0,362.0,21.0,11.0,239.0,4.0,32.0,358.

## Get Most Common Features
The most common features are determined by collecting all features, which appear in any tree of the random forest up to specific level (per default 4).

In [15]:
def get_features_and_levels(tree, maxlevel=10):
    result = []
    queue = deque([(tree, 0)])

    while len(queue):
        node, level = queue.popleft()
        if level > maxlevel:
            break
        if not node.is_leaf:
            result.append((node.split_feature, level))
            queue.append((node.left, level + 1))
            queue.append((node.right, level + 1))
    
    return result

In [16]:
def get_top_features(trees, maxlevel=4):
    return set([feature for tree in trees
                        for feature, level in get_features_and_levels(tree, maxlevel)])

In [17]:
sorted(get_top_features(trees))

['Day_Of_Week',
 'Day_Of_Year',
 'Dropoff_Count_Dis_1h',
 'Dropoff_Count_Dis_4h',
 'Dropoff_Count_Nb_1h',
 'Dropoff_Count_Nb_4h',
 'Dropoff_Count_Nyc_1h',
 'Dropoff_Count_Nyc_4h',
 'Hour',
 'Pickup_Count_Dis_1h',
 'Pickup_Count_Dis_4h',
 'Pickup_Count_Nb_1h',
 'Pickup_Count_Nb_4h',
 'Pickup_Count_Nyc_1h',
 'Pickup_Count_Nyc_4h',
 'Venue 1326 (2)',
 'Venue 1910 (2)',
 'Venue 2020 (-1)']