forked from sjwhitworth/golearn
/
bagging.go
188 lines (175 loc) · 5.07 KB
/
bagging.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package meta
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"math/rand"
"runtime"
"strings"
"sync"
)
// BaggedModel trains base.Classifiers on subsets of the original
// Instances and combine the results through voting
type BaggedModel struct {
base.BaseClassifier
Models []base.Classifier
RandomFeatures int
lock sync.Mutex
selectedAttributes map[int][]base.Attribute
}
// generateTrainingAttrs selects RandomFeatures number of base.Attributes from
// the provided base.Instances.
func (b *BaggedModel) generateTrainingAttrs(model int, from base.FixedDataGrid) []base.Attribute {
ret := make([]base.Attribute, 0)
attrs := base.NonClassAttributes(from)
if b.RandomFeatures == 0 {
ret = attrs
} else {
for {
if len(ret) >= b.RandomFeatures {
break
}
attrIndex := rand.Intn(len(attrs))
attr := attrs[attrIndex]
matched := false
for _, a := range ret {
if a.Equals(attr) {
matched = true
break
}
}
if !matched {
ret = append(ret, attr)
}
}
}
for _, a := range from.AllClassAttributes() {
ret = append(ret, a)
}
b.lock.Lock()
b.selectedAttributes[model] = ret
b.lock.Unlock()
return ret
}
// generatePredictionInstances returns a modified version of the
// requested base.Instances with only the base.Attributes selected
// for training the model.
func (b *BaggedModel) generatePredictionInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
selected := b.selectedAttributes[model]
return base.NewInstancesViewFromAttrs(from, selected)
}
// generateTrainingInstances generates RandomFeatures number of
// attributes and returns a modified version of base.Instances
// for training the model
func (b *BaggedModel) generateTrainingInstances(model int, from base.FixedDataGrid) base.FixedDataGrid {
_, rows := from.Size()
insts := base.SampleWithReplacement(from, rows)
selected := b.generateTrainingAttrs(model, from)
return base.NewInstancesViewFromAttrs(insts, selected)
}
// AddModel adds a base.Classifier to the current model
func (b *BaggedModel) AddModel(m base.Classifier) {
b.Models = append(b.Models, m)
}
// Fit generates and trains each model on a randomised subset of
// Instances.
func (b *BaggedModel) Fit(from base.FixedDataGrid) {
var wait sync.WaitGroup
b.selectedAttributes = make(map[int][]base.Attribute)
for i, m := range b.Models {
wait.Add(1)
go func(c base.Classifier, f base.FixedDataGrid, model int) {
l := b.generateTrainingInstances(model, f)
c.Fit(l)
wait.Done()
}(m, from, i)
}
wait.Wait()
}
// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
n := runtime.NumCPU()
// Channel to receive the results as they come in
votes := make(chan base.DataGrid, n)
// Count the votes for each class
voting := make(map[int](map[string]int))
// Create a goroutine to collect the votes
var votingwait sync.WaitGroup
votingwait.Add(1)
go func() {
for { // Need to resolve the voting problem
incoming, ok := <-votes
if ok {
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
// Check if we've seen this class before...
if _, ok := voting[predRow]; !ok {
// If we haven't, create an entry
voting[predRow] = make(map[string]int)
// Continue on the current row
}
voting[predRow][base.GetClass(incoming, predRow)]++
return true, nil
})
} else {
votingwait.Done()
break
}
}
}()
// Create workers to process the predictions
processpipe := make(chan int, n)
var processwait sync.WaitGroup
for i := 0; i < n; i++ {
processwait.Add(1)
go func() {
for {
if i, ok := <-processpipe; ok {
c := b.Models[i]
l := b.generatePredictionInstances(i, from)
v, _ := c.Predict(l)
votes <- v
} else {
processwait.Done()
break
}
}
}()
}
// Send all the models to the workers for prediction
for i := range b.Models {
processpipe <- i
}
close(processpipe) // Finished sending models to be predicted
processwait.Wait() // Predictors all finished processing
close(votes) // Close the vote channel and allow it to drain
votingwait.Wait() // All the votes are in
// Generate the overall consensus
ret := base.GeneratePredictionVector(from)
for i := range voting {
maxClass := ""
maxCount := 0
// Find the most popular class
for c := range voting[i] {
votes := voting[i][c]
if votes > maxCount {
maxClass = c
maxCount = votes
}
}
base.SetClass(ret, i, maxClass)
}
return ret
}
// String returns a human-readable representation of the
// BaggedModel and everything it contains
func (b *BaggedModel) String() string {
children := make([]string, 0)
for i, m := range b.Models {
children = append(children, fmt.Sprintf("%d: %s", i, m))
}
return fmt.Sprintf("BaggedModel(\n%s)", strings.Join(children, "\n\t"))
}