Skip to content

Commit

Permalink
add tree_gam minimal
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Apr 22, 2024
1 parent 82db112 commit 9443821
Showing 1 changed file with 2 additions and 17 deletions.
19 changes: 2 additions & 17 deletions imodels/algebraic/tree_gam_minimal.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from copy import deepcopy
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.linear_model import ElasticNetCV, LinearRegression, RidgeCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y
from sklearn.utils.validation import _check_sample_weight
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.utils.validation import check_X_y, check_is_fitted, _check_sample_weight
from sklearn.model_selection import train_test_split
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

import imodels

from sklearn.base import RegressorMixin, ClassifierMixin


class TreeGAMMinimal(BaseEstimator):
"""Tree-based GAM classifier.
Expand All @@ -31,7 +23,6 @@ def __init__(
self,
n_boosting_rounds=100,
max_leaf_nodes=3,
reg_param=0.0,
learning_rate: float = 0.01,
boosting_strategy="cyclic",
validation_frac=0.15,
Expand All @@ -44,8 +35,6 @@ def __init__(
Number of boosting rounds for the cyclic boosting.
max_leaf_nodes : int
Maximum number of leaf nodes for the trees in the cyclic boosting.
reg_param : float
Regularization parameter for the cyclic boosting.
learning_rate: float
Learning rate for the cyclic boosting.
boosting_strategy : str ["cyclic", "greedy"]
Expand All @@ -57,7 +46,6 @@ def __init__(
"""
self.n_boosting_rounds = n_boosting_rounds
self.max_leaf_nodes = max_leaf_nodes
self.reg_param = reg_param
self.learning_rate = learning_rate
self.boosting_strategy = boosting_strategy
self.validation_frac = validation_frac
Expand Down Expand Up @@ -129,9 +117,6 @@ def _cyclic_boost(
)
if not succesfully_split_on_feature:
continue
if self.reg_param > 0:
est = imodels.HSTreeRegressor(
est, reg_param=self.reg_param)
self.estimators_.append(est)
residuals_train_new = (
residuals_train - self.learning_rate * est.predict(X_train)
Expand Down

0 comments on commit 9443821

Please sign in to comment.