In [1]:
import argparse
from copy import deepcopy
import logging
import random
from collections import defaultdict
from os.path import join
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score, r2_score
from sklearn.model_selection import train_test_split
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
import joblib
import imodels
import inspect
import os.path
import imodelsx.cache_save_utils
import sys
import torch
#path_to_repo = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

#os.chdir(path_to_repo)
#os.chdir('/home/mattyshen/interpretableDistillation')
sys.path.append('..')

import idistill.model
import idistill.data
from idistill.ftd import FTDistillRegressorCV
from idistill.figs_distiller import FIGSRegressor

sys.path.append('/home/mattyshen/iCBM')

from CUB.template_model import End2EndModel, Inception3, MLP

class ARGS:
    def __init__(self, a_dict):
        for k in a_dict.keys():
            exec(f'self.{k} = a_dict["{k}"]')
            
def fit_model(model, X_train, y_train, feature_names, r):
    # fit the model
    fit_parameters = inspect.signature(model.fit).parameters.keys()
    if "feature_names" in fit_parameters and feature_names is not None:
        model.fit(X_train, y_train, feature_names=feature_names)
    else:
        model.fit(X_train, y_train)

    return r, model

def evaluate_model(model, X_train, X_val, y_train, y_val, comp, seed, r):
    """Evaluate model performance on each split"""
    metrics = {
            "accuracy": accuracy_score,
        }
    for split_name, (X_, y_) in zip(
        ["trainval", "test"], [(X_train, y_train), (X_val, y_val)]
    ):
        y_pred_ = model.predict(X_)
        if len(y_pred_.shape) > 1 and y_pred_.shape[1] > 1:
            #handle regressors
            y_pred_ = np.argmax(y_pred_, axis=1)
        for i, (metric_name, metric_fn) in enumerate(metrics.items()):
            print(metric_fn(y_, y_pred_))
            r[f"{comp}_seed{seed}_{metric_name}_{split_name}"] = metric_fn(y_, y_pred_)

    return r

def load_csvs(path):

    X_train = pd.read_csv(f'{path}/X_trainval.csv', index_col=0)
    X_train_hat = pd.read_csv(f'{path}/X_trainval_hat.csv', index_col=0)
    X_test = pd.read_csv(f'{path}/X_test.csv', index_col=0)
    X_test_hat = pd.read_csv(f'{path}/X_test_hat.csv', index_col=0)
    y_train = pd.read_csv(f'{path}/y_trainval.csv', index_col=0)
    y_train_hat = pd.read_csv(f'{path}/y_trainval_hat.csv', index_col=0)
    y_test = pd.read_csv(f'{path}/y_test.csv', index_col=0)
    y_test_hat = pd.read_csv(f'{path}/y_test_hat.csv', index_col=0)

    return X_train, X_train_hat, X_test, X_test_hat, y_train, y_train_hat, y_test, y_test_hat

def find_optimal_threshold(y_true, y_probs):
    precisions, recalls, thresholds = precision_recall_curve(y_true, y_probs)
    f1_scores = 2 * (precisions * recalls) / (precisions + recalls)
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    return optimal_threshold

def find_thresh(linkage_matrix, min_clusters=10, max_clusters=15, step=0.1, count = 0):
    if count > 3:
        print(max_clusters)
        return find_thresh(linkage_matrix, min_clusters=min_clusters, max_clusters=(max_clusters-5*4)-1, step=step, count = 0)
    threshold = 4.9
    while threshold < 10:
        clusters = fcluster(linkage_matrix, t=threshold, criterion='distance')
        num_clusters = len(set(clusters))
        if min_clusters <= num_clusters <= max_clusters:
            return threshold, num_clusters
        threshold += step
    print('find_thresh recursive call beginning')
    return find_thresh(linkage_matrix, min_clusters=min_clusters, max_clusters=max_clusters+5, step=0.1, count = count+1)
    #return None, 0

def cluster_concepts(X, num_clusters):
    distance_matrix = 1 - X_train_hat.corr().abs()
    linkage_matrix = linkage(squareform(distance_matrix), method='ward')
    
    threshold, _ = find_thresh(linkage_matrix, min_clusters=num_clusters-5, max_clusters=num_clusters, step=0.1)
        
    clusters = fcluster(linkage_matrix, t=threshold, criterion='distance')
    
    feature_groups = {}
    for i, cluster_id in enumerate(clusters):
        feature_groups.setdefault(cluster_id, []).append(distance_matrix.columns[i])
    
    return feature_groups

def process_X(X_train, X_train_hat, X_test, X_test_hat, prepro, num_clusters, thresh=0):
    if prepro == "probs":
        return X_train_hat, X_test_hat, None
    elif prepro == 'cluster':
        f_gs = cluster_concepts(X_train_hat, num_clusters)
        optimal_thresholds = np.zeros(X_train.shape[1])
        
        for k in f_gs.keys():
            idxs = [int(s[1:]) - 1 for s in f_gs[k]]
            optimal_thresholds[idxs] = find_optimal_threshold(X_train[f_gs[k]].values.reshape(-1, ), X_train_hat[f_gs[k]].values.reshape(-1, ))

        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
    elif prepro == 'global':
        f_gs = cluster_concepts(X_train_hat, num_clusters)
        opt_thresh = find_optimal_threshold(X_train.values.reshape(-1, ), X_train_hat.values.reshape(-1, ))
        
        return (X_train_hat > opt_thresh).astype(int), (X_test_hat > opt_thresh).astype(int), f_gs
    elif prepro == 'gpt1':
        f_gs = {1:['c'+str(i) for i in range(1, 5)]+['c'+str(i) for i in range(53, 55)]+['c'+str(i) for i in range(100, 104)],
                2:['c'+str(i) for i in range(5, 11)]+['c'+str(i) for i in range(110, 113)]+['c78'],
                3:['c'+str(i) for i in range(11, 17)]+['c'+str(i) for i in range(26, 32)]+['c'+str(i) for i in range(85, 88)]+['c'+str(i) for i in range(65, 71)]+['c'+str(i) for i in range(104, 110)],
                4:['c'+str(i) for i in range(17, 24)]+['c'+str(i) for i in range(40, 51)]+['c'+str(i) for i in range(24, 26)]+['c'+str(i) for i in range(71, 78)]+['c'+str(i) for i in range(60, 65)],
                5:['c'+str(i) for i in range(32, 38)]+['c'+str(i) for i in range(88, 91)],
                6:['c38', 'c39', 'c51','c52']+['c'+str(i) for i in range(55, 60)],
                7:['c'+str(i) for i in range(97, 100)],
                8:['c'+str(i) for i in range(91, 97)]+['c'+str(i) for i in range(79, 85)]
        }
        
        optimal_thresholds = np.zeros(X_train.shape[1])
        
        for k in f_gs.keys():
            idxs = [int(s[1:]) - 1 for s in f_gs[k]]
            optimal_thresholds[idxs] = find_optimal_threshold(X_train[f_gs[k]].values.reshape(-1, ), X_train_hat[f_gs[k]].values.reshape(-1, ))

        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
    elif prepro == 'gpt2':
        f_gs = {1:['c'+str(i) for i in range(1, 5)]+['c'+str(i) for i in range(53, 55)]+['c32']+['c'+str(i) for i in range(78, 85)],
                2:['c'+str(i) for i in range(5, 11)]+['c'+str(i) for i in range(110, 113)]+['c'+str(i) for i in range(33, 38)]+['c'+str(i) for i in range(88, 91)],
                3:['c'+str(i) for i in range(91, 97)]+['c'+str(i) for i in range(11, 17)]+['c'+str(i) for i in range(26, 32)]+['c'+str(i) for i in range(85, 88)]+['c'+str(i) for i in range(17, 24)]+['c'+str(i) for i in range(60, 65)],
                4:['c'+str(i) for i in range(40, 51)]+['c24', 'c25']+['c'+str(i) for i in range(104, 110)]+['c'+str(i) for i in range(55, 60)]+['c'+str(i) for i in range(65, 78)],
                5:['c38', 'c39', 'c51','c52'],
                6:['c'+str(i) for i in range(100, 104)]+['c'+str(i) for i in range(97, 100)]
        }
        
        optimal_thresholds = np.zeros(X_train.shape[1])
        
        for k in f_gs.keys():
            idxs = [int(s[1:]) - 1 for s in f_gs[k]]
            optimal_thresholds[idxs] = find_optimal_threshold(X_train[f_gs[k]].values.reshape(-1, ), X_train_hat[f_gs[k]].values.reshape(-1, ))

        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
    elif prepro == 'gpt3':
        f_gs = {1:['c'+str(i) for i in range(1, 5)]+['c53', 'c54']+['c'+str(i) for i in range(100, 104)],
                2:['c78', 'c32']+['c'+str(i) for i in range(5, 11)]+['c'+str(i) for i in range(110, 113)]+['c'+str(i) for i in range(88, 91)]+['c'+str(i) for i in range(33, 38)],
                3:['c'+str(i) for i in range(91, 97)]+['c'+str(i) for i in range(11, 24)]+['c'+str(i) for i in range(26, 32)]+['c'+str(i) for i in range(85, 88)]+['c'+str(i) for i in range(55, 78)]+['c'+str(i) for i in range(104, 110)],
                4:['c'+str(i) for i in range(40, 51)]+['c24', 'c25', 'c38', 'c39', 'c51', 'c52'],
                5:['c'+str(i) for i in range(79, 85)],
                6:['c'+str(i) for i in range(97, 100)]
        }
        
        optimal_thresholds = np.zeros(X_train.shape[1])
        
        for k in f_gs.keys():
            idxs = [int(s[1:]) - 1 for s in f_gs[k]]
            optimal_thresholds[idxs] = find_optimal_threshold(X_train[f_gs[k]].values.reshape(-1, ), X_train_hat[f_gs[k]].values.reshape(-1, ))

        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
    
    elif prepro == 'gpt4':
        f_gs = {1:[1,3,5,7,8,11,12,15,16,19,20,21,22,24,26,27,40,41,42,43,46,49,50,53,55,56,60,62,64,65,66,69,70,71,72,77,78,79,80,83,84,85,87,88,90,91,92,93,94,95,102,106,107,108,110],
                2:[4,6,10,13,14,17,18,23,25,29,54,73,76,86,89,97,98,104,105,111],
                3:[2,9,28,30,32,33,34,35,36,37,38,39,47,48,51,52,81,96,99],
                4:[31,44,45,57,58,59,61,63,67,68,74,75,82,100,101,103,109,112]
        }
        
        for k in f_gs.keys():
            f_gs[k] = ['c'+str(i) for i in f_gs[k]]
        
        optimal_thresholds = np.zeros(X_train.shape[1])
        
        for k in f_gs.keys():
            idxs = [int(s[1:]) - 1 for s in f_gs[k]]
            optimal_thresholds[idxs] = find_optimal_threshold(X_train[f_gs[k]].values.reshape(-1, ), X_train_hat[f_gs[k]].values.reshape(-1, ))

        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
        
    elif prepro == 'binary' and thresh > 0:
        f_gs = cluster_concepts(X_train_hat, num_clusters)
        return (X_train_hat > thresh).astype(int), (X_test_hat > thresh).astype(int), f_gs
    else:
        f_gs = cluster_concepts(X_train_hat, num_clusters)
        optimal_thresholds = []
        for class_idx in range(X_train_hat.shape[1]):
            y_true_class = X_train.iloc[:, class_idx]
            y_probs_class = X_train_hat.iloc[:, class_idx]
            optimal_thresholds.append(find_optimal_threshold(y_true_class, y_probs_class))
        optimal_thresholds = np.array(optimal_thresholds)
        
        return (X_train_hat > optimal_thresholds).astype(int), (X_test_hat > optimal_thresholds).astype(int), f_gs
    
def process_y(y_train, y_train_hat, y_test, y_test_hat, prepro):
    if prepro == "probs":
        return softmax(y_train_hat, axis=1), softmax(y_test_hat, axis=1)
    elif prepro == "classes":
        return pd.DataFrame(y_train_hat.idxmax(axis=1).astype(int)), pd.DataFrame(y_test_hat.idxmax(axis=1).astype(int))
    else:
        return y_train_hat, y_test_hat

  def noop(*args, **kwargs):  # type: ignore
2025-01-12 23:57:04.589322: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from urllib3.contrib.pyopenssl import orig_util_SSLContext as SSLContext
  torch.utils._pytree._register_pytree_node(


In [2]:
args_dict = {}
args_dict['task_type'] = 'regression'
args_dict['model_name'] = 'FIGSRegressor'
args_dict['X_type'] = 'binary'
args_dict['thresh'] = 0.5
args_dict['Y_type'] = 'logits'
args_dict['max_rules'] = 90
args_dict['max_trees'] = 60
args_dict['max_depth'] = 3
args_dict['device'] = 'cuda:0'
args_dict['num_clusters'] = 5
args_dict['num_bootstraps'] = 2

args = ARGS(args_dict)
r = {}

In [71]:
def extract_interactions(model):
    """
    Extracts all feature interactions from the FIGS model by parsing through each additive tree.

    Parameters:
        model: A FIGS model containing an attribute `trees_`.
               Each tree is comprised of hierarchically linked `Node` objects.

    Returns:
        interactions: A list of sets, where each set contains the features involved in an interaction.
    """
    interactions = []

    def traverse_tree(node, current_features, current_depth):
        """
        Recursively traverse a tree to collect feature interactions.

        Parameters:
            node: The current `Node` object in the tree.
            current_features: A set of features encountered so far in the current path.
        """
        print(current_features, current_depth)
        if node.left is None and node.right is None:
            interactions.append(current_features)
            return

        # Add the current feature to the set of features for this path
        current_features.add('c' + str(node.feature+1))

        # If the node has children, traverse them
        if node.left is not None:
            traverse_tree(node.left, current_features.copy(), current_depth=current_depth+1)
        if node.right is not None:
            traverse_tree(node.right, current_features.copy(), current_depth=current_depth+1)

    # Loop through each tree in the model
    traverse_tree(model.trees_[0], set(), current_depth=0)
    return interactions
    for tree in model.trees_:
        # Start traversal for each tree
        traverse_tree(tree, set(), current_depth=0)

    return interactions

In [4]:
np.random.seed(0)

all_interactions = {}

for i in range(args.num_bootstraps):
    print(f'bootstrap: {i}')
    X_train, X_train_hat, X_test, X_test_hat, y_train, y_train_hat, y_test, y_test_hat = load_csvs(f'/home/mattyshen/DistillationEdit/data/cub_tabular/seed0_Joint0.01SigmoidModel__Seed1')
    X_train_model, X_test_model, clusters = process_X(X_train, X_train_hat, X_test, X_test_hat, args.X_type, args.num_clusters, args.thresh)
    y_train_model, y_test_model = process_y(y_train, y_train_hat, y_test, y_test_hat, args.Y_type)
    
    cur_bootstrap = pd.concat([X_train_model, y_train_model], axis = 1).sample(X_train_model.shape[0], replace=True)
    
    X_bs = cur_bootstrap.iloc[:, np.arange(0, X_train_model.shape[1])]
    y_bs = cur_bootstrap.iloc[:, np.arange(X_train_model.shape[1], X_train_model.shape[1]+y_train_model.shape[1])]
    
    model = idistill.model.get_model(args.task_type, args.model_name, args)
    r, model = fit_model(model, X_bs, y_bs, None, r)
    
    cur_interactions = extract_interactions(model)
    
    all_interactions = {}
    for inter in cur_interactions:
        if frozenset(inter) not in all_interactions.keys():
            all_interactions[frozenset(inter)] = 1
        else:
            all_interactions[frozenset(inter)] = all_interactions[frozenset(inter)]+1

bootstrap: 0
bootstrap: 1


In [8]:
import matplotlib.pyplot as plt

In [17]:
X_train_model.corr().values[19, 42]

0.8156192569712072

In [6]:
all_interactions

{frozenset({'c37', 'c62', 'c8', 'c94'}): 1,
 frozenset({'c62', 'c69', 'c8', 'c94'}): 1,
 frozenset({'c35', 'c43', 'c8', 'c94'}): 1,
 frozenset({'c30', 'c35', 'c8', 'c94'}): 1,
 frozenset({'c29', 'c65', 'c87', 'c94'}): 1,
 frozenset({'c29', 'c62', 'c65', 'c94'}): 1,
 frozenset({'c20', 'c29', 'c35', 'c94'}): 1,
 frozenset({'c12', 'c29', 'c35', 'c94'}): 1,
 frozenset({'c12', 'c5', 'c6'}): 2,
 frozenset({'c33', 'c5', 'c55'}): 1,
 frozenset({'c37', 'c5', 'c55'}): 1,
 frozenset({'c44', 'c69', 'c73'}): 1,
 frozenset({'c32', 'c44', 'c73'}): 1,
 frozenset({'c30', 'c44', 'c95'}): 2,
 frozenset({'c30', 'c44', 'c86'}): 1,
 frozenset({'c4', 'c52', 'c80'}): 1,
 frozenset({'c4', 'c80', 'c96'}): 1,
 frozenset({'c2', 'c50', 'c80'}): 1,
 frozenset({'c-1', 'c50', 'c80'}): 1,
 frozenset({'c109', 'c87', 'c96'}): 1,
 frozenset({'c109', 'c49', 'c96'}): 1,
 frozenset({'c109', 'c34'}): 2,
 frozenset({'c20', 'c66', 'c94'}): 1,
 frozenset({'c43', 'c66', 'c94'}): 1,
 frozenset({'c112', 'c43', 'c94'}): 1,
 frozens

In [20]:
import pickle

with open(f'/home/mattyshen/DistillationEdit/results/figs_stability/stafigs_seed69_nbootstraps2.p', 'wb') as fp:
    pickle.dump(all_interactions, fp, protocol=pickle.HIGHEST_PROTOCOL)

In [82]:
with open('/home/mattyshen/DistillationEdit/results/figs_stability/stafigs_seed2_nbootstraps20.p', 'rb') as fp:
    data = pickle.load(fp)

In [90]:
sorted([(k, data[k]) for k in data.keys()], key= lambda x: x[1])[::-1]

[(frozenset({'c17', 'c90'}), 17),
 (frozenset({'c109', 'c31'}), 17),
 (frozenset({'c73', 'c75'}), 14),
 (frozenset({'c63', 'c75'}), 12),
 (frozenset({'c5', 'c62', 'c94'}), 12),
 (frozenset({'c44', 'c75'}), 12),
 (frozenset({'c87', 'c89'}), 11),
 (frozenset({'c32', 'c52'}), 11),
 (frozenset({'c77', 'c78'}), 11),
 (frozenset({'c109', 'c87'}), 10),
 (frozenset({'c3', 'c80'}), 10),
 (frozenset({'c109'}), 10),
 (frozenset({'c48', 'c5', 'c94'}), 9),
 (frozenset({'c1', 'c52'}), 8),
 (frozenset({'c12', 'c92'}), 8),
 (frozenset({'c112', 'c5'}), 8),
 (frozenset({'c54', 'c90'}), 7),
 (frozenset({'c56', 'c79'}), 7),
 (frozenset({'c107', 'c83'}), 7),
 (frozenset({'c12', 'c27'}), 7),
 (frozenset({'c103', 'c90'}), 7),
 (frozenset({'c5', 'c58', 'c65'}), 7),
 (frozenset({'c14', 'c29'}), 6),
 (frozenset({'c84'}), 6),
 (frozenset({'c52', 'c79'}), 6),
 (frozenset({'c34', 'c64'}), 6),
 (frozenset({'c14', 'c94'}), 6),
 (frozenset({'c79', 'c80'}), 6),
 (frozenset({'c39', 'c81'}), 5),
 (frozenset({'c12', 'c6'

In [76]:
cur_interactions

[frozenset({'c62', 'c8', 'c94'}),
 frozenset({'c29', 'c65', 'c94'}),
 frozenset({'c35', 'c8', 'c94'}),
 frozenset({'c29', 'c35', 'c94'})]

In [75]:
cur_interactions=list(set(frozenset(item) for item in cur_interactions))

In [61]:
all_interactions = {}
for inter in cur_interactions:
    if inter not in all_interactions.keys():
        all_interactions[inter] = 1
    else:
        all_interactions[frozenset(inter)] +=1

In [62]:
all_interactions

{frozenset({'c62', 'c8', 'c94'}): 1,
 frozenset({'c112', 'c65'}): 1,
 frozenset({'c52', 'c80'}): 1,
 frozenset({'c48', 'c97'}): 1,
 frozenset({'c70', 'c79'}): 1,
 frozenset({'c111', 'c36'}): 1,
 frozenset({'c1'}): 1,
 frozenset({'c58', 'c87'}): 1,
 frozenset({'c110', 'c51'}): 1,
 frozenset({'c25', 'c9'}): 1,
 frozenset({'c2', 'c47'}): 1,
 frozenset({'c32'}): 1,
 frozenset({'c66', 'c94'}): 1,
 frozenset({'c39'}): 1,
 frozenset({'c48', 'c81'}): 1,
 frozenset({'c20', 'c77'}): 1,
 frozenset({'c3', 'c57'}): 1,
 frozenset({'c108', 'c41'}): 1,
 frozenset({'c112', 'c5'}): 1,
 frozenset({'c100', 'c87'}): 1,
 frozenset({'c103', 'c78'}): 1,
 frozenset({'c21', 'c44'}): 1,
 frozenset({'c84', 'c98'}): 1,
 frozenset({'c109', 'c96'}): 1,
 frozenset({'c5', 'c55'}): 1,
 frozenset({'c35', 'c51'}): 1,
 frozenset({'c12', 'c5'}): 1,
 frozenset({'c30', 'c44'}): 1,
 frozenset({'c41', 'c53'}): 1,
 frozenset({'c14', 'c23'}): 1,
 frozenset({'c32', 'c76'}): 1,
 frozenset({'c109'}): 1,
 frozenset({'c81', 'c89'}): 

In [32]:
cur_interactions = extract_interactions(model)

set() 0
{'c94'} 1
{'c94', 'c8'} 2
{'c62', 'c94', 'c8'} 3
{'c62', 'c94', 'c8'} 3
{'c94', 'c8'} 2
{'c35', 'c94', 'c8'} 3
{'c35', 'c94', 'c8'} 3
{'c94'} 1
{'c94', 'c29'} 2
{'c65', 'c94', 'c29'} 3
{'c65', 'c94', 'c29'} 3
{'c94', 'c29'} 2
{'c35', 'c94', 'c29'} 3
{'c35', 'c94', 'c29'} 3
set() 0
{'c5'} 1
{'c12', 'c5'} 2
{'c12', 'c5'} 2
{'c5'} 1
{'c55', 'c5'} 2
{'c55', 'c5'} 2
set() 0
{'c44'} 1
{'c73', 'c44'} 2
{'c73', 'c44'} 2
{'c44'} 1
{'c30', 'c44'} 2
{'c30', 'c44'} 2
set() 0
{'c80'} 1
{'c4', 'c80'} 2
{'c4', 'c80'} 2
{'c80'} 1
{'c50', 'c80'} 2
{'c50', 'c80'} 2
set() 0
{'c109'} 1
{'c96', 'c109'} 2
{'c96', 'c109'} 2
{'c109'} 1
set() 0
{'c94'} 1
{'c66', 'c94'} 2
{'c66', 'c94'} 2
{'c94'} 1
{'c43', 'c94'} 2
{'c43', 'c94'} 2
set() 0
{'c98'} 1
{'c84', 'c98'} 2
{'c84', 'c98'} 2
{'c98'} 1
set() 0
{'c32'} 1
{'c32', 'c12'} 2
{'c32', 'c12'} 2
{'c32'} 1
set() 0
{'c107'} 1
{'c107', 'c17'} 2
{'c107', 'c17'} 2
{'c107'} 1
set() 0
{'c51'} 1
{'c51', 'c110'} 2
{'c51', 'c110'} 2
{'c51'} 1
{'c35', 'c51'} 2
{'c35

In [33]:
model.max_depth

3