### Implement Gini Impurity Calculation for a Set of Classes

# Understanding Gini Impurity

Gini impurity is a statistical measurement of the impurity or disorder in a list of elements. It is commonly used in decision tree algorithms to decide the optimal split at tree nodes. It is calculated as follows, where $ p_i $ is the probability of each class - $ \frac{n_i}{n} $:

$ 
\text{Gini Impurity} = 1 - \sum_{i=1}^{C} p_i^2 
$

A Gini impurity of 0 indicates a node where all elements belong to the same class, whereas a Gini impurity of 0.5 indicates maximum impurity, where elements are evenly distributed among each class. This means that a lower impurity implies a more homogeneous distribution of elements, suggesting a good split, as decision trees aim to minimize it at each node.

## Advantages and Limitations

**Advantages:**

- Computationally efficient
- Works for binary and multi-class classification

**Limitations:**

- Biased toward larger classes
- May cause overfitting in deep decision trees

## Example Calculation

Suppose we have the set: [0, 1, 1, 1, 0]. The probability of each class is calculated as follows:

$ 
p_0 = \frac{2}{5} \quad p_1 = \frac{3}{5} 
$

The Gini Impurity is then calculated as follows:

$ 
\text{Gini Impurity} = 1 - (p_0^2 + p_1^2) = 1 - \left( \left( \frac{2}{5} \right)^2 + \left( \frac{3}{5} \right)^2 \right) = 0.48 
$

Your task is to implement a function that calculates the Gini Impurity for a set of classes. Gini impurity is commonly used in decision tree algorithms to measure the impurity or disorder within a node.

Write a function gini_impurity(y) that takes in a list of class labels y and returns the Gini Impurity rounded to three decimal places.

Example:
y = [0, 1, 1, 1, 0]
print(gini_impurity(y))

Expected Output:
0.48

In [5]:
from collections import Counter

def gini_impurity(y: list) -> float:
  """
  Calculate Gini impurity for a set of labels.
  
  Args:
      y: List of class labels
      
  Returns:
      Gini impurity value between 0 (pure) and 1-1/k (impure)
      where k is the number of classes
  """
  if len(y) == 0:
      raise ValueError("Input array cannot be empty")
      
  # Count frequency of each class
  class_freq = Counter(y)
  
  # Total number of samples
  n_samples = len(y)
  
  # Calculate probability for each class
  probabilities = [freq/n_samples for freq in class_freq.values()]
  
  # Calculate Gini impurity
  return 1 - sum(p**2 for p in probabilities)

In [6]:
y = [0, 0, 0, 0, 1, 1, 1, 1]
print(gini_impurity(y))

0.5


In [7]:
y = [0, 0, 0, 0, 0, 1]
print(gini_impurity(y))

0.2777777777777777


In [8]:
y = [0, 1, 2, 2, 2, 1, 2]
print(gini_impurity(y))

0.5714285714285714
