# Decision Tree Basics

In [1]:
using Pkg
Pkg.activate(".")
Pkg.add(["StatsBase", "DataFrames", "MLDatasets"])

[32m[1m  Activating[22m[39m project at `~/repos/Decision-Tree-Classifier`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/repos/Decision-Tree-Classifier/Project.toml`
[32m[1m  No Changes[22m[39m to `~/repos/Decision-Tree-Classifier/Manifest.toml`


In [2]:
using MLDatasets, DataFrames, StatsBase

In [3]:
#load iris dataset
iris = Iris()
#add class labels as numbers
str_to_class = Dict{String, Int64}([("Iris-setosa", 0) ,("Iris-versicolor", 1), ("Iris-virginica", 2)])
class_to_str = Dict{Int64, String}([(v, k) for (k, v) in str_to_class])
map(i->(str_to_class[i]), iris.targets.class)

#I merge here the targets into the feature DataFrame for easier filtering later on.
iris.features.class = map(i->(str_to_class[i]), iris.targets.class);

In [4]:
#quick look at the dataset.
iris.features[1:5, :]

Row,sepallength,sepalwidth,petallength,petalwidth,class
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Int64
1,5.1,3.5,1.4,0.2,0
2,4.9,3.0,1.4,0.2,0
3,4.7,3.2,1.3,0.2,0
4,4.6,3.1,1.5,0.2,0
5,5.0,3.6,1.4,0.2,0


In [5]:
abstract type Node end
#Define Tree Nodes

mutable struct ClassifierNode <: Node
    left::Union{Node, Missing}
    right::Union{Node, Missing}
    split::Union{String, Missing}
    depth::Union{Int64, Missing}
    thresh::Union{Float64, Missing}
    gini::Union{Float64, Missing}
    class_count::Union{Dict{Int64, Int64}, Missing}
end

mutable struct DecisionTreeClassifier
    root::ClassifierNode
    impurity_measure::Function
    max_depth::Int64
    min_impurity_decrease::Float64
    min_samples_leaf::Int64
end

In [6]:
ClassifierNode(i::Int64) = ClassifierNode(missing, missing, missing, i, missing, missing, missing)

ClassifierNode

In [7]:
DecisionTreeClassifier(impurity_measure::Function, max_depth::Int64, min_impurity_decrease::Float64, min_samples_leaf::Int64) = DecisionTreeClassifier(ClassifierNode(0),impurity_measure,
    max_depth, min_impurity_decrease, min_samples_leaf)

DecisionTreeClassifier

In [8]:
#counts each class
function class_distribution(data, classes)
    d = Dict{Int64, Int64}()
    for class in classes
        d[class] = count(i->(i==class), data)
    end
    return d
end

class_distribution (generic function with 1 method)

In [9]:
#claculates gini impurity for a class distribution
function gini_impurity(dist)
    nums = [v for v in values(dist)]
    probs = (nums./sum(nums)).^2
    return 1-sum(probs)
end

gini_impurity (generic function with 1 method)

Now that we have set up the helper functions, we have to think about how a decision tree works.
1. Find the best feature to split the values
2. Perform split and add new nodes
3. Perform this recurisvely

So step-by-step this looks as follows:
First we need to iterate through every feature and evaluate splits at different values with our impurity measure (gini). There are multiple ways to choose split values, I used quantiles of a feature. You can also make smarter choices e.g. heuristically or some other statistical stuff (feel free to try). We compare these for every feature and when we found the best, we split the dataset like that and add new nodes on the left and right sides with the appropriate values and call the train function on those nodes (with the splitted data).

Next step is to implement guards when to stop. Max_depth prevents overall growth, min_samples_leaf prevents too many lone leafs (overffitting). One other important hyperparameter is min_impurity_decrease, it makes sure that the gini impurity decreases at least a minimal amount. 

In the end this leaves us with a nice fitted decision tree clasifier.

In [97]:
function train!(classifier_node::ClassifierNode, data::DataFrame, target_name::Symbol,classes::Vector{Int64}, impurity_measure::Function ,maxdepth::Int64 = 5, min_samples_leaf::Int64 = 3, min_impurity_decrease::Float64 = 0.1)
    
    if classifier_node.depth >= maxdepth || size(data)[1] <= min_samples_leaf
        #turn node into leaf
        #classifier_node = Leaf(classifier_node)
        return
    end
    #initialization
    feature_names = [name for name in names(data) if name != String(target_name)]
    
    best_feature = ""
    best_split_val = Inf
    #initial distribution of classes for node
    init_dist = class_distribution(data[!, target_name], classes)

    #GINIs
    parent_gini = impurity_measure(init_dist)
    best_gini = impurity_measure(init_dist)

    
    #set for the current node
    classifier_node.class_count = copy(init_dist)
    classifier_node.gini = copy(best_gini)
    
    for feature in feature_names
            values = sort(data, [Symbol(feature)])
            #calculate the initial distributio
            splitvals = [quantile(values[!, Symbol(feature)], t) for t in 0.01:0.01:0.99]
            for splitval in splitvals
                prediction_mask = values[!, Symbol(feature)] .< splitval
                subset = data[prediction_mask, :]
                new_dist = class_distribution(subset[!, target_name], classes)
                curr_gini = impurity_measure(new_dist)            
                if (curr_gini <= best_gini) && (size(subset)[1] >= min_samples_leaf)
                    #println("New_best_gini:", curr_gini)
                    best_gini = copy(curr_gini)
                    best_split_val = copy(splitval)
                    best_feature = feature
                end
            end
        
    end
    
    if best_feature != ""

        # Update classifier_node with the best split information
       
                
        # Create left child node4
    
        if (abs(best_gini-parent_gini) >= min_impurity_decrease)
            
            classifier_node.split = best_feature
            classifier_node.thresh = best_split_val
           
            classifier_node.left = ClassifierNode(classifier_node.depth + 1)
            classifier_node.left.gini = copy(best_gini)
           
            left_mask = data[!, best_feature] .< best_split_val
            left_data = data[left_mask, :]
            classifier_node.left.class_count = class_distribution(left_data[!, target_name], classes)
            
            train!(classifier_node.left, left_data, target_name,classes, impurity_measure, maxdepth, min_samples_leaf)
        
            # Create right child node
            right_data = data[.!left_mask, :]
            right_dist = class_distribution(right_data[!, target_name], classes)
            
            right_gini =  impurity_measure(right_dist)
        
            #println(parent_gini)
            #println(right_gini)
           
            classifier_node.right = ClassifierNode(classifier_node.depth + 1)            
            classifier_node.right.gini = copy(right_gini)
            classifier_node.right.class_count = right_dist
            train!(classifier_node.right, right_data, target_name,classes, impurity_measure, maxdepth, min_samples_leaf)
            
        else return end
    else return end
end


train! (generic function with 5 methods)

In [98]:
function train!(tree::DecisionTreeClassifier, data::DataFrame, target_name::Symbol, classes::Vector{Int64})
    train!(tree.root, data, target_name,classes, tree.impurity_measure, tree.max_depth, tree.min_samples_leaf, tree.min_impurity_decrease)
end

train! (generic function with 5 methods)

In [99]:
t = DecisionTreeClassifier(gini_impurity, 2, 0.1, 10)
train!(t, iris.features, :class, [0, 1, 2])

In [100]:
function predict(DT::DecisionTreeClassifier, x::Union{DataFrame, DataFrameRow}, cstr)
    curr_node = DT.root
    while !ismissing(curr_node.thresh)
        if x[curr_node.split] < curr_node.thresh
            curr_node = curr_node.left
        else
            curr_node = curr_node.right
        end
    end
    return cstr[findprob(curr_node.class_count)]
end
    

predict (generic function with 2 methods)

In [101]:
function findprob(d)
    m = 0
    mk = 0
    for k in keys(d)
        if d[k]>m
            m = d[k]
            mk = k
        end
    end
    return mk
end

findprob (generic function with 1 method)

In [102]:
predict(t, iris.features[140, :], class_to_str)

"Iris-virginica"

In [103]:
#just a helper function to show the tree structure
#approximately....
function traverse(node::ClassifierNode, spacing)
    
    #operations on current node
    dp = 15
    p = repeat(" ", spacing)
    
    
    if !ismissing(node.thresh)
        println("$(p)|$(node.split) < $(node.thresh)")
        println("$(p)|class distribution: $(values(node.class_count))")
        println("$(p)-----------------------------------------------")
    else
        l = length("$(p)Predicted Class: $(  class_to_str[findprob(node.class_count)])")
        #p2 = repeat(" ", spacing+12)
        
        println("$(p)Predicted Class: $(  class_to_str[findprob(node.class_count)])")
        println("$(p)|class distribution: $(values(node.class_count))")
    end
    
    if !ismissing(node.left)
        for i in 1:5
            np = repeat(" ", spacing-i)
            pp = repeat(" ", spacing+2*i)
            if !ismissing(node.right)
               println("$(np)/$(pp)\\")
            else
               println("$(np)")
            end
                
        end
        traverse(node.left, spacing-dp )
    end
    if !ismissing(node.right)
         for i in 1:5
            np = repeat(" ", 2*spacing+i+8)
            println("$(np)\\")
        end
        traverse(node.right, 2*spacing-dp)
    end
end      

traverse (generic function with 1 method)

In [104]:
traverse(t.root, 30)

                              |petalwidth < 1.0
                              |class distribution: [50, 50, 50]
                              -----------------------------------------------
                             /                                \
                            /                                  \
                           /                                    \
                          /                                      \
                         /                                        \
               Predicted Class: Iris-setosa
               |class distribution: [50, 0, 0]
                                                                     \
                                                                      \
                                                                       \
                                                                        \
                                                                         \
                        

# Conclusion
We have seen a really straightforward method to implement a decision tree and more importantly how to fit it to numerical data for classification. We also tested it on the iris dataset, so we know it works alright. 

