In [1]:
using DecisionTree
using Random
using Statistics

In [2]:
Random.seed!(42)
n_samples = 100

100

In [3]:
X = rand(n_samples, 2)

100×2 Matrix{Float64}:
 0.173575     0.937288
 0.321662     0.732393
 0.258585     0.0484837
 0.166439     0.699958
 0.527015     0.117331
 0.483022     0.638571
 0.390663     0.957774
 0.802763     0.817224
 0.980758     0.802052
 0.0944314    0.658507
 0.544758     0.647766
 0.433914     0.22013
 0.211228     0.727899
 ⋮            
 0.310797     0.650511
 0.199094     0.233777
 0.529996     0.869046
 0.000715255  0.642221
 0.95628      0.546373
 0.19808      0.802571
 0.902321     0.590144
 0.624447     0.936078
 0.400373     0.966444
 0.774017     0.918565
 0.280243     0.65797
 0.0369673    0.172211

In [4]:
y = [X[i,1] + X[i,2] < 1 ? "A" : "B" for i in 1:n_samples]

100-element Vector{String}:
 "B"
 "B"
 "A"
 "A"
 "A"
 "B"
 "B"
 "B"
 "B"
 "A"
 "B"
 "A"
 "A"
 ⋮
 "A"
 "A"
 "B"
 "A"
 "B"
 "B"
 "B"
 "B"
 "B"
 "B"
 "A"
 "A"

In [5]:
shuffle_idx = shuffle(1:n_samples)
train_size = round(Int, 0.8 * n_samples)

80

In [6]:
train_idx = shuffle_idx[1:train_size]
test_idx = shuffle_idx[train_size+1:end]

20-element Vector{Int64}:
 78
  8
 99
 24
 62
 75
  3
 50
 80
 14
 92
  4
 51
 47
 71
 85
 65
 72
 68
  6

In [7]:
X_train, y_train = X[train_idx, :], y[train_idx]
X_test, y_test = X[test_idx, :], y[test_idx]

([0.8268177665896753 0.38896967499662816; 0.802762551279973 0.8172244613027728; … ; 0.980814224191682 0.18229747129540086; 0.48302213696845187 0.6385712905749883], ["B", "B", "A", "B", "B", "A", "A", "B", "B", "A", "A", "A", "A", "B", "A", "A", "B", "B", "B", "B"])

In [8]:
model = DecisionTreeClassifier(max_depth=3)
fit!(model, X_train, y_train)

DecisionTreeClassifier
max_depth:                3
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  ["A", "B"]
root:                     Decision Tree
Leaves: 6
Depth:  3

In [9]:
y_pred = predict(model, X_test)

20-element Vector{String}:
 "B"
 "B"
 "A"
 "B"
 "B"
 "A"
 "A"
 "B"
 "B"
 "A"
 "A"
 "A"
 "A"
 "A"
 "A"
 "A"
 "B"
 "B"
 "B"
 "A"

In [10]:
accuracy = mean(y_pred .== y_test)
println("テストデータの精度: $accuracy")

テストデータの精度: 0.9


In [11]:
print_tree(model)

Feature 2 < 0.7301 ?
├─ Feature 1 < 0.7509 ?
    ├─ Feature 2 < 0.6259 ?
        ├─ A : 28/28
        └─ A : 7/9
    └─ Feature 2 < 0.1171 ?
        ├─ A : 1/1
        └─ B : 14/14
└─ Feature 1 < 0.1313 ?
    ├─ A : 2/2
    └─ B : 26/26
