In [131]:
from sklearn.datasets import load_iris
import numpy as np
iris = load_iris()
D, y = iris.data, iris.target

In [132]:
print(iris.DESCR)

.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

                    Min  Max   Mean    SD   Class Correlation
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :

In [133]:
def Gini_Impurity (y: list) -> float:
    """
    This function returns the Gini Impurity of the node.

    :param y(np.ndarray): list of classes in node.

    :return gini (float): Gini Impurity.

    """

    x=[y.count(_)/len(y) for _ in range(3)]
    gini=1-x[0]**2-x[1]**2-x[2]**2
    return gini

def split(D: np.ndarray) -> np.ndarray:
    """
    This function does the first split of root_node to L0 and L1.

    :param D(np.ndarray): data.

    :return L(np.nd): list of two elements L01 and L1.

    """

    L0=[]
    L1=[]
    L=[L0,L1]
    for _ in range (len(D)):
        if (D[_][0]>5.84): L1.append(y[_])
        else: L0.append(y[_])
    return L

def cost(L: list, root_node: list) -> float:
    """
    This function estimates the cost of the split function.

    :param L(list): list of two elements L0 and L1
    :param root_node(list): parent node

    :return cost(float): cost of the split.

    """

    cost = len(L[0])*Gini_Impurity(L[0])/(len(L[0])+len(L[1])) + len(L[1])*Gini_Impurity(L[1])/(len(L[0])+len(L[1])) - Gini_Impurity(root_node)
    return(cost)

In [134]:
root_node=y.tolist()
Gini_Impurity(root_node)

0.6666666666666665

In [135]:
cost(split(D),root_node)

-0.17476190476190456