# Decision Tree
## Dataset: Wine Dataset

In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, classification_report

# ignore warnings
import warnings
warnings.filterwarnings("ignore")


In [18]:
# Load Wine dataset
wine = load_wine()
X = wine.data
y = wine.target


In [19]:
# Splitting dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


In [40]:
# Class representing a node in a decision tree.
class DecisionTreeNode:
    """
    A decision tree node.
    
    Parameters
    ----------
    gini : float
        Gini impurity of the node.
    num_samples : int
        Number of samples at the node.
    num_samples_per_class : list
        Number of samples per class at the node.
    predicted_class : int
        Class predicted at the node.
    feature_index : int
        Index of the feature used for splitting.
    threshold : float
        Threshold value at the node used for splitting.
    left : DecisionTreeNode
        Left child node.
    right : DecisionTreeNode
        Right child node.
    """
    def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
        self.gini = gini
        self.num_samples = num_samples
        self.num_samples_per_class = num_samples_per_class
        self.predicted_class = predicted_class
        self.feature_index = 0
        self.threshold = 0
        self.left = None
        self.right = None

# Function to calculate the Gini impurity of a set of labels.
def gini(y):
    m = len(y)
    # Gini impurity formula implementation.
    return 1.0 - sum((np.sum(y == c) / m) ** 2 for c in np.unique(y))

# Function to find the best split for the data.
def best_split(X, y):
    m, n = X.shape
    if m <= 1:
        return None, None

    # Preparing for split calculation.
    unique_classes = np.unique(y)
    num_classes = len(unique_classes)
    class_dict = {c: i for i, c in enumerate(unique_classes)}
    num_parent = [np.sum(y == c) for c in unique_classes]
    best_gini = 1.0 - sum((n / m) ** 2 for n in num_parent)
    best_idx, best_thr = None, None

    # Iterating over all features to find the best split.
    for idx in range(n):
        # Sorting data and labels based on current feature.
        thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
        num_left = [0] * num_classes
        num_right = num_parent.copy()
        # Calculating Gini impurity for each possible split.
        for i in range(1, m):
            c = class_dict[classes[i - 1]]
            num_left[c] += 1
            num_right[c] -= 1
            gini_left = 1.0 - sum((num_left[x] / i) ** 2 for x in range(num_classes))
            gini_right = 1.0 - sum((num_right[x] / (m - i)) ** 2 for x in range(num_classes))
            gini = (i * gini_left + (m - i) * gini_right) / m
            # Skipping equal thresholds.
            if thresholds[i] == thresholds[i - 1]:
                continue
            # Updating best split if a better one is found.
            if gini < best_gini:
                best_gini = gini
                best_idx = idx
                best_thr = (thresholds[i] + thresholds[i - 1]) / 2 

    return best_idx, best_thr

# Function to grow the decision tree recursively.
def grow_tree(X, y, depth=0, max_depth=100):
    # Counting samples per class and choosing the predicted class.
    num_samples_per_class = [np.sum(y == i) for i in np.unique(y)]
    predicted_class = np.argmax(num_samples_per_class)
    # Creating a new tree node.
    node = DecisionTreeNode(
        gini=gini(y), 
        num_samples=len(y), 
        num_samples_per_class=num_samples_per_class, 
        predicted_class=predicted_class,
    )

    # Recursively growing the tree if depth limit is not reached.
    if depth < max_depth:
        idx, thr = best_split(X, y)
        if idx is not None:
            # Splitting the dataset based on the best split.
            indices_left = X[:, idx] < thr
            X_left, y_left = X[indices_left], y[indices_left]
            X_right, y_right = X[~indices_left], y[~indices_left]
            # Assigning split feature and threshold to the node.
            node.feature_index = idx
            node.threshold = thr
            # Recursively creating left and right children.
            node.left = grow_tree(X_left, y_left, depth + 1, max_depth)
            node.right = grow_tree(X_right, y_right, depth + 1, max_depth)
    return node


In [41]:
def predict(sample, node):
    while node.left:
        if sample[node.feature_index] < node.threshold:
            node = node.left
        else:
            node = node.right
    return node.predicted_class


In [42]:
# Assuming X_train and y_train are your data and labels
tree = grow_tree(X_train, y_train, max_depth=5)


In [43]:
# For a single prediction
prediction = predict(X_test[0], tree)

# For all test data
predictions = [predict(x, tree) for x in X_test]

In [44]:
# Show Represenation of Tree
def print_tree(node, depth=0):
    if node is None:
        print("{}LEAF {}".format("\t" * depth, node))
    elif node.left is None and node.right is None:
        print("{}LEAF {}".format("\t" * depth, node.predicted_class))
    else:
        print("{}[X{} < {}]".format("\t" * depth, node.feature_index, node.threshold))
        print_tree(node.left, depth + 1)
        print_tree(node.right, depth + 1)

print_tree(tree)

[X9 < 3.82]
	[X2 < 3.0700000000000003]
		[X11 < 3.8200000000000003]
			LEAF 0
			LEAF 0
		LEAF 0
	[X6 < 1.4]
		LEAF 0
		[X12 < 724.5]
			[X0 < 13.145]
				LEAF 0
				LEAF 0
			LEAF 0


In [45]:
# Show Confusion Matrix

print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))


[[19  0  0]
 [21  0  0]
 [14  0  0]]
              precision    recall  f1-score   support

           0       0.35      1.00      0.52        19
           1       0.00      0.00      0.00        21
           2       0.00      0.00      0.00        14

    accuracy                           0.35        54
   macro avg       0.12      0.33      0.17        54
weighted avg       0.12      0.35      0.18        54

