-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.go
44 lines (37 loc) · 1.2 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package model
import (
"fmt"
"github.com/bobonovski/gotm/corpus"
"github.com/bobonovski/gotm/sstable"
)
var constructors = make(map[string]ModelCtor)
// the common interface new LDA samplers should follow
type Model interface {
// train model for iter iteration
Train(dat *corpus.Corpus, iter int)
// do inference for new doc for iter iteration
Infer(dat *corpus.Corpus, iter int)
// get doc-topic distribution
Phi() *sstable.Float32Matrix
// get word-topic distribution
Theta() *sstable.Float32Matrix
// serialize posterior document topic distribution
SaveTheta(fn string) error
// serialize posterior word topic distribution
SavePhi(fn string) error
// serialize word topic count table
SaveWordTopic(fn string) error
// deserialize word topic count table
LoadWordTopic(fn string) error
}
// new LDA sampler should register itself using this function
func Register(modelType string, m ModelCtor) {
constructors[modelType] = m
}
type ModelCtor func(topicNum uint32, alpha float32, beta float32) Model
func GetModel(modelType string) (ModelCtor, error) {
if _, ok := constructors[modelType]; !ok {
return nil, fmt.Errorf("model %s not registered", modelType)
}
return constructors[modelType], nil
}