From 269af6b5083e00022a1e6f072417d556d19379c8 Mon Sep 17 00:00:00 2001 From: meganjkurka <20548431+meganjkurka@users.noreply.github.com> Date: Thu, 19 Nov 2020 15:55:22 -0500 Subject: [PATCH] updating rulefit tutorial to use rulefit estimator (#148) Co-authored-by: Megan Kurka --- best-practices/explainable-models/__init__.py | 0 best-practices/explainable-models/rulefit.py | 500 ------- .../explainable-models/rulefit_analysis.ipynb | 1288 ++--------------- 3 files changed, 141 insertions(+), 1647 deletions(-) delete mode 100644 best-practices/explainable-models/__init__.py delete mode 100644 best-practices/explainable-models/rulefit.py diff --git a/best-practices/explainable-models/__init__.py b/best-practices/explainable-models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/best-practices/explainable-models/rulefit.py b/best-practices/explainable-models/rulefit.py deleted file mode 100644 index b7d459817..000000000 --- a/best-practices/explainable-models/rulefit.py +++ /dev/null @@ -1,500 +0,0 @@ -"""H2O-3 RuleFit""" - -# Contributors: Megan Kurka - megan.kurka@h2o.ai -# Created: February 18th, 2020 -# Last Updated: April 30th, 2020 - -import pandas as pd -import numpy as np -import os -import warnings - -import h2o -from h2o.estimators import H2OGeneralizedLinearEstimator -from h2o.exceptions import H2OValueError -from h2o.tree import H2OTree - -class H2ORuleFit(): - """ - H2O RuleFit - Builds a Distributed RuleFit model on a parsed dataset, for regression or - classification. - :param algorithm: The algorithm to use to generate rules. Options are "DRF", "XGBoost", "GBM" - :param min_rule_len: Minimum length of rules. Defaults to 1. - :param max_rule_len: Maximum length of rules. Defaults to 10. - :param max_num_rules: The maximum number of rules to return. - Defaults to None which means the number of rules is selected by diminishing returns in model deviance. - :param nfolds: Number of folds for K-fold cross-validation. Defaults to 5. - :param seed: Seed for pseudo random number generator. Defaults to -1. - :param tree_params: Additional parameters that can be passed to the tree model. Defaults to None. - :param glm_params: Additional parameters that can be passed to the linear model. Defaults to None. - :returns: a set of rules and coefficients - :examples: - >>> rulefit = H2ORuleFit() - >>> training_data = h2o.import_file("smalldata/gbm_test/titanic.csv", - ... col_types = {'pclass': "enum", 'survived': "enum"}) - >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] - >>> rulefit.train(x=x,y="survived",training_frame=training_data) - >>> rulefit - """ - - def __init__(self, algorithm, - min_rule_len=1, - max_rule_len=10, - max_num_rules=None, - nfolds=5, - seed=-1, - tree_params={}, - glm_params={} - ): - - if algorithm not in ["DRF", "XGBoost", "GBM"]: - raise H2OValueError("{} is not a supported algorithm".format(algorithm)) - self.algorithm = algorithm - self.min_rule_len = min_rule_len - self.max_rule_len = max_rule_len - self.max_num_rules = max_num_rules - self.nfolds = nfolds - self.seed = seed - - if tree_params: - tree_params.pop("model_id", None) - if 'max_depth' in tree_params.keys(): - self.min_rule_len = tree_params.get("max_depth") - self.max_rule_len = tree_params.get("max_depth") - tree_params.pop("max_depth") - warnings.warn('max_depth provided in tree_params - min_rule_len and max_rule_len will be ignored') - if 'nfolds' in tree_params.keys(): - tree_params.pop('nfolds') - warnings.warn('seed provided in tree_params but will be ignored') - if 'seed' in tree_params.keys(): - tree_params.pop('seed') - warnings.warn('seed provided in tree_params but will be ignored') - - - if glm_params: - glm_params.pop("model_id", None) - if 'max_active_predictors' in glm_params.keys(): - self.max_num_rules = glm_params.get("max_active_predictors") - 1 - glm_params.pop("max_active_predictors") - warnings.warn('max_active_predictors provided in glm_params - max_num_rules will be ignored') - if 'nfolds' in glm_params.keys(): - glm_params.pop('nfolds') - warnings.warn('seed provided in glm_params but will be ignored') - if 'seed' in glm_params.keys(): - glm_params.pop('seed') - warnings.warn('seed provided in glm_params but will be ignored') - if 'alpha' in glm_params.keys(): - glm_params.pop('alpha') - warnings.warn('alpha ignored - set to 1 by rulefit') - if 'lambda_' in glm_params.keys(): - glm_params.pop('lambda_') - warnings.warn('lambda_ ignored by rulefit') - - self.tree_params = tree_params - self.glm_params = glm_params - - - def train(self, x=None, y=None, training_frame=None): - """ - Train the rulefit model. - :param x: A list of column names or indices indicating the predictor columns. - :param y: An index or a column name indicating the response column. - :param training_frame: The H2OFrame having the columns indicated by x and y (as well as any - additional columns specified by fold, offset, and weights). - :examples: - >>> rulefit = H2ORuleFit() - >>> training_data = h2o.import_file("smalldata/gbm_test/titanic.csv", - ... col_types = {'pclass': "enum", 'survived': "enum"}) - >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] - >>> rulefit.train(x=x,y="survived",training_frame=training_data) - >>> rulefit - """ - - if (training_frame.type(y) == "enum"): - if training_frame[y].unique().nrow > 2: - family = "multinomial" - raise H2OValueError("multinomial use cases not yet supported") - else: - family = "binomial" - else: - if self.glm_params.get("family") is not None: - family = self.glm_params.get("family") - self.glm_params.pop("family") - else: - family = "gaussian" - - - # Get paths from random forest models - paths_frame = training_frame[y] - depths = range(self.min_rule_len, self.max_rule_len + 1) - tree_models = dict() - for model_idx in range(len(depths)): - - # Train tree models - tree_model = _tree_model(self.algorithm, - depths[model_idx], - self.seed, - model_idx, - self.tree_params - ) - tree_model.train(y = y, x = x, training_frame = training_frame) - tree_models[model_idx] = tree_model - - paths = tree_model.predict_leaf_node_assignment(training_frame) - paths.col_names = ["tree_{0}.{1}".format(str(model_idx), x) for x in paths.col_names] - paths_frame = paths_frame.cbind(paths) - - if self.max_num_rules: - # Train GLM with chosen lambda - glm = H2OGeneralizedLinearEstimator(model_id = "glm.hex", - seed = self.seed, - family = family, - alpha = 1, - max_active_predictors = self.max_num_rules + 1, - **self.glm_params - ) - glm.train(y = y, training_frame=paths_frame) - - else: - # Get optimal lambda - glm = H2OGeneralizedLinearEstimator(model_id = "glm.hex", - nfolds = self.nfolds, - seed = self.seed, - family = family, - alpha = 1, - lambda_search = True, - **self.glm_params - ) - glm.train(y = y, training_frame=paths_frame) - - lambda_ = _get_glm_lambda(glm) - - # Train GLM with chosen lambda - glm = H2OGeneralizedLinearEstimator(model_id = "glm.hex", - seed = self.seed, - family = family, - alpha = 1, - lambda_ = lambda_, - solver = "COORDINATE_DESCENT", - **self.glm_params - ) - glm.train(y = y, training_frame=paths_frame) - - # Get Intercept - intercept = _get_intercept(glm) - - # Get Rules - rule_importance = _get_rules(glm, tree_models, self.algorithm) - - self.intercept = intercept - self.rule_importance = rule_importance - self.glm = glm - self.tree_models = tree_models - - def predict(self, test_data): - """ - Predict on a dataset. - - :param H2OFrame test_data: Data on which to make predictions. - - :returns: A new H2OFrame of predictions. - """ - paths_frame = test_data[0] - for model_idx in self.tree_models.keys(): - - paths = self.tree_models.get(model_idx).predict_leaf_node_assignment(test_data) - paths.col_names = ["tree_{0}.{1}".format(str(model_idx), x) for x in paths.col_names] - paths_frame = paths_frame.cbind(paths) - - paths_frame = paths_frame[1::] - - return self.glm.predict(paths_frame) - - def filter_by_rule(self, test_data, rule): - """ - Returns records that match a provided rule. - - :param H2OFrame test_data: Data on which to find rule assignment. - :param rule: The rule to use. - - :returns: A new H2OFrame of records that match the rule. - """ - family = self.glm.params.get('family').get('actual') - model_idx, tree_num, tree_class, path = _map_column_name(rule, family, self.algorithm) - paths = self.tree_models.get(model_idx).predict_leaf_node_assignment(test_data) - - paths_col = ".".join(rule.split(".")[1:-1]) - paths_path = rule.split(".")[-1] - - return paths[paths[paths_col] == paths_path] - - def coverage_table(self, test_data): - """ - Returns table of coverage per rule - - :param H2OFrame test_data: Data on which to find rule assignment. - - :returns: A new table with rule coefficients plus coverage - """ - rules = self.rule_importance.copy(deep = True) - coverage = [len(self.filter_by_rule(test_data, x)) for x in rules.variable.values] - coverage_percent = [x/len(test_data) for x in coverage] - - rules["coverage_count"] = coverage - rules["coverage_percent"] = coverage_percent - - return rules - - def varimp_plot(self, num_rules = 10): - """ - Generate variable importanec plot of rules - :param num_rules: The number of rule to graph. Defaults to 10. - :examples: - >>> rulefit = H2ORuleFit() - >>> training_data = h2o.import_file("smalldata/gbm_test/titanic.csv", - ... col_types = {'pclass': "enum", 'survived': "enum"}) - >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] - >>> rulefit.train(x=x,y="survived",training_frame=training_data) - >>> rulefit.varimp_plot() - """ - import plotly.graph_objects as go - plot_data = self.rule_importance.copy(deep = True) - if len(plot_data) > num_rules: - plot_data = plot_data.iloc[0:num_rules] - plot_data["color"] = np.where(plot_data.coefficient > 0, 'crimson', 'lightslategray') - plot_data = plot_data.iloc[::-1] - fig = go.Figure([go.Bar(x=plot_data.coefficient, y=plot_data.rule, marker_color = plot_data.color, orientation='h')]) - fig.update_layout(showlegend=False) - return fig - - - def save(self, path): - """ - Save the rulefit model. - :param path: The path to the directory where the models should be saved. - :examples: - >>> rulefit = H2ORuleFit() - >>> training_data = h2o.import_file("smalldata/gbm_test/titanic.csv", - ... col_types = {'pclass': "enum", 'survived': "enum"}) - >>> x = ["age", "sibsp", "parch", "fare", "sex", "pclass"] - >>> rulefit.train(x=x,y="survived",training_frame=training_data) - >>> rulefit.save(dir_path = "/home/user/my_rulefit/") - """ - # save random forest models - for tree_model in self.tree_models.values(): - h2o.save_model(tree_model, path=path) - - # save glm model - h2o.save_model(self.glm, path=path) - - return path - - def load(self, path): - """ - Load the saved rulefit model. - :param path: The path to the rulefit model. - :examples: - >>> rulefit = H2ORuleFit() - >>> rulefit.load(path) - """ - # load GLM model - glm = h2o.load_model(os.path.join(path, 'glm.hex')) - - # load tree models - depths = range(self.min_rule_len, self.max_rule_len + 1) - tree_models = dict() - for model_idx in range(len(depths)): - tree_models[model_idx] = h2o.load_model(os.path.join(path, "tree_{}.hex".format(model_idx))) - - # Get Intercept - intercept = _get_intercept(glm) - - # Get Rules - rule_importance = _get_rules(glm, tree_models, self.algorithm) - - self.intercept = intercept - self.rule_importance = rule_importance - self.glm = glm - self.tree_models = tree_models - - -def _tree_model(algorithm, max_depth, seed, model_idx, tree_params): - - if algorithm == "DRF": - # Train random forest models - from h2o.estimators.random_forest import H2ORandomForestEstimator - model = H2ORandomForestEstimator(seed = seed, - model_id = "tree_{}.hex".format(str(model_idx)), - max_depth = max_depth, - **tree_params - ) - elif algorithm == "GBM": - - from h2o.estimators.gbm import H2OGradientBoostingEstimator - model = H2OGradientBoostingEstimator(seed = seed, - model_id = "tree_{}.hex".format(str(model_idx)), - max_depth = max_depth, - **tree_params - ) - - elif algorithm == "XGBoost": - from h2o.estimators.xgboost import H2OXGBoostEstimator - model = H2OXGBoostEstimator(seed = seed, - model_id = "tree_{}.hex".format(str(model_idx)), - max_depth = max_depth, - **tree_params - ) - - else: - raise H2OValueError("{} algorithm not supported".format(algorithm)) - - return model - - -def _get_glm_lambda(glm): - """ - Get the best GLM lambda by choosing one diminishing returns on explained deviance - """ - r = H2OGeneralizedLinearEstimator.getGLMRegularizationPath(glm) - deviance = r.get('explained_deviance_train') - if len(deviance) < 5: - lambda_index = len(deviance) - 1 - else: - lambda_index = [i*3 for i, x in enumerate(np.diff(np.sign(np.diff(deviance, 2)))) if x != 0 and i > 0][0] - - return r.get('lambdas')[lambda_index] - -def _tree_traverser(node, split_path): - """ - Traverse the tree to get the rules for a specific split_path - """ - rule = [] - splits = [char for char in split_path] - for i in splits: - if i == "R": - if np.isnan(node.threshold): - rule = rule + [{'split_feature': node.split_feature, - 'value': node.right_levels, - 'operator': 'in'}] - else: - rule = rule + [{'split_feature': node.split_feature, - 'value': node.threshold, - 'operator': '>='}] - - node = node.right_child - if i == "L": - if np.isnan(node.threshold): - rule = rule + [{'split_feature': node.split_feature, - 'value': node.left_levels, - 'operator': 'in'}] - - else: - rule = rule + [{'split_feature': node.split_feature, - 'value': node.threshold, - 'operator': '<'}] - - node = node.left_child - consolidated_rules = _consolidate_rules(rule) - consolidated_rules = " AND ".join(consolidated_rules.values()) - return consolidated_rules - -def _consolidate_rules(rules): - """ - Consolidate rules to remove redundancies - """ - rules = [x for x in rules if x.get("value")] - features = set([x.get('split_feature') for x in rules]) - consolidated_rules = {} - for i in features: - feature_rules = [x for x in rules if x.get('split_feature') == i] - if feature_rules[0].get('operator') == 'in': - cleaned_rules = i + " is in " + ", ".join(sum([x.get('value') for x in feature_rules], [])) - else: - cleaned_rules = [] - operators = set([x.get('operator') for x in feature_rules]) - for op in operators: - vals = [x.get('value') for x in feature_rules if x.get('operator') == op] - if '>' in op: - constraint = max(vals) - else: - constraint = min(vals) - cleaned_rules = " and ".join([op + " " + str(round(constraint, 3))]) - cleaned_rules = i + " " + cleaned_rules - consolidated_rules[i] = cleaned_rules - - return consolidated_rules - -def _get_intercept(glm): - """ - Get Intercept from GLM model - """ - family = glm.params.get('family').get('actual') - # Get paths - if family == "multinomial": - intercept = {k: {k1: v1 for k1, v1 in v.items() if k1 == "Intercept"} for k, v in glm.coef().items()} - else: - intercept = {k: v for k, v in glm.coef().items() if k == 'Intercept'} - return intercept - -def _get_rules(glm, tree_models, algorithm): - """ - Get Rules from GLM model - """ - - family = glm.params.get('family').get('actual') - - if family != "multinomial": - coefs = {'coefs_class_0': glm.coef()} - else: - coefs = glm.coef() - - coefs = {k: {k1: v1 for k1, v1 in v.items() if abs(v1) > 0 and k1 != "Intercept"} for k, v in coefs.items()} - - rule_importance = dict() - for k,v in coefs.items(): - rules_pd = pd.DataFrame.from_dict(v, orient = "index").reset_index() - if len(rules_pd) > 0: - rules_pd.columns = ["variable", "coefficient"] - rule_importance[k] = rules_pd - - # Convert paths to rules - for k,v in rule_importance.items(): - class_rules = [] - if len(v) > 0: - for i in v.variable: - model_idx, tree_num, tree_class, path = _map_column_name(i, family, algorithm) - tree = H2OTree(tree_models[model_idx], tree_num, tree_class = tree_class) - class_rules = class_rules + [_tree_traverser(tree.root_node, path)] - - # Add rules and order by absolute coefficient - v["rule"] = class_rules - v["abs_coefficient"] = v["coefficient"].abs() - v = v.loc[v.groupby(["rule"])["abs_coefficient"].idxmax()] - v = v.sort_values(by = "abs_coefficient", ascending = False) - v = v.drop("abs_coefficient", axis = 1) - - rule_importance[k] = v - - if family != "multinomial": - rule_importance = list(rule_importance.values())[0] - - return rule_importance - -def _map_column_name(column_name, family, algorithm): - """ - Take column name from paths frame and return the model_idx, tree_num, tree_class, and path - """ - if family == "binomial": - if algorithm == "XGBoost": - model_idx, tree_num, path = column_name.replace("tree_", "").replace("T", "").split(".") - tree_class = int(0) - else: - model_idx, tree_num, tree_class, path = column_name.replace("tree_", "").replace("T", "").replace("C", "").split(".") - tree_class = int(tree_class) - 1 - else: - model_idx, tree_num, path = column_name.replace("tree_", "").replace("T", "").split(".") - tree_class = None - - return int(model_idx), int(tree_num) - 1, tree_class, path - diff --git a/best-practices/explainable-models/rulefit_analysis.ipynb b/best-practices/explainable-models/rulefit_analysis.ipynb index 978f9de32..56c5221ac 100644 --- a/best-practices/explainable-models/rulefit_analysis.ipynb +++ b/best-practices/explainable-models/rulefit_analysis.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -31,68 +31,67 @@ "Attempting to start a local H2O server...\n", " Java Version: java version \"12.0.2\" 2019-07-16; Java(TM) SE Runtime Environment (build 12.0.2+10); Java HotSpot(TM) 64-Bit Server VM (build 12.0.2+10, mixed mode, sharing)\n", " Starting server from /Users/megankurka/env2/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar\n", - " Ice root: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp1zopzi4z\n", - " JVM stdout: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp1zopzi4z/h2o_megankurka_started_from_python.out\n", - " JVM stderr: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp1zopzi4z/h2o_megankurka_started_from_python.err\n", + " Ice root: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp9tuw0k38\n", + " JVM stdout: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp9tuw0k38/h2o_megankurka_started_from_python.out\n", + " JVM stderr: /var/folders/fk/z2fjbsq163scfcsq9fhsw7r00000gn/T/tmp9tuw0k38/h2o_megankurka_started_from_python.err\n", " Server is running at http://127.0.0.1:54321\n", - "Connecting to H2O server at http://127.0.0.1:54321 ... successful.\n", - "Warning: Your H2O cluster version is too old (4 months and 9 days)! Please download and install the latest version from http://h2o.ai/download/\n" + "Connecting to H2O server at http://127.0.0.1:54321 ... successful.\n" ] }, { "data": { "text/html": [ - "
\n", + "
H2O cluster uptime:
\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "
H2O_cluster_uptime:01 secs
H2O cluster timezone:
H2O_cluster_timezone:America/New_York
H2O data parsing timezone:
H2O_data_parsing_timezone:UTC
H2O cluster version:3.28.0.4
H2O cluster version age:4 months and 9 days !!!
H2O cluster name:H2O_from_python_megankurka_n9klxg
H2O cluster total nodes:
H2O_cluster_version:3.32.0.2
H2O_cluster_version_age:1 day
H2O_cluster_name:H2O_from_python_megankurka_gr81uf
H2O_cluster_total_nodes:1
H2O cluster free memory:
H2O_cluster_free_memory:4 Gb
H2O cluster total cores:
H2O_cluster_total_cores:16
H2O cluster allowed cores:
H2O_cluster_allowed_cores:16
H2O cluster status:
H2O_cluster_status:accepting new members, healthy
H2O connection url:
H2O_connection_url:http://127.0.0.1:54321
H2O connection proxy:{'http': None, 'https': None}
H2O internal security:
H2O_connection_proxy:{\"http\": null, \"https\": null}
H2O_internal_security:False
H2O API Extensions:
H2O_API_Extensions:Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python version:
Python_version:3.6.8 final
" ], "text/plain": [ "-------------------------- ------------------------------------------------------------------\n", - "H2O cluster uptime: 01 secs\n", - "H2O cluster timezone: America/New_York\n", - "H2O data parsing timezone: UTC\n", - "H2O cluster version: 3.28.0.4\n", - "H2O cluster version age: 4 months and 9 days !!!\n", - "H2O cluster name: H2O_from_python_megankurka_n9klxg\n", - "H2O cluster total nodes: 1\n", - "H2O cluster free memory: 4 Gb\n", - "H2O cluster total cores: 16\n", - "H2O cluster allowed cores: 16\n", - "H2O cluster status: accepting new members, healthy\n", - "H2O connection url: http://127.0.0.1:54321\n", - "H2O connection proxy: {'http': None, 'https': None}\n", - "H2O internal security: False\n", - "H2O API Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4\n", - "Python version: 3.6.8 final\n", + "H2O_cluster_uptime: 01 secs\n", + "H2O_cluster_timezone: America/New_York\n", + "H2O_data_parsing_timezone: UTC\n", + "H2O_cluster_version: 3.32.0.2\n", + "H2O_cluster_version_age: 1 day\n", + "H2O_cluster_name: H2O_from_python_megankurka_gr81uf\n", + "H2O_cluster_total_nodes: 1\n", + "H2O_cluster_free_memory: 4 Gb\n", + "H2O_cluster_total_cores: 16\n", + "H2O_cluster_allowed_cores: 16\n", + "H2O_cluster_status: accepting new members, healthy\n", + "H2O_connection_url: http://127.0.0.1:54321\n", + "H2O_connection_proxy: {\"http\": null, \"https\": null}\n", + "H2O_internal_security: False\n", + "H2O_API_Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4\n", + "Python_version: 3.6.8 final\n", "-------------------------- ------------------------------------------------------------------" ] }, @@ -107,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -146,7 +145,7 @@ "data": { "text/plain": [] }, - "execution_count": 28, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -173,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -182,1111 +181,105 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", - "glm Model Build progress: |███████████████████████████████████████████████| 100%\n", - "glm Model Build progress: |███████████████████████████████████████████████| 100%\n" + "rulefit Model Build progress: |███████████████████████████████████████████| 100%\n" ] } ], "source": [ - "from rulefit import H2ORuleFit\n", - "\n", "x = [\"age\", \"sibsp\", \"parch\", \"fare\", \"sex\", \"pclass\"]\n", - "rulefit_model = H2ORuleFit(algorithm=\"DRF\", seed=1234)\n", - "rulefit_model.train(training_frame=train, x=x, y=\"survived\")" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Intercept: -0.2731588568\n", - "\n", - "\n", - "\n", - "Coefficient:-0.556416351820959\n", - "Rule: age >= 8.536 AND sex is in male\n", - "\n", - "\n", - "Coefficient:0.497239476714762\n", - "Rule: pclass is in 1, 2 AND sex is in female\n", - "\n", - "\n", - "Coefficient:0.324584211702948\n", - "Rule: sex is in female AND sibsp < 2.5\n", - "\n", - "\n", - "Coefficient:-0.211297942964276\n", - "Rule: age >= 9.569 AND sex is in male AND pclass is in 2, 3\n", - "\n", - "\n", - "Coefficient:0.19296700901211\n", - "Rule: sex is in female AND pclass is in 1, 2\n", - "\n", - "\n", - "Coefficient:-0.137666293866152\n", - "Rule: age >= 5.141 AND sex is in male AND fare < 52.033 AND pclass is in 2, 3\n", - "\n", - "\n" - ] - } - ], - "source": [ - "print(\"Intercept: \" + str(round(rulefit_model.intercept.get(\"Intercept\"), 10)))\n", - "print(\"\\n\\n\")\n", "\n", - "rules = rulefit_model.rule_importance\n", - "for i in range(len(rules)):\n", - " print(\"Coefficient:\" + str(round(rules.iloc[i][\"coefficient\"], 15)) \n", - " + \"\\nRule: \" + rules.iloc[i][\"rule\"] + \"\\n\\n\")" + "from h2o.estimators import H2ORuleFitEstimator\n", + "rfit = H2ORuleFitEstimator(seed=1234)\n", + "rfit.train(training_frame=train, x=x, y=\"survived\")" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 31, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.plotly.v1+json": { - "config": { - "plotlyServerURL": "https://plot.ly" - }, - "data": [ - { - "marker": { - "color": [ - "lightslategray", - "crimson", - "lightslategray", - "crimson", - "crimson", - "lightslategray" - ] - }, - "orientation": "h", - "type": "bar", - "x": [ - -0.13766629386615237, - 0.1929670090121102, - -0.2112979429642755, - 0.324584211702948, - 0.49723947671476165, - -0.5564163518209589 - ], - "y": [ - "age >= 5.141 AND sex is in male AND fare < 52.033 AND pclass is in 2, 3", - "sex is in female AND pclass is in 1, 2", - "age >= 9.569 AND sex is in male AND pclass is in 2, 3", - "sex is in female AND sibsp < 2.5", - "pclass is in 1, 2 AND sex is in female", - "age >= 8.536 AND sex is in male" - ] - } - ], - "layout": { - "showlegend": false, - "template": { - "data": { - "bar": [ - { - "error_x": { - "color": "#2a3f5f" - }, - "error_y": { - "color": "#2a3f5f" - }, - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - } - }, - "type": "bar" - } - ], - "barpolar": [ - { - "marker": { - "line": { - "color": "#E5ECF6", - "width": 0.5 - } - }, - "type": "barpolar" - } - ], - "carpet": [ - { - "aaxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "baxis": { - "endlinecolor": "#2a3f5f", - "gridcolor": "white", - "linecolor": "white", - "minorgridcolor": "white", - "startlinecolor": "#2a3f5f" - }, - "type": "carpet" - } - ], - "choropleth": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "choropleth" - } - ], - "contour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "contour" - } - ], - "contourcarpet": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "contourcarpet" - } - ], - "heatmap": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmap" - } - ], - "heatmapgl": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "heatmapgl" - } - ], - "histogram": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "histogram" - } - ], - "histogram2d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2d" - } - ], - "histogram2dcontour": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "histogram2dcontour" - } - ], - "mesh3d": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "type": "mesh3d" - } - ], - "parcoords": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "parcoords" - } - ], - "pie": [ - { - "automargin": true, - "type": "pie" - } - ], - "scatter": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter" - } - ], - "scatter3d": [ - { - "line": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatter3d" - } - ], - "scattercarpet": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattercarpet" - } - ], - "scattergeo": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergeo" - } - ], - "scattergl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattergl" - } - ], - "scattermapbox": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scattermapbox" - } - ], - "scatterpolar": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolar" - } - ], - "scatterpolargl": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterpolargl" - } - ], - "scatterternary": [ - { - "marker": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "type": "scatterternary" - } - ], - "surface": [ - { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - }, - "colorscale": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "type": "surface" - } - ], - "table": [ - { - "cells": { - "fill": { - "color": "#EBF0F8" - }, - "line": { - "color": "white" - } - }, - "header": { - "fill": { - "color": "#C8D4E3" - }, - "line": { - "color": "white" - } - }, - "type": "table" - } - ] - }, - "layout": { - "annotationdefaults": { - "arrowcolor": "#2a3f5f", - "arrowhead": 0, - "arrowwidth": 1 - }, - "coloraxis": { - "colorbar": { - "outlinewidth": 0, - "ticks": "" - } - }, - "colorscale": { - "diverging": [ - [ - 0, - "#8e0152" - ], - [ - 0.1, - "#c51b7d" - ], - [ - 0.2, - "#de77ae" - ], - [ - 0.3, - "#f1b6da" - ], - [ - 0.4, - "#fde0ef" - ], - [ - 0.5, - "#f7f7f7" - ], - [ - 0.6, - "#e6f5d0" - ], - [ - 0.7, - "#b8e186" - ], - [ - 0.8, - "#7fbc41" - ], - [ - 0.9, - "#4d9221" - ], - [ - 1, - "#276419" - ] - ], - "sequential": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ], - "sequentialminus": [ - [ - 0, - "#0d0887" - ], - [ - 0.1111111111111111, - "#46039f" - ], - [ - 0.2222222222222222, - "#7201a8" - ], - [ - 0.3333333333333333, - "#9c179e" - ], - [ - 0.4444444444444444, - "#bd3786" - ], - [ - 0.5555555555555556, - "#d8576b" - ], - [ - 0.6666666666666666, - "#ed7953" - ], - [ - 0.7777777777777778, - "#fb9f3a" - ], - [ - 0.8888888888888888, - "#fdca26" - ], - [ - 1, - "#f0f921" - ] - ] - }, - "colorway": [ - "#636efa", - "#EF553B", - "#00cc96", - "#ab63fa", - "#FFA15A", - "#19d3f3", - "#FF6692", - "#B6E880", - "#FF97FF", - "#FECB52" - ], - "font": { - "color": "#2a3f5f" - }, - "geo": { - "bgcolor": "white", - "lakecolor": "white", - "landcolor": "#E5ECF6", - "showlakes": true, - "showland": true, - "subunitcolor": "white" - }, - "hoverlabel": { - "align": "left" - }, - "hovermode": "closest", - "mapbox": { - "style": "light" - }, - "paper_bgcolor": "white", - "plot_bgcolor": "#E5ECF6", - "polar": { - "angularaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "radialaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "scene": { - "xaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "yaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - }, - "zaxis": { - "backgroundcolor": "#E5ECF6", - "gridcolor": "white", - "gridwidth": 2, - "linecolor": "white", - "showbackground": true, - "ticks": "", - "zerolinecolor": "white" - } - }, - "shapedefaults": { - "line": { - "color": "#2a3f5f" - } - }, - "ternary": { - "aaxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "baxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - }, - "bgcolor": "#E5ECF6", - "caxis": { - "gridcolor": "white", - "linecolor": "white", - "ticks": "" - } - }, - "title": { - "x": 0.05 - }, - "xaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - }, - "yaxis": { - "automargin": true, - "gridcolor": "white", - "linecolor": "white", - "ticks": "", - "title": { - "standoff": 15 - }, - "zerolinecolor": "white", - "zerolinewidth": 2 - } - } - } - } - }, - "text/html": [ - "
\n", - " \n", - " \n", - "
\n", - " \n", - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "rulefit_model.varimp_plot()" + "rule_importance = rfit.rule_importance()\n", + "(table, nr, is_pandas) = rule_importance._as_show_table()" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", - "\n", "\n", " \n", " \n", " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - "
variablecoefficientrulecoverage_countcoverage_percent
4tree_1.T22.C1.RR-0.556416age >= 8.536 AND sex is in male8050.614973
2tree_1.T2.C1.LL0.497239pclass is in 1, 2 AND sex is in female2500.1909850M0T18N120.766079(pclass in {1, 2}) & (age < 60.49951934814453 or age is NA) & (sex in {female})
3tree_1.T17.C1.LL0.324584sex is in female AND sibsp < 2.54410.3368981M0T44N21-0.441729(sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) & (age >= 9.569365501403809 or age is NA)
1tree_2.T45.C1.RRR-0.211298age >= 9.569 AND sex is in male AND pclass is ...6240.4767002linear.sex.female0.408274
5tree_1.T24.C1.LL0.192967sex is in female AND pclass is in 1, 22500.1909853linear.sex.male-0.171397
0tree_3.T35.C1.RRLR-0.137666age >= 5.141 AND sex is in male AND fare < 52....6150.4698244M0T28N20-0.139191(sex in {male} or sex is NA) & (fare < 51.03279113769531 or fare is NA) & (age >= 6.469268798828125 or age is NA)
\n", - "
" + "" ], "text/plain": [ - " variable coefficient \\\n", - "4 tree_1.T22.C1.RR -0.556416 \n", - "2 tree_1.T2.C1.LL 0.497239 \n", - "3 tree_1.T17.C1.LL 0.324584 \n", - "1 tree_2.T45.C1.RRR -0.211298 \n", - "5 tree_1.T24.C1.LL 0.192967 \n", - "0 tree_3.T35.C1.RRLR -0.137666 \n", - "\n", - " rule coverage_count \\\n", - "4 age >= 8.536 AND sex is in male 805 \n", - "2 pclass is in 1, 2 AND sex is in female 250 \n", - "3 sex is in female AND sibsp < 2.5 441 \n", - "1 age >= 9.569 AND sex is in male AND pclass is ... 624 \n", - "5 sex is in female AND pclass is in 1, 2 250 \n", - "0 age >= 5.141 AND sex is in male AND fare < 52.... 615 \n", - "\n", - " coverage_percent \n", - "4 0.614973 \n", - "2 0.190985 \n", - "3 0.336898 \n", - "1 0.476700 \n", - "5 0.190985 \n", - "0 0.469824 " + "" ] }, - "execution_count": 34, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "rulefit_model.coverage_table(df)" + "from IPython.display import display, HTML\n", + "import pandas as pd\n", + "pd.options.display.max_colwidth = 0\n", + "display(HTML(table.head().to_html()))" ] }, { @@ -1297,13 +290,13 @@ "\n", "**Highest Likelihood of Survival**\n", "\n", - "1. Women in class 1 or 2\n", - "2. Women with 2 siblings + spouses or less\n", + "1. Women\n", + "1. Women in class 1 or 2 who are less than 60 years old\n", "\n", "**Lowest Likelihood of Survival**\n", - "1. Men age 9+\n", - "2. Men age 10+ in class 2 or 3\n", - "3. Men 6+ in class 2 or 3 with fare < $52\n", + "1. Men\n", + "2. Men age 6+ who paid less than 51 dollars for their tickets\n", + "4. Men age 10+ in class 2 or 3\n", "\n", "Note: The rules are additive. That means that if a passenger is described by multiple rules, their probability is added together from those rules.\n", "\n", @@ -1312,14 +305,14 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "glm prediction progress: |████████████████████████████████████████████████| 100%\n" + "rulefit prediction progress: |████████████████████████████████████████████| 100%\n" ] }, { @@ -1330,16 +323,16 @@ " survived predict p0 p1\n", "\n", "\n", - " 0 00.6962650.303735\n", - " 0 00.6962650.303735\n", - " 1 00.6962650.303735\n", - " 1 10.3226470.677353\n", - " 1 10.3226470.677353\n", - " 0 00.6962650.303735\n", - " 1 00.6962650.303735\n", - " 1 10.3226470.677353\n", - " 1 00.6962650.303735\n", - " 1 00.6962650.303735\n", + " 0 00.6294290.370571\n", + " 0 00.6756440.324356\n", + " 1 00.6783550.321645\n", + " 1 10.2938040.706196\n", + " 1 10.3168810.683119\n", + " 0 00.64067 0.35933 \n", + " 1 00.6439830.356017\n", + " 1 10.3201560.679844\n", + " 1 00.6383490.361651\n", + " 1 00.6788330.321167\n", "\n", "" ] @@ -1351,28 +344,28 @@ "data": { "text/plain": [] }, - "execution_count": 35, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "predictions = rulefit_model.predict(test)\n", + "predictions = rfit.predict(test)\n", "predictions = test[\"survived\"].cbind(predictions)\n", "predictions.head()" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "How many times we correctly predicted survived: 75.45%\n", - "How many times we correctly predicted not survived: 82.11%\n" + "How many times we correctly predicted survived: 76.00%\n", + "How many times we correctly predicted not survived: 79.82%\n" ] } ], @@ -1385,14 +378,14 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Accuracy with RuleFit Model: 79.88%\n", + "Accuracy with RuleFit Model: 78.66%\n", "Accuracy with Constant Model: 62.80%\n" ] } @@ -1408,45 +401,46 @@ "source": [ "## Step 4: Customizations\n", "\n", - "In this section, we train a new rulefit model to predict `parch`. We customize the rulefit model so that the distribution used in the tree model/linear model is `poisson` and that the tree model used is XGBoost with LightGBM emulation mode." + "In this section, we train a new rulefit model to predict `parch`. We customize the rulefit model so that the distribution used is `poisson` and that the tree model used is GBM. We also restrict is so that it is limited to 10 rules and the rules cannot be longer than 2." ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", - "glm Model Build progress: |███████████████████████████████████████████████| 100%\n" - ] - } - ], + "outputs": [], "source": [ "x = [\"age\", \"pclass\", \"sibsp\", \"fare\", \"sex\", \"survived\"]\n", - "rulefit_parch_model = H2ORuleFit(algorithm=\"XGBoost\", seed=1234,\n", - " max_num_rules=10,\n", - " tree_params={'tree_method': \"hist\",\n", - " 'grow_policy': \"lossguide\",\n", - " 'distribution': \"poisson\"\n", - " },\n", - " glm_params={'family': \"poisson\"}\n", - " )\n", + "rulefit_parch_model = H2ORuleFitEstimator(algorithm=\"GBM\", seed=1234,\n", + " max_num_rules=10,\n", + " max_rule_length=2,\n", + " distribution=\"poisson\"\n", + " )\n", "rulefit_parch_model.train(training_frame=train, x=x, y=\"parch\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rule_importance = rulefit_parch_model.rule_importance()\n", + "(table, nr, is_pandas) = rule_importance._as_show_table()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "import pandas as pd\n", + "pd.options.display.max_colwidth = 0\n", + "display(HTML(table.head().to_html()))" + ] + }, { "cell_type": "code", "execution_count": 44,