The next step after implementing decision trees is to use an ensemble of them. A random forest contains multiple decision trees. Th training data is bootstrapped from the starting dataset and no tree is fitted using all features, but only a subset of those too. So lets go on and implement these things step by step.

In [None]:
using Pkg;
Pkg.activate(".")

In [1]:
using MLDatasets, DataFrames, Random, StatsBase

In [2]:
include("DecisionTrees.jl")

Main.MyDecisionTree

In [3]:
using .MyDecisionTree: DecisionTreeClassifier, calculate_prob, traverse, train!, gini_impurity

In [4]:
iris = Iris()
str_to_class = Dict{String, Int64}([("Iris-setosa", 0) ,("Iris-versicolor", 1), ("Iris-virginica", 2)])
class_to_str = Dict{Int64, String}([(v, k) for (k, v) in str_to_class])
map(i->(str_to_class[i]), iris.targets.class)

#I merge here the targets into the feature DataFrame for easier filtering later on.
iris.features.class = map(i->(str_to_class[i]), iris.targets.class);

In [5]:
#
function bootstrap_subset(data::DataFrame, n_sub_features::Int64, n_samples::Int64=0, target_name::String="class")
    N = size(data)[1]
    feature_names = [name for name in names(data) if name != String(target_name)]
    
    if n_sub_features >= N
        prinln("Too much subfeatures defined, going to use N-1")
        n_sub_features = N -1
    end
    
    if n_samples == 0
        n_samples = N
    end
    subset_features = sample(feature_names, n_sub_features; replace=false)
    push!(subset_features, target_name)
    subset_idxs = rand(1:N, n_samples)
    subset = data[subset_idxs, subset_features]
    return subset
end

bootstrap_subset (generic function with 3 methods)

In [6]:
mutable struct RandomForestClassifier
    n_classifiers::Int64
    n_samples::Int64
    n_sub_features::Int64
    trees::Vector{DecisionTreeClassifier}
    tree_constructor_args::Union{Missing, Dict{String, Any}}
end

In [7]:
RandomForestClassifier(n_classifiers::Int64, 
                        n_samples::Int64,
                        n_sub_features::Int64, 
                        tree_constructor_args::Union{Missing, Dict{String, Any}}) = 

RandomForestClassifier(n_classifiers, 
                    n_samples,n_sub_features,
                    [DecisionTreeClassifier(gini_impurity,tree_constructor_args["max_depth"][i], tree_constructor_args["min_samples_leaf"][i], tree_constructor_args["min_impurity_decrease"][i]) for i in 1:n_classifiers], 
                    tree_constructor_args)


RandomForestClassifier(n_classifiers::Int64, 
                            n_samples::Int64,
                                n_sub_features::Int64) = 

RandomForestClassifier(n_classifiers, 
                    n_samples,n_sub_features,
                    [DecisionTreeClassifier(gini_impurity,5,20,1e-2) for i in 1:n_classifiers], missing)



RandomForestClassifier

In [8]:
function Main.train!(tree::DecisionTreeClassifier, data::DataFrame, target_name::Symbol, classes::Vector{Int64})
    train!(tree.root, data, target_name,classes, tree.impurity_measure, tree.max_depth, tree.min_samples_leaf, tree.min_impurity_decrease)
end

In [9]:
function Main.train!(forest::RandomForestClassifier, dataset::DataFrame)
    for t in forest.trees
        subset = bootstrap_subset(dataset, 3, 0, "class")
        feature_names = [name for name in names(subset) if name != "class"]
        t.features = feature_names
        train!(t, subset, :class, [0, 1, 2])
    end
end

In [10]:
n_classifiers = 20

20

In [11]:
args = Dict{String, Any}([("max_depth", rand(3:10, n_classifiers)), 
                                ("min_samples_leaf", rand(5:30, n_classifiers)), 
                                ("min_impurity_decrease", rand(0:0.001:0.5, n_classifiers))])

Dict{String, Any} with 3 entries:
  "min_impurity_decrease" => [0.13, 0.011, 0.041, 0.488, 0.318, 0.07, 0.217, 0.…
  "max_depth"             => [8, 6, 3, 6, 4, 8, 10, 6, 4, 10, 7, 4, 6, 8, 5, 6,…
  "min_samples_leaf"      => [30, 7, 22, 23, 12, 20, 17, 7, 30, 22, 30, 28, 26,…

In [12]:
RFC = RandomForestClassifier(20, 200, 2, args)

RandomForestClassifier(20, 200, 2, DecisionTreeClassifier[DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, missing, missing, 0, missing, missing, missing), missing, Main.MyDecisionTree.gini_impurity, 8, 30, 0.13), DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, missing, missing, 0, missing, missing, missing), missing, Main.MyDecisionTree.gini_impurity, 6, 7, 0.011), DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, missing, missing, 0, missing, missing, missing), missing, Main.MyDecisionTree.gini_impurity, 3, 22, 0.041), DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, missing, missing, 0, missing, missing, missing), missing, Main.MyDecisionTree.gini_impurity, 6, 23, 0.488), DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, missing, missing, 0, missing, missing, missing), missing, Main.MyDecisionTree.gini_impurity, 4, 12, 0.318), DecisionTreeClassifier(Main.MyDecisionTree.ClassifierNode(missing, mi

In [13]:
train!(RFC, iris.features)

In [14]:
function predict_majority(RFC::RandomForestClassifier, x::Union{DataFrame, DataFrameRow}, cstr)
    prediction_counts = Dict{String, Int64}([(cls, 0) for cls in values(cstr)])
    for DT in RFC.trees
        x_sub = x[DT.features]
        curr_node = DT.root
        while !ismissing(curr_node.thresh)
            if x[curr_node.split] < curr_node.thresh
                curr_node = curr_node.left
            else
                curr_node = curr_node.right
            end
        end
        prediction_counts[cstr[calculate_prob(curr_node.class_count)]] += 1
    end
    return collect(keys(prediction_counts))[argmax(collect(values(prediction_counts)))]
end
    

predict_majority (generic function with 1 method)

In [15]:
predict_majority(RFC, iris.features[1, :], class_to_str)

"Iris-setosa"

In [16]:
class_to_str[iris.features[1, :].class]

"Iris-setosa"

In [17]:
function print_rft_tree(RFC::RandomForestClassifier, n::Int64)
    traverse(RFC.trees[n].root, 30, class_to_str)
end

print_rft_tree (generic function with 1 method)

In [28]:
print_rft_tree(RFC, 2)

                              |sepalwidth < 2.4
                              |class distribution: [49, 55, 46]
                              -----------------------------------------------
                             /                                \
                            /                                  \
                           /                                    \
                          /                                      \
                         /                                        \
               Predicted Class: Iris-versicolor
               |class distribution: [1, 2, 5]
                                                                     \
                                                                      \
                                                                       \
                                                                        \
                                                                         \
                     