# Decision Tree

In [45]:
using CSV, DataFrames, MLDataUtils, DecisionTree

In [46]:
data = CSV.read("./data/data.csv");

In [47]:
data2020 = filter(row -> row.year == 2020, data);

In [48]:
# All the numerical column names
colnames = [
    "acousticness",
    "danceability",
    "duration_ms",
    "energy",
    "explicit",
    "instrumentalness",
    "key",
    "liveness",
    "loudness",
    "mode",
    "speechiness",
    "tempo",
    "valence",
]

X = data2020[:, colnames];

In [49]:
y = Int.(data2020.popularity .> 70);

In [50]:
Xtrain, Xtest = splitobs(X, at = 0.7);
ytrain, ytest = splitobs(y, at = 0.7);

Xtrain = convert(Array{Float64}, Xtrain)
Xtest = convert(Array{Float64}, Xtest);

In [51]:
n = length(y)
n_train = length(ytrain)
n_test = length(ytest);

In [52]:
model = DecisionTreeClassifier(max_depth=4)
fit!(model, Xtrain, ytrain)
print_tree(model, 10)

Feature 2, Threshold 0.6635
L-> Feature 1, Threshold 0.6315
    L-> Feature 3, Threshold 239976.5
        L-> Feature 8, Threshold 0.591
            L-> 0 : 235/334
            R-> 1 : 8/10
        R-> Feature 5, Threshold 0.5
            L-> 0 : 57/59
            R-> 0 : 18/26
    R-> Feature 7, Threshold 7.5
        L-> Feature 13, Threshold 0.429
            L-> 0 : 51/58
            R-> 0 : 21/21
        R-> 0 : 35/35
R-> Feature 9, Threshold -5.838
    L-> Feature 12, Threshold 92.01750000000001
        L-> Feature 1, Threshold 0.14
            L-> 0 : 31/33
            R-> 0 : 15/25
        R-> Feature 9, Threshold -12.942499999999999
            L-> 0 : 10/10
            R-> 0 : 187/322
    R-> Feature 12, Threshold 123.995
        L-> Feature 11, Threshold 0.207
            L-> 1 : 88/135
            R-> 0 : 11/16
        R-> Feature 12, Threshold 163.527
            L-> 0 : 76/114
            R-> 1 : 21/31


In [53]:
yhat_train = DecisionTree.predict(model, Xtrain)
error_train = (n - (sum(yhat_train .== ytrain))) / n

0.5079726651480638

In [54]:
yhat_test = DecisionTree.predict(model, Xtest)
error_test = (n - (sum(ypred .== yhat_test))) / n

0.6998861047835991