In [2]:
import numpy as np

class DecisionTreeStump:
    def __init__(self):
        self.threshold = None
        self.left_class = None
        self.right_class = None

    def _gini(self, y):
        impurity = 1.0
        classes = np.unique(y)

        for c in classes:
            p = np.sum(y == c) / len(y)
            impurity -= p ** 2

        return impurity


    def fit(self, X, y):
        X = X.flatten()
        unique_values = np.sort(np.unique(X))

        # midpoint بین مقادیر
        thresholds = (unique_values[:-1] + unique_values[1:]) / 2

        best_gini = float("inf")

        for threshold in thresholds:

            left_mask = X <= threshold
            right_mask = X > threshold

            left_y = y[left_mask]
            right_y = y[right_mask]

            if len(left_y) == 0 or len(right_y) == 0:
                continue

            weighted_gini = (
                len(left_y)/len(y) * self._gini(left_y) +
                len(right_y)/len(y) * self._gini(right_y)
            )

            if weighted_gini < best_gini:
                best_gini = weighted_gini
                self.threshold = threshold
                self.left_class = self._majority(left_y)
                self.right_class = self._majority(right_y)


    def _majority(self, y):
        values, counts = np.unique(y, return_counts=True)
        return values[np.argmax(counts)]


    def predict(self, X):
        X = X.flatten()
        predictions = []

        for x in X:
            if x <= self.threshold:
                predictions.append(self.left_class)
            else:
                predictions.append(self.right_class)

        return np.array(predictions)


# =====================================
#  (ساعت مطالعه → قبولی)
# =====================================

X = np.array([[2], [3], [5], [7], [9], [10]])
y = np.array([0, 0, 0, 1, 1, 1])  # 0 = مردود ، 1 = قبول

model = DecisionTreeStump()
model.fit(X, y)

print("Best Threshold:", model.threshold)

test_data = np.array([[4], [8]])
predictions = model.predict(test_data)

print("Test Data:", test_data.flatten())
print("Predictions:", predictions)


Best Threshold: 6.0
Test Data: [4 8]
Predictions: [0 1]
