# 计算概率、交叉熵、基尼系数

In [9]:
from collections import Counter
import numpy as np

In [10]:
def get_probs(elements):
    """
    获取概率
    """
    counter = Counter(elements)
    pr = np.array([counter[c] / len(elements) for c in counter])
    
    return pr

In [12]:
def get_entropy(elements):
    """
    计算交叉熵
    """
    pr = get_probs(elements)
    
    return -np.sum(pr * np.log2(pr))

In [13]:
# 测试交叉熵
get_entropy(['a', 'b', 'c','d', 'a', 'b', 'c'])

1.950212064914747

## cart

In [16]:
def gini(elements):
    """
    计算基尼系数
    """
    pr = get_probs(elements)
    return 1 - np.sum(pr ** 2)

In [17]:
def cart_loss(left, right, pure_fn):
    """
    CART树
    """
    m_left, m_right = len(left), len(right)
    m = m_left + m_right
    
    return m_left / m * pure_fn(left) + m_right / m * pure_fn(right)

In [18]:
sales = {
    'gender': ['Female', 'Female', 'Female', 'Female', 'Male', 'Male', 'Male'],
    'income': ['H', 'M', 'H', 'M', 'H', 'H', 'L'],
    'family-number': [1, 1, 2, 1, 1, 1, 2],
    'bought': [1, 1, 1, 0, 0, 0, 1]
}

In [20]:
import pandas as pd
sales_dataset = pd.DataFrame.from_dict(sales)
target = 'bought'

In [39]:
def find_best_split(training_dataset, target):
    dataset = training_dataset
    fields = set(dataset.columns.tolist()) - {target}
    print(fields)
    
    mini_loss = float('inf')
    best_feature, best_split = None, None
    
    for x in fields:
        field_value = dataset[x]
        for v in field_value:
            split_left = dataset[dataset[x] == v][target].tolist()
            split_right = dataset[dataset[x] != v][target].tolist()
            
            loss = cart_loss(split_left, split_right, pure_fn=gini)
            # ic(x, v, cart_loss(split_left, split_right, pure_fn=gini))
            if loss < mini_loss:
                print(f"best_feature={best_feature}, best_split={best_split}, loss={mini_loss}")
                best_feature, best_split = x, v
                mini_loss = loss
                
    return best_feature, best_split

In [40]:
from icecream import ic
# ic(get_entropy([1, 1]))
# ic(gini([1, 1]))
# ic(get_entropy([0, 0, 0]))
# ic(gini([0, 0, 0]))
# ic(get_entropy([0, 0, 1, 1, 1, 1 ,1, 1]))
# ic(gini([0, 0, 1, 1, 1, 1 ,1, 1]))
# ic(get_entropy([0, 0, 0, 0, 0, 0, 0, 0]))
# ic(gini([0, 0, 0, 0, 0, 0, 0, 0]))
# ic(get_entropy([1, 2, 3, 4, 56, 7, 8, 1, 19]))
# ic(gini([1, 2, 3, 4, 56, 7, 8, 1, 19]))
# ic(get_entropy([1, 2, 3, 4, 65, 76, 87, 32, 21]))
# ic(gini([1, 2, 3, 4, 65, 76, 87, 32, 21]))

In [41]:
ic(find_best_split(sales_dataset, target='bought'))

ic| find_best_split(sales_dataset, target='bought'): ('family-number', 1)


{'income', 'family-number', 'gender'}
best_feature=None, best_split=None, loss=inf
best_feature=income, best_split=H, loss=0.47619047619047616
best_feature=income, best_split=L, loss=0.42857142857142855


('family-number', 1)