# Training decision trees

Our objective here is to implement the code necessary to train a decision tree.

First, we will define some types that will help us to define the signatures of our methods and functions:

- A `Value` corresponds to a value in a certain attribute for our domain. It can be numeric (e.g., 3.14) or categorical (e.g., "blue").
- `Observation` is a list of `Value`s that represent all the values for all the attributes in a real world observation.
- `Data` corresponds to a list of observations.
- `Labels` is a vector with the labels for our problem. Item in position `i` for this vector (or list) corresponds to the label for observation `i` in our dataset.
- `ScoreFn` is a function that takes a list of labels and returns a numeric score measuring the impurity of the distribution.

In [8]:
from typing import Callable

Value = int | float | str
Observation = list[Value]
Data = list[Observation]
Labels = list[Value]
ScoreFn = Callable[[Labels], float]

## Utility functions
First, let's define some functions that can help you in the process of creating our first classification model.

For now, just remember that they are here, and when you need them you can come back and implement them. Left as they are, the notebook should run fine until you call them.

In [3]:
def unique_counts(values: list[Value]) -> dict[Value, int]:
    """Count how many times each value appears in `values`"""
    return {v : values.count(v) for v in values }

def is_numeric(value: Value) -> bool:
    """Checks if a value is numeric (i.e. a float or an int)"""
    return isinstance(value, int) or isinstance(value, float)

def get_query_fn(column: int, value: Value) -> Callable[[Observation], bool]:
    """
    Create a function that separates observations based on a query.
    The query can be:

    a) categorical: the created function returns true
       iff. the observation has the exact value in the column specified.
    b) continuous: the created function returns true
       iff. the observation has a value smaller than the reference one
       in the column specified.

    Note: consider any column with a numeric value as continuous.
    """
    if is_numeric(value):
        return lambda obs : obs[column] <= value
    else:
        return lambda obs : obs[column] == value


def unique_values(table: list[list[Value]], column_idx: int):
    """Returns a set of the values in the columns of a table."""
    return set(table[column_idx])

def cast_to(value_str: str) -> Value:
    """
    Given a value represented as a string, try to convert it
    to a more specific type (int, float) or fail back to string.
    """
    try:
        return int(value_str)
    except ValueError:
        pass
    try:
        return float(value_str)
    except ValueError:
        pass
    return value_str

from math import log
def log2(x): return log(x) / log(2)

## Representing a decision tree

Let's create a class that represents a decision tree.

We will use the [dataclasses](https://docs.python.org/3/library/dataclasses.html) module from Python to reduce the boilerplate we need to write for this class.

The instances of `Node` will have 5 attributes:

- column: The column of the data we are splitting on.
- value: The value for this column that we use for splitting the rows.
- results: A counter on how many rows we have for each label that reached this node.
- true_branch: In case we are in a decision node, the node that an observation reaches if it answers positively to the query of this node.
- false_branch: In case we are in a decision node, the node that an observation reaches if it answers negatively to the query of this node.

The attributes `column` and `value` define the query we are performing in this decision node. For leaves, those attributes are `None`. We consider queries as seen in the slides, where for continuous columns we perform the query `>= value`, and for categorical columns we perform the query `== value`.

The attribute `results` can be defined only on the leaves of the tree.

`true_branch` and `false_branch` are `None` for leaves.

In [7]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional

@dataclass
class Node:
    column: Optional[int]
    value: Optional[Value]
    results: Optional[dict[Value, int]]
    true_branch: Optional[Node]
    false_branch: Optional[Node]

    def is_leaf(self):
        return self.true_branch is None
    
    @classmethod
    def new_node(cls, column, value, true_branch, false_branch) -> Node:
        """Create a new instance of this class representing a decision node."""
        return cls(column,value,None,true_branch,false_branch)

    @classmethod
    def new_leaf(cls, labels: Labels) -> Node:
        """Create a new instance of this class representing a leaf."""
        results = unique_counts(labels)
        return cls(None,None,results,None,None)

    def print_tree(self, indent=''):
        """Prints to stdout a representation of the tree."""
        # Is this a leaf node?
        if self.results!=None:
            print(self.results)
        else:
            # Print the criteria
            if is_numeric(self.value): #type:ignore
                print(f"{self.column}: >= {self.value}?")
            else:
                print(f"{self.column}: {self.value}?")
            # Print the branches
            print(f"{indent}T->", end="")
            self.true_branch.print_tree(indent+' ') #type:ignore
            print(f"{indent}F->", end="")
            self.false_branch.print_tree(indent+' ') #type:ignore
        
    def follow_tree(self, observation: Observation) -> Node:
        """
        Traverse the (sub)tree by answering the queries, until a leaf is reached.
        
        This method returns the leaf that this observation reaches.
        """

        current = self

        while not current.is_leaf():
            query = get_query_fn(current.column,current.value)
            if query(observation):
                current = current.true_branch
            else:
                current = current.false_branch
        return current



And try a predefined tree with some data:

In [5]:
tree = Node.new_node(
    0, "Red",
    Node.new_leaf(["Grape", "Grape"]),
    Node.new_node(
        1, 3,
        Node.new_leaf(["Apple", "Lemon", "Apple"]),
        Node.new_leaf(["Grape"])
    )
)

tree.print_tree()

0: Red?
T->{'Grape': 2}
F->1: >= 3?
 T->{'Apple': 2, 'Lemon': 1}
 F->{'Grape': 1}


In [6]:
example_data = [
    ["Green", 3],
    ["Yellow", 3],
    ["Red", 1],
    ["Red", 1],
    ["Yellow", 2],
]

for observation in example_data:
    print(tree.follow_tree(observation).results)

{'Apple': 2, 'Lemon': 1}
{'Apple': 2, 'Lemon': 1}
{'Grape': 2}
{'Grape': 2}
{'Apple': 2, 'Lemon': 1}


## Building the tree

Defining the trees manually is a tedious process. So, as good engineers, we will automate this task.

In [22]:
def _iterate_queries(observations : Data) -> Generator[tuple[int, Value],None,None]:
    assert len(observations) > 0, "No data"

    ncols = len(observations[0])
    for col in range(0, ncols):
        for value in unique_values(observations, col):
            yield col, value

def recursive_build_tree(scoref: ScoreFn, observations: Data, labels: Labels) -> Node:
    if not observations:
        return Node.new_leaf([])
    root_imp = scoref(labels)
    if root_imp == 0:
        return Node.new_leaf(labels)
    
    best_query, best_goodness, best_T, best_F = None, 0, None, None
    for col, value in _iterate_queries(observations):
        obs_true, labels_true, obs_false, labels_false = divideset(observations, labels, col, value)
        root = scoref(labels)
        tb = len(obs_true) / len(observations) * scoref(labels_true) 
        fb = len(obs_false) / len(observations) * scoref(labels_false)
        goodness = root - tb - fb
                                  
        if best_query is None or goodness > best_goodness:
                 best_query = col, value
                 best_goodness = goodness
                 best_T = obs_true, labels_true
                 best_F = obs_false, labels_false
    return Node.new_node(*best_query, recursive_build_tree(scoref, *best_T), recursive_build_tree(scoref, *best_F))
        
def divideset(
    observations: Data, labels: Labels, column: int, value: Value
) -> tuple[Data, Labels, Data, Labels]:
    """
    Divides a set on a specific column.
    Can handle numeric or categorical values
    """                              
                                  
    query_fn = get_query_fn(column, value)
    
    observations_true, labels_true, observations_false, labels_false = [], [], [], []

    for obv, label in zip(observations, labels):
        if query_fn(observation):
            observations_true.append(obv)
            labels_true.append(label)
        else:
            observations_false.append(obv)
            labels_false.append(label)

    return observations_true, labels_true, observations_false, labels_false

### Impurity functions

As our building function requires some functions to score the impurity of a node, let's implement the ones that we have seen previously: the gini score and the entropy.

Gini score:

$$
gini(t) = \sum^{K}_{j_1,j_2 = 1 : j_1 \ne j_2} p(j_1|t) p(j_2|t) = 1 - \sum^{K}_{j=1} p(j|t)^2
$$

In [23]:
def gini(labels: Labels) -> float:
    total = len(labels)
    results = unique_counts(labels)
    probs={ label : count/total for label,count in results.items() }
    return 1 - sum(p**2 for p in probs.values())

Entropy:

$$
entropy(t) = -\sum^{K}_{j=1} p(j|t) log_2(p(j|t))
$$

**Note** that we are using the $log_2$ function, not $log$.

In [None]:
def entropy(labels):
    total = len(labels)
    results = unique_counts(labels)
    ...

### Training our first classifier

We are ready to build our first tree from some training data:

| **Color** | **Size** | **Fruit** |
|-----------|----------|-----------|
| Green     | 3        | Apple     |
| Yellow    | 3        | Apple     |
| Red       | 1        | Grape     |
| Red       | 2        | Grape     |
| Yellow    | 2        | Lemon     |

In [None]:
data = [
    ["Green", 3],
    ["Yellow", 3],
    ["Red", 1],
    ["Red", 2],
    ["Yellow", 2],
]

labels = ["Apple", "Apple", "Grape", "Grape", "Lemon"]

tree = recursive_build_tree(gini, data, labels)
tree.print_tree()

## Predicting using our decision tree

Building the decision tree is interesting but it does not help us in the long term.

What we need is some way to predict labels for new observations. Given an observation, we should find which leaf we should consider when deciding the label:

In [None]:
for observation in data:
    print(observation, "->", tree.follow_tree(observation))

## A DIY classifier.

Finally, we will encapsulate the training and usage of the decision tree in a classs.

This class (the machine learning model) can be trained (aka fit) with some data, and then we can use it to classify new observations.

Our decision tree model will have 3 different parameters:

- Score function: the function that we will use to measure the impurity of the nodes
- Beta: the minimum decrease of impurity required to split a node
- Prune threshold: the maximum decrease of impurity allowed when prunning a tree

The conventions we will follow for this class are:

1. The constructor is dummy, and only stores the parameters for our model.
2. All the parameters in the constructor must have a default, so we can do `DecisionTreeModel()`.
3. The `fit` method can create new attributes on the class, learnt during the training process. We will name those attributes with a final underscore to distinguish them from the parameters set in the constructor. We can use the existence of such attributes to check if the model has been trained.
4. The `predict` method must accept multiple observations at once. It should check if the model has been fitted before (e.g. using [hasattr](https://docs.python.org/3.9/library/functions.html?highlight=hasattr#hasattr) over some attribute learnt in `fit`).

For convenience, we can add a `score` method to our class. This method will tell us the accuracy of our model on some labelled data. The accuracy of a classifier is the ratio of correctly predicted labels over the total number of observations.



**Note:** the conventions follow closely the API defined by the *de facto* standard of Machine Learning in Python: [scikit-learn](https://scikit-learn.org/stable/developers/develop.html).

In [None]:
from typing import Self

class DecisionTreeModel:
    def __init__(self, scoref: ScoreFn = gini, beta: float = 0, prune_threshold: float = 0):
        self.scoref = scoref
        self.beta = beta
        self.prune_threshold = prune_threshold
    
    def fit(self, observations: Data, labels: Labels) -> Self:
        self.tree_ = ...
        return self
    
    def predict(self, observations: Data) -> Labels:
        ...

    def score(self, data: Data, labels: Labels) -> float:
        predicted = self.predict(data)
        correct = sum(
            1 if pred == expected else 0
            for pred, expected in zip(predicted, labels)
        )
        return correct / len(data)

## Utilities

We have all the components needed to train decision tress. Despite this, we still have to define our data manually in the code. Ideally, we would load this data from a file, a database, etc.

Let's implement a method that reads our training set from a [CSV](https://en.wikipedia.org/wiki/Comma-separated_values) file.

As the file we load will have both the observations and the labels, we also want a method that separates those. We will assume for our case that the labels are always the last column of the CSV file.


In [None]:
def read_csv(file_name: str) -> list[list[Value]]:
    table = []
    with open(file_name) as f:
        for line in f:
            ...    
    return table


def split_observations_and_labels(table: list[list[Value]]) -> tuple[Data, Labels]:
    data, labels = [], []
    for row in table:
        data.append(row[:-1])
        labels.append(row[-1])
    return data, labels

Let's load the data from the example in the slides:

In [None]:
data, labels = split_observations_and_labels(read_csv("decision_tree_example.csv"))

print(data)

print(labels)

And train a decision tree for this dataset:

In [None]:
model = DecisionTreeModel(gini)
model.fit(data, labels)

model.tree_.print_tree()

print("The accuracy of our tree in the training set is", model.score(data, labels))