In [120]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Data

#### Define data

In [121]:
df = pd.DataFrame({ 'weight': [10,60,70,20,80,30],
                    'size':['S','L','L','S','L','S']
                     })

df

Unnamed: 0,weight,size
0,10,S
1,60,L
2,70,L
3,20,S
4,80,L
5,30,S


#### Sort data

In [122]:
df = df.sort_values('weight').reset_index(drop=True)
df

Unnamed: 0,weight,size
0,10,S
1,20,S
2,30,S
3,60,L
4,70,L
5,80,L


## Calculate midpoints = max. possible depth

In [123]:
sorted_values = df['weight'].to_numpy()
sorted_values

array([10, 20, 30, 60, 70, 80], dtype=int64)

In [124]:
midpoints = [(sorted_values[i]+sorted_values[i+1])/2 for i in range (len(sorted_values)-1)]
midpoints

[15.0, 25.0, 45.0, 65.0, 75.0]

## Calculate Gini impurities for all midpoints

In [125]:
''' 
Tree stops when both sides of a tree have GINI impurities equal zero
'''

' \nTree stops when the bith sides have GINI impurities equal zero\n'

#### Define the function

In [131]:
def tree(midpoint,df):
    # Split into left and right based on a midpoint
    df_left = df[df['weight']<=midpoint]
    df_right = df[df['weight']>=midpoint]

    print(f'left df: \n {df_left}')
    print(f'right df: \n {df_right}')
    print(f'******')

    # Total number of nodes in leafs
    n_left = len(df_left)
    n_right = len(df_right)

    print(f'Left total number: {n_left}')
    print(f'Right total number: {n_right}')
    print(f'******')

    # Probability of samples belonging to the class at a specific node.
    p_right_S = (df_right['size']=='S').sum()
    p_right_L = (df_right['size']=='L').sum()

    p_left_S = (df_left['size']=='S').sum()
    p_left_L = (df_left['size']=='L').sum()

    print(f'Right-S: {p_right_S}')
    print(f'Right-L: {p_right_L}')

    print(f'Left-S: {p_left_S}')
    print(f'Left-L: {p_left_L}')
    print(f'******')
    
    # Gini impurities
    gini_right = 1-((p_right_L/n_right)**2+(p_right_S/n_right)**2)
    gini_left = 1-((p_left_L/n_left)**2+(p_left_S/n_left)**2)

    print(f'gini_left: {gini_left}')
    print(f'gini_right: {gini_right}')

    return df_left, df_right
    

### Run the algorithms

<img src="pics/gini.png" style="width: 20%;"/>

In [132]:
first_branch_L, first_branch_R  = tree(midpoint=25, df=df)

left df: 
    weight size
0      10    S
1      20    S
right df: 
    weight size
2      30    S
3      60    L
4      70    L
5      80    L
******
Left total number: 2
Right total number: 4
******
Right-S: 1
Right-L: 3
Left-S: 2
Left-L: 0
******
gini_left: 0.0
gini_right: 0.375


In [133]:
second_branch_L, second_branch_R = tree(midpoint=65, df=first_branch_R)

left df: 
    weight size
2      30    S
3      60    L
right df: 
    weight size
4      70    L
5      80    L
******
Left total number: 2
Right total number: 2
******
Right-S: 0
Right-L: 2
Left-S: 1
Left-L: 1
******
gini_left: 0.5
gini_right: 0.0


In [134]:
third_branch_L, third_branch_R = tree(midpoint=35, df=second_branch_L)


left df: 
    weight size
2      30    S
right df: 
    weight size
3      60    L
******
Left total number: 1
Right total number: 1
******
Right-S: 0
Right-L: 1
Left-S: 1
Left-L: 0
******
gini_left: 0.0
gini_right: 0.0


## Results

<img src="pics/diagram.png" style="width: 40%;"/>