<a href="https://colab.research.google.com/github/lkarjun/Data-Science-from-Scratch/blob/master/17%20Decision%20Tree/decisionTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Decision Tree

## Entropy

In [36]:
from typing import List
import math

In [37]:
def entropy(class_probabilities: List[float]) -> float:
    '''given a list of class probabilities, compute the entropy'''
    return sum(-p * math.log(p, 2)
                for p in class_probabilities
                if p > 0)

In [38]:
assert entropy([0.5, 0.5]) == 1

In [39]:
from typing import Any
from collections import Counter

In [40]:
def class_probabilities(labels: List[Any]) -> List[float]:
    total_count = len(labels)
    return [count / total_count
            for count in Counter(labels).values()]

def data_entropy(labels: List[Any]) -> float:
    return entropy(class_probabilities(labels))    

In [41]:
assert data_entropy(['a']) == 0
assert data_entropy([True, True]) == 0.0
assert data_entropy(['yes', 'yes', 'no', 'no']) == 1

## The Entropy of a Partition

In [42]:
def partion_entropy(subsets: List[List[Any]]) -> float:
    '''return the entropy from partition of data into subsets'''
    total_count = sum(len(subset) for subset in subsets)
    return sum(data_entropy(subset) * len(subset) / total_count
              for subset in subsets)

## Creating a Decision Tree

In [43]:
from typing import NamedTuple, Optional

In [44]:
class Candidate(NamedTuple):
    level: str
    lang: str
    tweets: bool
    phd: bool
    did_well: Optional[bool] = None

In [45]:
inputs = [Candidate('Senior', 'Java', False, False, False),
         Candidate('Senior', 'Java', False, True, False),
         Candidate('Mid', 'Python', False, False, True),
         Candidate('Junior', 'Python', False, False, True),
         Candidate('Junior', 'R', True, False, True),
         Candidate('Junior', 'R', True, True, False),
         Candidate('Mid', 'R', True, True, True),
         Candidate('Senior', 'Python', False, False, False),
         Candidate('Senior', 'R', True, False, True),
         Candidate('Junior', 'Python', True, False, True),
         Candidate('Senior', 'Python', True, True, True),
         Candidate('Mid', 'Python', False, True, True),
         Candidate('Mid', 'Java', True, False, True),
         Candidate('Junior', 'Python', False, True, False)]

In [46]:
from typing import Dict, TypeVar
from collections import defaultdict

In [47]:
T = TypeVar("T")

In [51]:
def partition_by(inputs: List[T], attribute: str) -> Dict[Any, List[T]]:
    '''partition the inputs into lists based on the specified attribute.'''
    partitions: Dict[Any, List[T]] = defaultdict(list)
    for input in  inputs:
        key = getattr(input, attribute)
        partitions[key].append(input)
    
    return partitions

In [54]:
def partition_entropy_by(inputs: List[Any],
                         attribute: str,
                         label_attribute: str) -> float:
    '''compute the entropy corresponding to the given partition'''
    partitions = partition_by(inputs, attribute)
    labels = [[getattr(input, label_attribute) for input in partition] for partition in partitions.values()]
    return partion_entropy(labels)

In [55]:
for key in ['level', 'lang', 'tweets', 'phd']:
    print(key, partition_entropy_by(inputs, key, 'did_well'))

level 0.6935361388961919
lang 0.8601317128547441
tweets 0.7884504573082896
phd 0.8921589282623617


In [56]:
senior_inputs = [input for input in inputs if input.level == 'Senior']

In [57]:
partition_entropy_by(senior_inputs, 'lang', 'did_well')

0.4

## Putting it all together    

In [58]:
from typing import NamedTuple, Union, Any

In [61]:
class Leaf(NamedTuple):
    value: Any
        
class Split(NamedTuple):
    attribute: str
    subtrees: dict
    default_value: Any = None
    
DecisionTree = Union[Leaf, Split]

In [63]:
hiring_tree = Split('level', {
        'Junior': Split('phd', {
            False: Leaf(True),
            True: Leaf(False)}),
        'Mid': Leaf(True),
        'Senior': Split('tweets',{
            False: Leaf(False),
            True: Leaf(True)
        })
})

In [64]:
def classify(tree: DecisionTree, input: Any) -> Any:
    '''classify the input using the given decision tree'''
    if isinstance(tree, Leaf):
        return tree.value
    
    subtree_key = getattr(input, tree.attribute)
    
    if subtree_key not in tree.subtrees:
        return tree.default_value
    
    subtree = tree.subtrees[subtree_key]
    return classify(subtree, input)

In [73]:
def build_tree_id3(inputs: List[Any],
                   split_attributes: List[str],
                   target_attribute: str) -> DecisionTree:
    # Count target labels
    label_counts = Counter(getattr(input, target_attribute)
                           for input in inputs)
    most_common_label = label_counts.most_common(1)[0][0]

    # If there's a unique label, predict it
    if len(label_counts) == 1:
        return Leaf(most_common_label)

    # If no split attributes left, return the majority label
    if not split_attributes:
        return Leaf(most_common_label)

    # Otherwise split by the best attribute

    def split_entropy(attribute: str) -> float:
        """Helper function for finding the best attribute"""
        return partition_entropy_by(inputs, attribute, target_attribute)

    best_attribute = min(split_attributes, key=split_entropy)

    partitions = partition_by(inputs, best_attribute)
    new_attributes = [a for a in split_attributes if a != best_attribute]

    # recursively build the subtrees
    subtrees = {attribute_value : build_tree_id3(subset,
                                                 new_attributes,
                                                 target_attribute)
                for attribute_value, subset in partitions.items()}

    return Split(best_attribute, subtrees, default_value=most_common_label)


In [75]:
tree = build_tree_id3(inputs, ['level', 'lang', 'tweets', 'phd'], 'did_well')

In [76]:
classify(tree, Candidate('Junior', 'Java', True, False))

True

In [78]:
classify(tree, Candidate('Intern', 'Java', True, True))

True