In [None]:
import argparse, os, logging, random, time
import numpy as np
import math
import time
import scipy.sparse
import lightgbm as lgb
# import data_helpers as dh
import pickle 

In [None]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [None]:
def countSplitNodes(tree):
    root = tree['tree_structure']
    logging.info(f'Func: countSplitNodes.\n parameter-root: {root}.')
    def counter(root):
        if 'split_index' not in root:
            return 0
        return 1 + counter(root['left_child']) + counter(root['right_child'])
    ans = counter(root)
    logging.info(f'Func: countSplitNodes.\n return: {ans}.')
    return ans

def getItemByTree(tree, item='split_feature'):
    logging.info(f'Func: getItemByTree.\n parameter-item: {item}.')
    logging.info(f'Func: getItemByTree.\n parameter-tree: {tree}.')
    root = tree.raw['tree_structure']
    logging.info(f'Func: getItemByTree.\n parameter-root: {root}.')
    split_nodes = tree.split_nodes
    res = np.zeros(split_nodes+tree.raw['num_leaves'], dtype=np.int32)
    if 'value' in item or 'threshold' in item or 'split_gain' in item:
        res = res.astype(np.float64)
    def getFeature(root, res):
        if 'child' in item:
            if 'split_index' in root:
                node = root[item]
                if 'split_index' in node:
                    res[root['split_index']] = node['split_index']
                else:
                    res[root['split_index']] = node['leaf_index'] + split_nodes # need to check
            else:
                res[root['leaf_index'] + split_nodes] = -1
        elif 'value' in item:
            if 'split_index' in root:
                res[root['split_index']] = root['internal_'+item]
            else:
                res[root['leaf_index'] + split_nodes] = root['leaf_'+item]
        else:
            if 'split_index' in root:
                res[root['split_index']] = root[item]
            else:
                res[root['leaf_index'] + split_nodes] = -2
        if 'left_child' in root:
            getFeature(root['left_child'], res)
        if 'right_child' in root:
            getFeature(root['right_child'], res)
    getFeature(root, res)
    logging.info(f'Func: getItemByTree.\n return: {res}.')
    return res

def getTreeSplits(model):
    
    logging.info(f'Func: getTreeSplits.\n parameter-model: {model["tree_info"]}.')
    featurelist = []
    threhlist = []
    trees = []
    for idx, tree in enumerate(model['tree_info']):
        trees.append(TreeInterpreter(tree))
        logging.info(f'Func: getTreeSplits.\n parameter-trees-tree: {TreeInterpreter(tree)}.')
        featurelist.append(trees[-1].feature)
        logging.info(f'Func: getTreeSplits.\n parameter-featurelist.append(trees[-1].feature): \
        {trees[-1].feature}.')
        threhlist.append(getItemByTree(trees[-1], 'threshold'))
        logging.info(f'Func: getTreeSplits.\n parameter-threhlist.threshold: \
        {getItemByTree(trees[-1], "threshold")}.')
    logging.info(f'Func: getTreeSplits.\n return trees: {trees}.\n return featurelist: {featurelist}.\n \
        return threhlist: {threhlist}.')
    return (trees, featurelist, threhlist)


def getChildren(trees):
    logging.info(f'Func: getChildren.\n parameter-trees: {trees}.')
    listcl = []
    listcr = []
    for idx, tree in enumerate(trees):
        listcl.append(getItemByTree(tree, 'left_child'))
        listcr.append(getItemByTree(tree, 'right_child'))
    logging.info(f'Func: getChildren.\n return listcl: {listcl}.\n return listcr: {listcr}.')
    return(listcl, listcr)

class TreeInterpreter(object):
    def __init__(self, tree):
        self.raw = tree
        self.split_nodes = countSplitNodes(tree)
        self.node_count = self.split_nodes # + tree['num_leaves']
        self.value = getItemByTree(self, item='value')
        self.feature = getItemByTree(self)
        self.gain = getItemByTree(self, 'split_gain')
        # self.leaf_value = getLeafValue(tree)
        logging.info(f'Class: TreeInterpreter.\n return self.raw: {tree}.\n \
            return self.split_nodes: {self.split_nodes}. \
            \n return self.node_count: {self.node_count}.\n return self.value: {self.value}. \
            \n return self.feature: {self.feature}.\n return self.gain: {self.gain}.')

class ModelInterpreter(object):
    def __init__(self, model, tree_model='lightgbm'):
        print("Model Interpreting...")
        self.tree_model = tree_model
        model = model.dump_model()
        self.n_features_ = model['max_feature_idx'] + 1
        self.trees, self.featurelist, self.threshlist = getTreeSplits(model)
        self.listcl, self.listcr = getChildren(self.trees)
        logging.info(f'Class: ModelInterpreter.\n return self.tree_model: {self.tree_model}.\n \
            return self.n_features_: {self.n_features_}. \
            \n return self.trees: {self.trees}.\n return self.featurelist: {self.featurelist}. \
            \n return self.threshlist: {self.threshlist}.\n return self.listcl: {self.listcl}. \
            \n return self.listcr: {self.listcr}.')

    def GetTreeSplits(self):
        return (self.trees, self.featurelist, self.threshlist)

    def GetChildren(self):
        return (self.listcl, self.listcr)

    def EqualGroup(self, n_clusters):
        logging.info(f'Class: ModelInterpreter : func EqualGroup.\n \
            n_clusters : {n_clusters}.')
        vectors = {}
        # n_feature = 256
        (f'Class: ModelInterpreter : func EqualGroup.\n \
            self.featurelist: {self.featurelist}.')
        for idx,features in enumerate(self.featurelist):
            vectors[idx] = set(features[np.where(features>0)])
            logging.info(f'Class: ModelInterpreter : func EqualGroup.\n \
            featurelist vectors[idx]: {vectors[idx]}.')
        keys = random.sample(vectors.keys(), len(vectors))
        logging.info(f'Class: ModelInterpreter : func EqualGroup.\n \
            keys : {keys}.')
        clusterIdx = np.zeros(len(vectors))
        logging.info(f'Class: ModelInterpreter : func EqualGroup.\n \
            clusterIdx : {clusterIdx}.')
        # groups = [[] for i in range(n_clusters)]
        trees_per_cluster = len(vectors)//n_clusters
        mod_per_cluster = len(vectors) % n_clusters
        begin = 0
        for idx in range(n_clusters):
            for jdx in range(trees_per_cluster):
                clusterIdx[keys[begin]] = idx
                begin += 1
            if idx < mod_per_cluster:
                clusterIdx[keys[begin]] = idx
                begin += 1
        print([np.where(clusterIdx==i)[0].shape for i in range(n_clusters)])
        logging.info(f'Class: ModelInterpreter : func EqualGroup.\n return clusterIdx: {clusterIdx}.')
        return clusterIdx

In [None]:
def SubGBDTLeaf_cls(train_x, test_x, gbm, maxleaf=5, num_slices=2, 
                    group_method="Random", feat_per_group=4, tree_model='lightgbm'):
    
    MAX=train_x.shape[1]
    leaf_preds = gbm.predict(train_x, pred_leaf=True).reshape(train_x.shape[0], -1)
    test_leaf_preds = gbm.predict(test_x, pred_leaf=True).reshape(test_x.shape[0], -1)
    n_trees = leaf_preds.shape[1]
    step = int((n_trees + num_slices - 1) // num_slices)
    step = max(step, 1)
    leaf_output = np.zeros([n_trees, maxleaf], dtype=np.float32)
    for tid in range(n_trees):
        num_leaf = np.max(leaf_preds[:,tid]) + 1
        for lid in range(num_leaf):
            leaf_output[tid][lid] = gbm.get_leaf_output(tid, lid)
    rest_nt = n_trees
    modelI = ModelInterpreter(gbm, group_method)
    if group_method == 'Equal' or group_method == 'Random':
        clusterIdx = modelI.EqualGroup(num_slices)
        n_feature = feat_per_group
    treeI = modelI.trees
    # rand = (args.group_method == 'Random')
    Allset = set([i for i in range(MAX)])
    for n_idx in range(num_slices):
        tree_indices = np.where(clusterIdx == n_idx)[0]
        trees = {}
        tid = 0
        for jdx in tree_indices:
            trees[str(tid)] = treeI[jdx].raw
            tid += 1
        tree_num = len(tree_indices)
        layer_num = 1
        xi = []
        xi_fea = set()
        all_hav = {} # set([i for i in range(MAX)])
        for jdx, tree in enumerate(tree_indices):
            for kdx, f in enumerate(treeI[tree].feature):
                if f == -2:
                    continue
                if f not in all_hav:
                    all_hav[f] = 0
                all_hav[f] += treeI[tree].gain[kdx]
        used_features = []
        rest_feature = []
        all_hav = sorted(all_hav.items(), key=lambda kv: -kv[1])
        used_features = [item[0] for item in all_hav[:n_feature]]
        # if rand:
        # used_features = np.random.choice(MAX, len(used_features), replace = False).tolist()
        used_features_set = set(used_features)
        for kdx in range(max(0, n_feature - len(used_features))):
            used_features.append(MAX)
        cur_leaf_preds = leaf_preds[:, tree_indices]
        cur_test_leaf_preds = test_leaf_preds[:, tree_indices]
        new_train_y = np.zeros(train_x.shape[0])
        for jdx in tree_indices:
            new_train_y += np.take(leaf_output[jdx,:].reshape(-1), leaf_preds[:,jdx].reshape(-1))
        new_train_y = new_train_y.reshape(-1,1).astype(np.float32)
        yield used_features, new_train_y, cur_leaf_preds, cur_test_leaf_preds, np.mean(np.take(leaf_output, tree_indices,0)), np.mean(leaf_output)


In [None]:
iris = datasets.load_iris()
X = iris.data
y = iris.target
train_x, test_x, train_y, test_y = train_test_split(
    X, y, test_size=0.33, random_state=42)

In [None]:
from importlib import reload 

reload(logging)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler('iris-2.log')
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)

In [None]:
with open('gbm_dump_iris_2_tree.pickle', 'rb') as f:
    # Pickle using the highest protocol available.
    gbm = pickle.load(f)

In [None]:
type(gbm)

lightgbm.basic.Booster

In [None]:
gbms = SubGBDTLeaf_cls(train_x, test_x, gbm)
min_len_features = train_x.shape[1]
used_features = []
tree_outputs = []
leaf_preds = []
test_leaf_preds = []
n_output = train_y.shape[0]
max_ntree_per_split = 0
group_average = []
for used_feature, new_train_y, leaf_pred, test_leaf_pred, avg, all_avg in gbms:
    logging.info(f'Func: SubGBDTLeaf_cls\n return used_feature: {used_feature}. \
        \n return new_train_y: {new_train_y}. \
        \n return leaf_pred: {leaf_pred}. \
        \n return test_leaf_pred: {test_leaf_pred}. \
        \n return avg: {avg}. \
        \n return all_avg: {all_avg}.')
    
    used_features.append(used_feature)
    min_len_features = min(min_len_features, len(used_feature))
    tree_outputs.append(new_train_y)
    leaf_preds.append(leaf_pred)
    test_leaf_preds.append(test_leaf_pred)
    group_average.append(avg)
    max_ntree_per_split = max(max_ntree_per_split, leaf_pred.shape[1])
for i in range(len(used_features)):
    used_features[i] = sorted(used_features[i][:min_len_features])
n_models = len(used_features)
group_average = np.asarray(group_average).reshape(n_models, 1, 1)
for i in range(n_models):
    if leaf_preds[i].shape[1] < max_ntree_per_split:
        leaf_preds[i] = np.concatenate([leaf_preds[i] + 1, 
                                        np.zeros([leaf_preds[i].shape[0],
                                                    max_ntree_per_split-leaf_preds[i].shape[1]],
                                                    dtype=np.int32)], axis=1)
        test_leaf_preds[i] = np.concatenate([test_leaf_preds[i] + 1, 
                                                np.zeros([test_leaf_preds[i].shape[0],
                                                        max_ntree_per_split-test_leaf_preds[i].shape[1]],
                                                        dtype=np.int32)], axis=1)
leaf_preds = np.concatenate(leaf_preds, axis=1)
test_leaf_preds = np.concatenate(test_leaf_preds, axis=1)
logging.info(f'Return: leaf_preds: {leaf_preds}. \
        \n return test_leaf_preds: {test_leaf_preds}.')


Model Interpreting...
[(1,), (1,)]


In [None]:

with open('n_models_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(n_models, f, pickle.HIGHEST_PROTOCOL)
    
with open('max_ntree_per_split_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(max_ntree_per_split, f, pickle.HIGHEST_PROTOCOL)
    
with open('group_average_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(group_average, f, pickle.HIGHEST_PROTOCOL)

with open('leaf_preds_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(leaf_preds, f, pickle.HIGHEST_PROTOCOL)
    
with open('test_leaf_preds_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(test_leaf_preds, f, pickle.HIGHEST_PROTOCOL)

with open('tree_outputs_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(tree_outputs, f, pickle.HIGHEST_PROTOCOL)   
with open('used_features_iris_2.pickle', 'wb') as f:
    # Pickle using the highest protocol available.
    pickle.dump(used_features, f, pickle.HIGHEST_PROTOCOL)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=de072003-a9db-4342-8067-19a4b45feff1' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>