In [17]:
import numpy as np
from sklearn.tree import DecisionTreeRegressor

In [24]:
import numpy as np
from sklearn.tree import DecisionTreeRegressor

class GradientBoostingClassifier:
  def __init__(self, n_estimators=100, learning_rate=0.1):
    self.n_estimators = n_estimators
    self.learning_rate = learning_rate
    self.trees = []
    self.initial_log_odd = None

  def _sigmoid(self, x):
    return 1 / (1 + np.exp(-x))

  def fit(self, x, y):
    p = np.mean(y)
    p = np.clip(p, 1e-10, 1 - 1e-10) # Clip p to avoid log(0) or log(inf)
    self.initial_log_odd = np.log(p / (1 - p))

    F = np.full(len(x), self.initial_log_odd, dtype=float)

    for _ in range(self.n_estimators):
      p_hats = self._sigmoid(F)
      residuals = y - p_hats

      tree = DecisionTreeRegressor(max_depth=1)
      tree.fit(x, residuals)

      self.trees.append(tree)


      F += self.learning_rate * tree.predict(x)
    return self

  def predict_proba(self, x):
    F = np.full(x.shape[0], self.initial_log_odd, dtype=float)

    for tree in self.trees:
      F += self.learning_rate * tree.predict(x)

    probs = self._sigmoid(F)
    return probs

  def predict(self, x):
    probs = self.predict_proba(x)
    return (probs >= 0.5).astype(int)

In [25]:
from sklearn.datasets import make_classification

In [26]:
X,y = make_classification(n_samples=1000,n_features=10,random_state=42)
print(X.shape)
print(y.shape)

(1000, 10)
(1000,)


In [27]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=42)


In [28]:
Gb = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1)
Gb.fit(X_train, y_train)
y_pred = Gb.predict(X_test)


In [29]:
from sklearn.metrics import accuracy_score
print(accuracy_score(y_test, y_pred))

0.84


In [30]:
from sklearn.ensemble import GradientBoostingClassifier
gd = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1)
gd.fit(X_train, y_train)
y_pred = gd.predict(X_test)
print(accuracy_score(y_test,y_pred))

0.9
