In [136]:
X = [
    {"Marital": "Married", "Job": "Stable", "Credit": "Good", "Age": 45, "Salary": 6000},
    {"Marital": "Married", "Job": "Stable", "Credit": "Good", "Age": 38, "Salary": 7000},
    {"Marital": "Married", "Job": "Stable", "Credit": "Good", "Age": 30, "Salary": 3500},
    {"Marital": "Single", "Job": "Unstable", "Credit": "Bad", "Age": 28, "Salary": 3000},
    {"Marital": "Married", "Job": "Unstable", "Credit": "Bad", "Age": 22, "Salary": 2500},
    {"Marital": "Single", "Job": "Unstable", "Credit": "Bad", "Age": 25, "Salary": 2000},
    {"Marital": "Married", "Job": "Stable", "Credit": "Bad", "Age": 35, "Salary": 5000},
    {"Marital": "Single", "Job": "Unstable", "Credit": "Good", "Age": 32, "Salary": 4500},
    {"Marital": "Married", "Job": "Stable", "Credit": "Good", "Age": 27, "Salary": 3500},
    {"Marital": "Single", "Job": "Unstable", "Credit": "Good", "Age": 29, "Salary": 4800},
    {"Marital": "Married", "Job": "Unstable", "Credit": "Bad", "Age": 40, "Salary": 3700},
]

y = [
    "Approved",
    "Approved",
    "Approved",
    "Denied",
    "Denied",
    "Denied",
    "Approved",
    "Approved",
    "Denied",
    "Approved",
    "Denied",
]

In [167]:
from math import log2

def mode(collection):
    freq = dict()
    for el in collection:
        freq[el] = freq.get(el, 0) + 1
    return max(freq, key=freq.get)

def entropy(y):
    total = 0
    for label in set(y):
        prob_label = y.count(label) / len(y)
        total += prob_label * log2(prob_label)
    return -total

def get_splits(X, feature):
    if isinstance(X[0][feature], str):
        return [{
            'feature': feature,
            'operation': '==',
            'value': value
        } for value in set([x[feature] for x in X])]
    elif isinstance(X[0][feature], (int, float)):
        values = sorted(set([x[feature] for x in X]))
        splits = []
        for i in range(len(values) - 1):
            threshold = (values[i] + values[i + 1]) / 2
            splits.append({
                'feature': feature,
                'operation': '>=',
                'value': threshold
            })
        return splits

def ig_and_children(X, y, split, parent_entropy):
    feature, value = split['feature'], split['value']
    left_data, right_data = [], []
    for x, label in zip(X, y):
        if split['operation'] == '==':
            if x[feature] == value:
                left_data.append((x, label))
            else:
                right_data.append((x, label))
        elif split['operation'] == '>=':
            if x[feature] >= value:
                left_data.append((x, label))
            else:
                right_data.append((x, label))
    left_entropy = entropy([label for _, label in left_data])
    right_entropy = entropy([label for _, label in right_data])
    left_w, right_w = len(left_data) / len(y), len(right_data) / len(y)
    ig = parent_entropy - left_entropy * left_w - right_entropy * right_w
    return ig, (left_data, left_entropy), (right_data, right_entropy)

def grow_tree(X, y, parent_entropy=None):
    if len(set(y)) == 1:
        return y[0]
    if parent_entropy is None:
        parent_entropy = entropy(y)
    best_split, max_ig = None, 0
    left_data, right_data = None, None
    left_entropy, right_entropy = 0, 0
    for feature in X[0].keys():
        for split in get_splits(X, feature):
            ig, left, right = ig_and_children(X, y, split, parent_entropy)
            if ig > max_ig:
                best_split, max_ig = split, ig
                left_data, left_entropy = left
                right_data, right_entropy = right
    left_X, left_y = zip(*left_data)
    right_X, right_y = zip(*right_data)
    split_str = f'{best_split["feature"]} {best_split["operation"]} {best_split["value"]}'
    tree = dict()
    tree[split_str] = {
        True: grow_tree(left_X, left_y, left_entropy),
        False: grow_tree(right_X, right_y, right_entropy)
    }
    return tree

def classify(tree, sample):
    if not isinstance(tree, dict):
        return tree

    (split_str, children), = tree.items()
    feature, op, value = split_str.split(" ", 2)
    try:
        value = float(value)
        if value.is_integer():
            value = int(value)
    except ValueError:
        pass

    if op == "==":
        branch = sample[feature] == value
    elif op == ">=":
        branch = sample[feature] >= value
    else:
        raise ValueError(f"Unknown operation: {op}")

    return classify(children[branch], sample)

def decision_tree(X, y, sample):
    tree = grow_tree(X, y)
    print(tree)
    return classify(tree, sample)

decision_tree(X, y, {"Marital": "Married", "Job": "Stable", "Credit": "Good", "Age": 31, "Salary": 5000})

{'Age >= 28.5': {True: {'Credit == Good': {True: 'Approved', False: {'Job == Unstable': {True: 'Denied', False: 'Approved'}}}}, False: 'Denied'}}


'Approved'