# Decision Tree ID3

A *decision tree* is an acyclic graph that can be used to make descisions about data points. Each branching node of the graph examines a specific feature of the data point. According to some threshold the data point is classified and directed down a speicific branch of the tree.

## The Dataset

The dataset that we will use to develop this algorithm is the balanced scale dataset, found [here](http://archive.ics.uci.edu/ml/datasets/Balance+Scale). The dataset concerns the weights on two sides of a set of balance scales - the weight on each side, and the distance the weight is from the pivot. Each point then is classified as either B (for balanced), or L (for left side heavier), or R (for right side heavier).

The structure of each data point is:

  1. Class - either B, R, or L;
  2. Left-weight - the weight on the left side of the balance;
  3. Left-distance - the distance of the left weight from the pivot;
  4. Right-weight - the weight on the right side of the balance;
  5. Right-distance - the distance of the right weight from the pivot.
    
Whether or not the scale is balanced can be calculated using moments - if (right-weight * right-distance) == (left-weight * left-distance), then the scale is balanced. The goal is will be to create a decision tree that can determine whether or not the scale is balanced.

There are 625 instances - (49 balanced, 288 left, 288 right).

In [1]:
import pandas as pd

data = pd.read_csv("data/balance-scale.csv")

data.head(5) # let's have a look at the data

Unnamed: 0,Class,Left-weight,Left-distance,Right-weight,Right-distance
0,B,1,1,1,1
1,R,1,1,1,2
2,R,1,1,1,3
3,R,1,1,1,4
4,R,1,1,1,5


# The Algorithm

The aim is to create a binary decision tree from the dataset. This will be a binary tree, that at each node "splits" the data into two sets.

Optimising a cost function across all possible trees that we could construct is computationally unfeasible, so we will opt for a local approach - looking for the best possible way to "split" the data into to groups at each stage of the algorithm.

In order to do this we need a measure for what a "good" split is. For the categorical data that we are using, intuitively we would think that a "good" split is one in which the groups that the data is divided into contain mostly similar values (i.e. all the `balanced` data points end up in one group, and all the `right` and `left` leaning data points end up in the other group for example. We can achieve this using a measure of impurity called the Gini Index.

For a colection of data points $m$, with categories $k \in K$, let $p_{k,m}$ be the proportion of data points in $m$ that are category $k$. The the Gini index is calculated as follows:
$$\sum_{k} p_{m,k} (1 - p_{m,k})$$

In [33]:
def calculate_gini_index(sample, category):
    value_counts = sample[category].value_counts()
    total = value_counts.sum()
    return sum([gini_score(count, total) for count in value_counts])
        
def gini_score(count, total):
    proportion = count / total
    return proportion * (1 - proportion)

calculate_gini_index(data, "Class") # the Gini Index of the whole dataset (pretty impure!)

0.5691801599999999

So now all we need to do is find the binary split with the lowest Gini Index. Then we can split the data and then separately consider the two sets that we are left with. Then we can repeat the process for each of them - finding the split with the lowest Gini Index and then dividing the data.

In [41]:
def get_best_split(sample, category):
    features = sample.columns.values[1:]
    best_gini_index = 1
    best_split = ()
    for variable in features:
        p = sample[variable].unique()
        split_options = [(p[i + 1] + p[i]) / 2 for i in range(len(p) - 1)]
        for split in split_options:
            gini = calculate_gini_index_of_split(sample, variable, split, category)
            if gini < best_gini_index:
                best_gini_index = gini
                best_split = (variable, split)
    return best_split + (best_gini_index,)
            
def calculate_gini_index_of_split(sample, variable, split, category):
    split1 = sample[variable] <= split
    split2 = sample[variable] > split
    gini_split1 = calculate_gini_index(sample[split1], category)
    gini_split2 = calculate_gini_index(sample[split2], category)
    n1 = len(split1)
    n2 = len(split2)
    N = n1 + n2
    return (gini_split1 * (n1 / N)) + (gini_split2 * (n2 / N))
            
get_best_split(data, "Class")

('Left-weight', 1.5, 0.4581)

## Building a Decision Tree

Now we have a method for determining the best way to split a sample of the dataset, we need a way of arranging this into a binary tree.

In [100]:
class Tree:
    def __init__(self, data, category):
        self.category = category
        self.num_nodes = 0
        self.root = Node(self, data)
        self.root.buildtree()
        
    def predict(self, datapoint):
        return self.root.predict(datapoint)

class Node:
    def __init__(self, root, data):
        self.root = root
        self.root.num_nodes += 1
        self.data = data
        self.label = None
        
    def buildtree(self):
        self.split_data(self.root.category)
        
    def split_data(self, category):
        if len(self.data[category].unique()) > 1:
            variable, split, split_gini = get_best_split(self.data, category)
            self.condition = (variable, split)
            if not split_gini < self.gini_index():
                self.label = get_most_frequent(self.data, category)
                return
            else:
                self.left = Node(self.root, self.data[self.data[variable] <= split])
                self.right = Node(self.root, self.data[self.data[variable] > split])
                self.left.split_data(self.root.category)
                self.right.split_data(self.root.category)
        else:
            self.label = get_most_frequent(self.data, category)
        
    def gini_index(self):
        return calculate_gini_index(self.data, self.root.category)
    
    def predict(self, datapoint):
        if not self.label:
            test = self.test(datapoint)
            return getattr(self, test).predict(datapoint)
        else:
            return self.label
        
    def test(self, datapoint):
        variable, split = self.condition
        if datapoint[variable].at[0] <= split:
            return "left"
        else:
            return "right"
    
def get_most_frequent(data, category):
    return data[category].mode()[0]
        
tree = Tree(data, "Class")
print("done!")

done!


In [101]:
print(f"The tree has {tree.num_nodes} nodes")

The tree has 447 nodes


## Prediction

Let's try predicting a value that 

In [111]:
new_data = pd.DataFrame(data=[[8,6,8,6]], columns=['Left-weight', 'Left-distance', 'Right-weight', 'Right-distance'])
new_data.head()

Unnamed: 0,Left-weight,Left-distance,Right-weight,Right-distance
0,8,6,8,6


In [112]:
tree.predict(new_data)

'B'