In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import sys
import os
import numpy as np
import pandas as pd
import lightgbm
import json

In [11]:
model_file = '../out/models/census/adv-boosting_census_B30_T200_S0050_L24_R200.T20.model'
model = lightgbm.Booster(model_file=model_file)

In [12]:
model.save_model('test.txt')

<lightgbm.basic.Booster at 0x7f2b7b879630>

In [13]:
model_json = model.dump_model()

In [14]:
model_json

{'name': 'tree',
 'version': 'v2',
 'num_class': 1,
 'num_tree_per_iteration': 1,
 'label_index': 0,
 'max_feature_idx': 12,
 'average_output': False,
 'feature_names': ['Column_0',
  'Column_1',
  'Column_2',
  'Column_3',
  'Column_4',
  'Column_5',
  'Column_6',
  'Column_7',
  'Column_8',
  'Column_9',
  'Column_10',
  'Column_11',
  'Column_12'],
 'tree_info': [{'tree_index': 0,
   'num_leaves': 24,
   'num_cat': 2,
   'shrinkage': 0.05,
   'tree_structure': {'split_index': 0,
    'split_feature': 9,
    'split_gain': 0.10160499811172485,
    'threshold': 5036.500000000001,
    'decision_type': '<=',
    'default_left': True,
    'missing_type': 'None',
    'internal_value': 0,
    'internal_count': 27132,
    'left_child': {'split_index': 1,
     'split_feature': 3,
     'split_gain': 0.055960699915885925,
     'threshold': 12.500000000000002,
     'decision_type': '<=',
     'default_left': True,
     'missing_type': 'None',
     'internal_value': -1.14938,
     'internal_count'

# Robust model

In [15]:
# 
from robust_forest import *

model = None
with open("../out/models/census/robust_census_B0_T100_D8_I20_20_20.model", 'rb') as mf:
    model = dill.load(mf)

In [16]:
tree = model.estimators[0]

In [17]:
tree

RobustDecisionTree(attacker=<robust_forest.Attacker object at 0x7f2b7b843908>,
          feature_blacklist={}, max_depth=8, max_features=0.8,
          max_samples=0.8, min_instances_per_node=20,
          replace_features=False, replace_samples=False, seed=0,
          split_optimizer=<robust_forest.SplitOptimizer object at 0x7f2af9eddb00>,
          tree_id=0)

In [39]:
def count_leaves(n):
    if n.is_leaf():
        return 1
    else:
        return count_leaves(n.left)+count_leaves(n.right)


def node_json(n, is_numerical):
    node_j = {}
    
    if n.is_leaf():
#         'left_child': {'leaf_index': 0,
        node_j['leaf_value'] = n.get_node_prediction()
#         'leaf_count': 4906},
    else:
        node_j['split_index'] = n.node_id[0] # to be fixed
        node_j['split_feature'] = n.best_split_feature_id
    #     'split_gain': 0.10160499811172485,
        node_j['threshold'] = n.best_split_feature_value
        if is_numerical[n.best_split_feature_id]:
            node_j['decision_type'] = '<='
        else:
            node_j['decision_type'] = '==' # check if this is ok for lightgbm
        node_j['default_left'] = True
        node_j['missing_type'] = 'None'
    #     'internal_value': 0,   # is this the prediction?
    #     'internal_count': 27132, # the number of instances, which we do not need
        node_j['left_child'] = node_json(n.left, is_numerical)
        node_j['right_child'] = node_json(n.right, is_numerical)
    
    return node_j
    
def tree_json(t):
    tree_j = {}
    tree_j['tree_index'] = t.tree_id
    tree_j['num_leaves'] = count_leaves(t.root)
    # 'num_cat': 2,
    # 'shrinkage': 0.05, In scikit learn (1/num_trees)
    # tree_j['num_leaves'] = 24
    
    # 'tree_structure':
    tree_j['tree_structure'] = node_json(tree.root, tree.numerical_idx)
    
    return tree_j

tree_json(tree)

{'tree_index': 0,
 'num_leaves': 71,
 'tree_structure': {'split_index': -1,
  'split_feature': 4,
  'threshold': 'Married-civ-spouse',
  'decision_type': '==',
  'default_left': True,
  'missing_type': 'None',
  'left_child': {'split_index': -1,
   'split_feature': 3,
   'threshold': 12,
   'decision_type': '<=',
   'default_left': True,
   'missing_type': 'None',
   'left_child': {'split_index': -1,
    'split_feature': 9,
    'threshold': 5013,
    'decision_type': '<=',
    'default_left': True,
    'missing_type': 'None',
    'left_child': {'split_index': -1,
     'split_feature': 10,
     'threshold': 1740,
     'decision_type': '<=',
     'default_left': True,
     'missing_type': 'None',
     'left_child': {'split_index': -1,
      'split_feature': 3,
      'threshold': 9,
      'decision_type': '<=',
      'default_left': True,
      'missing_type': 'None',
      'left_child': {'split_index': -1,
       'split_feature': 3,
       'threshold': 8,
       'decision_type': '<=',
  

In [38]:
count_leaves(tree.root)

71