-
Notifications
You must be signed in to change notification settings - Fork 2
/
treeda_dt.jl
123 lines (114 loc) · 3.04 KB
/
treeda_dt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
treeda_dt(X, y; kwargs...)
Discrimination tree (CART) with DecisionTree.jl.
* `X` : X-data (n, p).
* `y` : Univariate class membership (n).
Keyword arguments:
* `n_subfeatures` : Nb. variables to select at random
at each split (default: 0 ==> keep all).
* `max_depth` : Maximum depth of the
decision tree (default: -1 ==> no maximum).
* `min_sample_leaf` : Minimum number of samples
each leaf needs to have.
* `min_sample_split` : Minimum number of observations
in needed for a split.
* `scal` : Boolean. If `true`, each column of `X`
is scaled by its uncorrected standard deviation.
The function fits a single discrimination tree (CART) using
package `DecisionTree.jl'.
## References
Breiman, L., Friedman, J. H., Olshen, R. A., and
Stone, C. J. Classification And Regression Trees.
Chapman & Hall, 1984.
DecisionTree.jl
https://github.com/JuliaAI/DecisionTree.jl
Gey, S., 2002. Bornes de risque, détection de ruptures,
boosting : trois thèmes statistiques autour de CART en
régression (These de doctorat). Paris 11.
http://www.theses.fr/2002PA112245
## Examples
```julia
using JchemoData, JLD2
path_jdat = dirname(dirname(pathof(JchemoData)))
db = joinpath(path_jdat, "data/forages2.jld2")
@load db dat
pnames(dat)
X = dat.X
Y = dat.Y
n, p = size(X)
s = Bool.(Y.test)
Xtrain = rmrow(X, s)
ytrain = rmrow(Y.typ, s)
Xtest = X[s, :]
ytest = Y.typ[s]
ntrain = nro(Xtrain)
ntest = nro(Xtest)
(ntot = n, ntrain, ntest)
tab(ytrain)
tab(ytest)
n_subfeatures = p / 3
max_depth = 10
mod = model(treeda_dt; n_subfeatures, max_depth)
fit!(mod, Xtrain, ytrain)
pnames(mod)
pnames(mod.fm)
fm = mod.fm ;
fm.lev
fm.ni
res = predict(mod, Xtest) ;
pnames(res)
@head res.pred
errp(res.pred, ytest)
conf(res.pred, ytest).cnt
```
"""
## For DA in DecisionTree.jl,
## y must be Int or String
function treeda_dt(X, y::Union{Array{Int}, Array{String}};
kwargs...)
par = recovkwargs(Par, kwargs)
X = ensure_mat(X)
Q = eltype(X)
y = vec(y)
p = nco(X)
taby = tab(y)
xscales = ones(Q, p)
if par.scal
xscales .= colstd(X)
X = fscale(X, xscales)
end
n_subfeatures = Int(
round(par.n_subfeatures))
min_purity_increase = 0
fm = build_tree(y, X,
n_subfeatures,
par.max_depth,
par.min_samples_leaf,
par.min_samples_split,
min_purity_increase;
#rng = Random.GLOBAL_RNG
#rng = 3
)
featur = collect(1:p)
TreedaDt(fm, xscales, featur, taby.keys, taby.vals, kwargs, par)
end
"""
predict(object::TreedaDt, X)
Compute Y-predictions from a fitted model.
* `object` : The fitted model.
* `X` : X-data for which predictions are computed.
"""
function predict(object::TreedaDt, X)
X = ensure_mat(X)
m = nro(X)
## Tree
if pnames(object.fm)[1] == :node
pred = apply_tree(object.fm, fscale(X, object.xscales))
## Forest
else
pred = apply_forest(object.fm, fscale(X, object.xscales);
use_multithreading = object.par.mth)
end
pred = reshape(pred, m, 1)
(pred = pred,)
end