## Imports

In [40]:
import json
import os
import pickle

import pandas as pd

import treelite
from treelite import Model, ModelBuilder
import treelite.sklearn

## Paths

In [41]:
model_dir_daily_10_trees = "models/sklearn/daily/10_trees"
model_dir_daily_100_trees = "models/sklearn/daily/100_trees"
model_dir_5days_10_trees = "models/sklearn/5days/10_trees"
model_dir_5days_100_trees = "models/sklearn/5days/100_trees"

## Load Models Daily

In [42]:
def get_model_jsons(model_dir):
    models = {}

    # Iterate over Files
    for model_name in os.listdir(model_dir):
        model_path = os.path.join(model_dir, model_name)
        
        # Load Model from SKLearn
        with open(model_path, "rb") as f:
            model = pickle.load(f)
            
        # Load Model into Treelite
        treelite_model = treelite.sklearn.import_model(model)
        # Load JSON Representation
        treelite_model_json = json.loads(treelite_model.dump_as_json(pretty_print=False))
        
        models[model_name] = treelite_model_json
        
    return models

In [43]:
daily_model_10_trees = get_model_jsons(model_dir_daily_10_trees)
daily_model_100_trees = get_model_jsons(model_dir_daily_100_trees)

fiveday_model_10_trees = get_model_jsons(model_dir_5days_10_trees)
fiveday_model_100_trees = get_model_jsons(model_dir_5days_100_trees)

## Comparisons

In [44]:
def remove_unneeded_node_keys(node):
    node.pop("data_count", None)
    node.pop("sum_hess", None)
    node.pop("gain", None)
    return node


def duplicate_tree(tree1, tree2):
    duplicate_ratio = 0.0
    num_nodes = 0
    # Sanity Checks
    if tree1["num_nodes"] == tree2["num_nodes"] and tree1["has_categorical_split"] == tree2["has_categorical_split"]:
        num_nodes = len(tree1["nodes"])
        duplicate_nodes = 0
        # Iterate over nodes
        for node1 in tree1["nodes"]:
            node1 = remove_unneeded_node_keys(node1)
            for node2 in tree2["nodes"]:
                node2 = remove_unneeded_node_keys(node2)
                if node1 == node2:
                    duplicate_nodes += 1
        
        duplicate_ratio = duplicate_nodes/num_nodes
            
    return num_nodes, duplicate_ratio


def duplicate_trees(trees1, trees2):
    duplicate_data = []
    # Iterate over Trees
    for i, tree1 in enumerate(trees1):
        for j, tree2 in enumerate(trees2):
            num_nodes, duplicate_ratio = duplicate_tree(tree1, tree2)
            
            if duplicate_ratio == 0.0:
                print(f"Found duplicate_tree for tree {i} and tree {j}")
            # if duplicate_ratio > 0.0:
            #     print(f"Duplicate Ratio for tree {i} and tree {j}: {duplicate_ratio}")
                
            duplicate_data.append({
                "num_tree_model1": i,
                "num_tree_model2": j,
                "num_nodes": num_nodes,
                "duplicate_ratio": duplicate_ratio
            })
            
    return duplicate_data
   

def duplicate_model(model1: dict, model2: dict) -> bool:
    if model1["num_feature"] != model2["num_feature"]:
        return False 
    elif model1["task_type"] != model2["task_type"]:
        return False 
    elif model1["average_tree_output"] != model2["average_tree_output"]:
        return False 
    elif model1["task_param"] != model2["task_param"]:
        return False 
    elif model1["model_param"] != model2["model_param"]:
        return False 
    else:
        return duplicate_trees(model1["trees"], model2["trees"])


def compare_models(models: list):
    
    duplicate_data = []
    
    for i, (model_name_a, model_a) in enumerate(models.items()):
        for j, (model_name_b, model_b) in enumerate(models.items()):
            if i == j: continue

            duplicate_dict_list = duplicate_model(model_a, model_b)
            
            duplicate_dict_list = [{**d, "model1": model_name_a, "model2": model_name_b} for d in duplicate_dict_list]
            duplicate_data.extend(duplicate_dict_list)
            
    return pd.DataFrame(duplicate_data)

In [45]:
duplicate_data = compare_models(daily_model_10_trees)
display(duplicate_data.describe())
print(duplicate_data['duplicate_ratio'].value_counts())

Unnamed: 0,num_tree_model1,num_tree_model2,num_nodes,duplicate_ratio
count,93000.0,93000.0,93000.0,93000.0
mean,4.5,4.5,6.338839,0.046955
std,2.872297,2.872297,2.043891,0.082459
min,0.0,0.0,0.0,0.0
25%,2.0,2.0,7.0,0.0
50%,4.5,4.5,7.0,0.0
75%,7.0,7.0,7.0,0.142857
max,9.0,9.0,7.0,0.428571


duplicate_ratio
0.000000    67356
0.142857    21028
0.285714     4308
0.428571      308
Name: count, dtype: int64


In [49]:
duplicate_data = compare_models(daily_model_100_trees)
display(duplicate_data.describe())
print(duplicate_data['duplicate_ratio'].value_counts())

: 

In [47]:
duplicate_data = compare_models(fiveday_model_10_trees)
display(duplicate_data.describe())
print(duplicate_data['duplicate_ratio'].value_counts())

Unnamed: 0,num_tree_model1,num_tree_model2,num_nodes,duplicate_ratio
count,70200.0,70200.0,70200.0,70200.0
mean,4.5,4.5,10.144359,0.053971
std,2.872302,2.872302,6.982538,0.078218
min,0.0,0.0,0.0,0.0
25%,2.0,2.0,0.0,0.0
50%,4.5,4.5,15.0,0.0
75%,7.0,7.0,15.0,0.066667
max,9.0,9.0,15.0,0.461538


duplicate_ratio
0.000000    41476
0.066667    11366
0.133333     9524
0.200000     5050
0.266667     1852
0.333333      402
0.153846      130
0.230769      102
0.076923       98
0.307692       74
0.384615       64
0.400000       38
0.461538       18
0.222222        6
Name: count, dtype: int64


In [48]:
duplicate_data = compare_models(fiveday_model_100_trees)
display(duplicate_data.describe())
print(duplicate_data['duplicate_ratio'].value_counts())

Unnamed: 0,num_tree_model1,num_tree_model2,num_nodes,duplicate_ratio
count,560000.0,560000.0,560000.0,560000.0
mean,49.5,49.5,14.625,0.136953
std,28.866096,28.866096,2.341876,0.077812
min,0.0,0.0,0.0,0.0
25%,24.75,24.75,15.0,0.066667
50%,49.5,49.5,15.0,0.133333
75%,74.25,74.25,15.0,0.2
max,99.0,99.0,15.0,0.466667


duplicate_ratio
0.133333    186598
0.066667    147296
0.200000    119940
0.266667     51290
0.000000     42144
0.333333     11488
0.400000      1214
0.466667        30
Name: count, dtype: int64
