-
Notifications
You must be signed in to change notification settings - Fork 0
/
forest.go
71 lines (56 loc) · 1.53 KB
/
forest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package iforestgo
import (
"bytes"
"encoding/gob"
"errors"
"math"
"math/rand"
)
type Value interface {
float32 | float64
}
type Forest[V Value] struct {
Trees []*Tree[V]
SubSamplingSize int
rand *rand.Rand
InputDimesion int
}
var ErrSubSamplingSizeToolarge = errors.New("the requested sub-sampling size exceeds the total number of samples in the input data")
func NewForest[V Value](X [][]V, nTrees int, subSamplingSize int, seed int64) (*Forest[V], error) {
if len(X) < subSamplingSize {
return nil, ErrSubSamplingSizeToolarge
}
r := rand.New((rand.NewSource(seed)))
forest := Forest[V]{
Trees: make([]*Tree[V], nTrees),
SubSamplingSize: subSamplingSize,
rand: r,
InputDimesion: len(X[0]),
}
for i := 0; i < nTrees; i++ {
sampleIdxs := r.Perm(len(X))[:subSamplingSize]
forest.Trees[i] = NewTree(&X, sampleIdxs, forest.rand)
}
return &forest, nil
}
func (f *Forest[V]) CalculateAnomalyScore(x []V) float64 {
var sumPathLength float64
for _, t := range f.Trees {
sumPathLength += PathLength[V](x, t)
}
avgPath := sumPathLength / float64(len(f.Trees))
avgPathSubSamplingSize := avgPathLength(int(f.SubSamplingSize))
return math.Pow(2, -avgPath/avgPathSubSamplingSize)
}
func (f *Forest[V]) Serialize() (*bytes.Buffer, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(f)
return &buf, err
}
func Deserialize[V Value](b *bytes.Buffer) (Forest[V], error) {
dec := gob.NewDecoder(b)
var f Forest[V]
err := dec.Decode(&f)
return f, err
}