In [136]:
import re
from collections import deque

from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SQLContext

from pyspark.mllib.tree import RandomForestModel

In [2]:
conf = (SparkConf().setMaster("local[*]").setAppName('pyspark'))
sc = SparkContext(conf=conf)
sql_context = SQLContext(sc)

# Configuration

In [3]:
MODEL_LOCATION = "/Users/georg/Downloads/random_forest_d15_n5/model_40.72_-73.94/"

In [201]:
FEATURE_MAPPING = ["Pickup_Count_Dis_1h", "Dropoff_Count_Dis_1h",
                   "Pickup_Count_Dis_4h", "Dropoff_Count_Dis_4h",
                   "Pickup_Count_Nb_1h", "Dropoff_Count_Nb_1h",
                   "Pickup_Count_Nb_4h", "Dropoff_Count_Nb_4h",
                   "Pickup_Count_Nyc_1h", "Dropoff_Count_Nyc_1h",
                   "Pickup_Count_Nyc_4h", "Dropoff_Count_Nyc_4h",
                   "Is_Holiday",
                   "Hour", "Day", "Month", "Year", "Weekday",
                   "AWND_GHCND:US1NJBG0018", "AWND_GHCND:US1NYKN0003",
                   "AWND_GHCND:US1NYKN0025", "AWND_GHCND:US1NYNS0007",
                   "AWND_GHCND:US1NYQN0002", "AWND_GHCND:US1NYRC0001",
                   "AWND_GHCND:US1NYRC0002", "AWND_GHCND:USC00300961",
                   "AWND_GHCND:USW00014732", "AWND_GHCND:USW00094728",
                   "AWND_GHCND:USW00094789", "PRCP_GHCND:US1NJBG0018",
                   "PRCP_GHCND:US1NYKN0003", "PRCP_GHCND:US1NYKN0025",
                   "PRCP_GHCND:US1NYNS0007", "PRCP_GHCND:US1NYQN0002",
                   "PRCP_GHCND:US1NYRC0001", "PRCP_GHCND:US1NYRC0002",
                   "PRCP_GHCND:USC00300961", "PRCP_GHCND:USW00014732",
                   "PRCP_GHCND:USW00094728", "PRCP_GHCND:USW00094789",
                   "TMAX_GHCND:US1NJBG0018", "TMAX_GHCND:US1NYKN0003",
                   "TMAX_GHCND:US1NYKN0025", "TMAX_GHCND:US1NYNS0007",
                   "TMAX_GHCND:US1NYQN0002", "TMAX_GHCND:US1NYRC0001",
                   "TMAX_GHCND:US1NYRC0002", "TMAX_GHCND:USC00300961",
                   "TMAX_GHCND:USW00014732", "TMAX_GHCND:USW00094728",
                   "TMAX_GHCND:USW00094789", "TMIN_GHCND:US1NJBG0018",
                   "TMIN_GHCND:US1NYKN0003", "TMIN_GHCND:US1NYKN0025",
                   "TMIN_GHCND:US1NYNS0007", "TMIN_GHCND:US1NYQN0002",
                   "TMIN_GHCND:US1NYRC0001", "TMIN_GHCND:US1NYRC0002",
                   "TMIN_GHCND:USC00300961", "TMIN_GHCND:USW00014732",
                   "TMIN_GHCND:USW00094728", "TMIN_GHCND:USW00094789"] \
                    + ["Venue %d (0h)" % i for i in range(2434)] \
                    + ["Venue %d (-3)" % i for i in range(2434)] \
                    + ["Venue %d (-2)" % i for i in range(2434)] \
                    + ["Venue %d (-1)" % i for i in range(2434)] \
                    + ["Venue %d (1)" % i for i in range(2434)] \
                    + ["Venue %d (2)" % i for i in range(2434)] \
                    + ["Venue %d (3)" % i for i in range(2434)]
len(FEATURE_MAPPING)

17100

# Load Model & Get Debug String

In [202]:
model = RandomForestModel.load(sc, MODEL_LOCATION)

In [203]:
debug_string = model.toDebugString().split('\n')[2:-1]
print(debug_string[:10])

[u'  Tree 0:', u'    If (feature 5 <= 368.0)', u'     If (feature 5 <= 153.0)', u'      If (feature 4 <= 26.0)', u'       If (feature 4 <= 14.0)', u'        If (feature 2 <= 8.0)', u'         If (feature 2 <= 3.0)', u'          If (feature 10 <= 127953.0)', u'           If (feature 1 <= 5.0)', u'            If (feature 4 <= 5.0)']


# Parse Debug String

In [204]:
class InternalNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, left, right):
        
        self.is_leaf = False
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.left = left
        self.right = right
        
        assert left.parent_split_feature == right.parent_split_feature
        assert left.parent_split_value == right.parent_split_value
        
        self.split_feature = left.parent_split_feature
        self.split_value = left.parent_split_value
        
    def __str__(self):
        
        return self.to_string(0)
        
    def to_string(self, indention):
        
        return (" " * indention) + "InternalNode: %s, %s" % (self.split_feature, self.split_value) + "\n" \
                + self.left.to_string(indention + 1) + "\n" \
                + self.right.to_string(indention + 1)
    
        
class LeafNode(object):
    
    def __init__(self, parent_split_feature, parent_split_value, prediction):
        
        self.is_leaf = True
        
        self.parent_split_feature = parent_split_feature
        self.parent_split_value = parent_split_value
        self.prediction = prediction
        
    def to_string(self, indention):
        
        return (" " * indention) + "LeafNode: %s" % self.prediction

In [205]:
def split_debug_string(debug_string):
    indention = len(re.search(" +", debug_string[0]).group())

    result = []
    currentList = None
    for i in range(len(debug_string)):
        if not debug_string[i].startswith(" " * (indention + 1)):
            currentList = []
            result.append(currentList)
        currentList.append(debug_string[i])
    
    return result

In [206]:
tree_debug_strings = split_debug_string(debug_string)

In [207]:
def get_node_from_debug_string(debug_string):
    node_debug_string = debug_string[0].strip()
    if node_debug_string.startswith("Tree"):
        parent_split_feature = None
        parent_split_value = None
    else:
        match = re.match("(If|Else) \(feature (\d+) (<=|>) (-?\d+\.\d+)\)", node_debug_string)
        feature_index = int(match.group(2))
        if feature_index >= len(FEATURE_MAPPING):
            print("WARN: feature index out of range: %d" % feature_index)
            parent_split_feature = "OOR"
        else:
            parent_split_feature = FEATURE_MAPPING[feature_index]
        parent_split_value = float(match.group(4))
    
    split = split_debug_string(debug_string[1:])
    if len(split) == 1:
        assert len(split[0]) == 1
        prediction_value = float(re.match("Predict: (-?\d+\.\d+)", split[0][0].strip()).group(1))
        return LeafNode(parent_split_feature, parent_split_value, prediction_value)
    
    assert len(split) == 2
    left_child = get_node_from_debug_string(split[0])
    right_child = get_node_from_debug_string(split[1])
    
    return InternalNode(parent_split_feature, parent_split_value, left_child, right_child)

In [208]:
trees = [get_node_from_debug_string(tree_debug_string) for tree_debug_string in tree_debug_strings]

WARN: feature index out of range: 17369
WARN: feature index out of range: 17369
WARN: feature index out of range: 17297
WARN: feature index out of range: 19134
WARN: feature index out of range: 19134
WARN: feature index out of range: 17297
WARN: feature index out of range: 18061
WARN: feature index out of range: 18061
WARN: feature index out of range: 18573
WARN: feature index out of range: 18573
WARN: feature index out of range: 19127
WARN: feature index out of range: 19127
WARN: feature index out of range: 18316
WARN: feature index out of range: 18316
WARN: feature index out of range: 19342
WARN: feature index out of range: 19342
WARN: feature index out of range: 18932
WARN: feature index out of range: 18289
WARN: feature index out of range: 18289
WARN: feature index out of range: 18932
WARN: feature index out of range: 17807
WARN: feature index out of range: 17807
WARN: feature index out of range: 17996
WARN: feature index out of range: 17996
WARN: feature index out of range: 17340


In [209]:
print(str(trees[0])[:500] + "...")

InternalNode: Dropoff_Count_Nb_1h, 368.0
 InternalNode: Dropoff_Count_Nb_1h, 153.0
  InternalNode: Pickup_Count_Nb_1h, 26.0
   InternalNode: Pickup_Count_Nb_1h, 14.0
    InternalNode: Pickup_Count_Dis_4h, 8.0
     InternalNode: Pickup_Count_Dis_4h, 3.0
      InternalNode: Pickup_Count_Nyc_4h, 127953.0
       InternalNode: Dropoff_Count_Dis_1h, 5.0
        InternalNode: Pickup_Count_Nb_1h, 5.0
         InternalNode: Dropoff_Count_Nyc_1h, 16502.0
          InternalNode: Dropoff_Count_Nyc_4h, 39174...


In [210]:
len(trees)

5

# Get Most Common Features

In [211]:
def get_features_and_levels(tree, maxlevel=10):
    result = []
    queue = deque([(tree, 0)])

    while len(queue):
        node, level = queue.popleft()
        if level > maxlevel:
            break
        if not node.is_leaf:
            result.append((node.split_feature, level))
            queue.append((node.left, level + 1))
            queue.append((node.right, level + 1))
    
    return result

In [214]:
def get_top_features(trees, maxlevel=4):
    return set([feature for tree in trees
                        for feature, level in get_features_and_levels(tree, maxlevel)])

In [215]:
get_top_features(trees)

{'AWND_GHCND:US1NJBG0018',
 'Day',
 'Dropoff_Count_Dis_1h',
 'Dropoff_Count_Dis_4h',
 'Dropoff_Count_Nb_1h',
 'Dropoff_Count_Nb_4h',
 'Dropoff_Count_Nyc_1h',
 'Dropoff_Count_Nyc_4h',
 'Month',
 'OOR',
 'PRCP_GHCND:US1NYNS0007',
 'PRCP_GHCND:USW00094728',
 'Pickup_Count_Dis_1h',
 'Pickup_Count_Dis_4h',
 'Pickup_Count_Nb_1h',
 'Pickup_Count_Nb_4h',
 'Pickup_Count_Nyc_1h',
 'Pickup_Count_Nyc_4h',
 'TMIN_GHCND:US1NYKN0025',
 'TMIN_GHCND:USW00094789',
 'Venue 1535 (3)',
 'Venue 1575 (3)',
 'Venue 1916 (2)',
 'Venue 2026 (0h)',
 'Venue 2030 (0h)',
 'Venue 2031 (0h)',
 'Venue 2160 (1)',
 'Venue 2272 (0h)',
 'Venue 2322 (3)',
 'Venue 675 (2)',
 'Venue 910 (-3)',
 'Year'}