# UNDERSTANDING THE GINI INDEX

30/01/2020 - Davide di Nello - ECB S2S Training


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets.widgets import interact, IntSlider, fixed

In [2]:
def w_gini(weight, freq):
    """ 
    Weighted Gini Index for binary classification
    Input: weight - float containing class 1 weight in relative terms [0,1]
           freq - ratio between elements of class 1 and class 2 [0,1]"""
    tot = weight*freq + (1-weight)*(1-freq)
    return 1-(weight*freq/tot)**2 - ((1-weight)*(1-freq)/tot)**2

def w_entropy(weight, freq):
    """ 
    Weighted Entropy for binary classification
    Input: weight - float containing class 1 weight in relative terms [0,1]
           freq - ratio between elements of class 1 and class 2 [0,1]"""
    tot = weight*freq + (1-weight)*(1-freq)
    e = -(weight*freq/tot)*np.log(weight*freq/tot) - ((1-weight)*(1-freq)/tot)*np.log((1-weight)*(1-freq)/tot)
    if np.isnan(e):
        return 0
    else:
        return e
    
def w_accuracy(weight, freq):
    """ 
    Weighted Classification Error for binary classification
    Input: weight - float containing class 1 weight in relative terms [0,1]
           freq - ratio between elements of class 1 and class 2 [0,1]"""
    tot = weight*freq + (1-weight)*(1-freq)
    return min((weight*freq/tot),((1-weight)*(1-freq)/tot))

def wsum(a,w):
    """ 
    Returns the weighted sum of the first two elements of a
    Input: a - vector containing two elements
           w - relative weight of the first element [0,1]"""
    return a[0]*w+a[1]*(1-w)

In [3]:
def interactive_plot(Function,weight,n1,n2,l1,l2):
    """ 
    Defines an interactive plot to visualize the splitting metrics for Tree based methods in a 
    binary classification setting
    Input: Function - a metric function that receives a weight and a class fraction and returns a score
           weight - float containing the relative weight of class 1 [0,1]
           n1 - int containing the number of elements of class 1 in the parent node
           n2 - int containing the number of elements of class 2 in the parent node
           l1 - int containing the number of elements of class 1 in the left child node
           l2 - int containing the number of elements of class 2 in the right child node"""
    f = (np.arange(101))/100
    f1 = Function
    w = weight
    g1 = [f1(w,y) for y in f]
    
    p = [n1,n2]
    cl = [l1,l2]
    cr = [n1-l1,n2-l2]

    f_1 = p[0]/sum(p)
    f_l = cl[0]/sum(cl)
    f_r = cr[0]/sum(cr)

    plt.plot(f,g1)
    plt.plot(f_l, f1(f_l, w), 'rx')
    plt.plot(f_r, f1(f_r, w), 'rx')
    plt.plot(f_l*wsum(cl,w)/wsum(p,w) + f_r*wsum(cr,w)/wsum(p,w), 
             wsum(cl,w)/wsum(p,w)*f1(f_l, w) + wsum(cr,w)/wsum(p,w)*f1(f_r, w) , 'go')
    plt.plot(f_1, f1(f_1, w), 'ro')


In [6]:
# Define interactive elements and their relationships

n1=IntSlider(min=0, max=50, step=1, value=25, description='Parent C1')
n2=IntSlider(min=0, max=50, step=1, value=25, description='Parent C2')
l1=IntSlider(min=0, max=50, step=1, value=10, description='Left C1')
l2=IntSlider(min=0, max=50, step=1, value=10, description='Left C2')


def update_l1_range(*args):
    l1.max = n1.value
def update_l2_range(*args):
    l2.max = n2.value

    
l1.observe(update_l1_range, 'value')
l2.observe(update_l2_range, 'value')

#### See how the value of the metrics change based on the number of elements from the two classes in the parent and children nodes:

In [7]:
interact(interactive_plot, Function=[('Gini', w_gini), ('Classification Error', w_accuracy)],
                           weight = fixed(0.5), n1=n1, n2=n2,l1=l1,l2=l2)

<function __main__.interactive_plot>