/
svm.ts
144 lines (123 loc) · 3.58 KB
/
svm.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import * as sdk from 'botpress/sdk'
import _ from 'lodash'
const binding = require('./svm-js/index.js')
export const DefaultTrainArgs: Partial<sdk.MLToolkit.SVM.SVMOptions> = {
c: [0.1, 1, 2, 5, 10, 20, 100],
classifier: 'C_SVC',
gamma: [0.01, 0.1, 0.25, 0.5, 0.75],
kernel: 'LINEAR'
}
export class Trainer implements sdk.MLToolkit.SVM.Trainer {
private clf: any
private labels: string[] = []
private model?: any
private report?: any
constructor(options: Partial<sdk.MLToolkit.SVM.SVMOptions> = DefaultTrainArgs) {
const args = { ...DefaultTrainArgs, ...options }
this.clf = new binding.SVM({
svmType: args.classifier,
kernelType: args.kernel,
c: args.c,
gamma: args.gamma,
reduce: false,
probability: true,
kFold: 4
})
}
async train(
points: sdk.MLToolkit.SVM.DataPoint[],
callback?: sdk.MLToolkit.SVM.TrainProgressCallback | undefined,
options?: Partial<sdk.MLToolkit.SVM.SVMOptions>
): Promise<string> {
if (options) {
const args = { ...DefaultTrainArgs, ...options }
this.clf = new binding.SVM({
svmType: args.classifier,
kernelType: args.kernel,
c: args.c,
gamma: args.gamma,
reduce: false,
probability: true,
kFold: 4
})
}
await this._train(points, callback)
return this.serialize()
}
private async _train(
points: sdk.MLToolkit.SVM.DataPoint[],
callback?: sdk.MLToolkit.SVM.TrainProgressCallback | undefined
): Promise<any> {
this.labels = []
return new Promise((resolve, reject) => {
const dataset = points.map(c => [c.coordinates, this.getLabelIdx(c.label)])
this.clf
.train(dataset)
.progress(progress => {
if (callback && typeof callback === 'function') {
callback(progress)
}
})
.spread((trainedModel, report) => {
this.model = trainedModel
this.report = report
resolve()
})
.catch(err => reject(new Error(err)))
})
}
private getLabelIdx(label: string) {
const idx = this.labels.indexOf(label)
if (idx === -1) {
this.labels.push(label)
return this.labels.indexOf(label)
}
return idx
}
isTrained(): boolean {
return !!this.model
}
serialize(): string {
return JSON.stringify({ ...this.model, labels_idx: this.labels }, undefined, 2)
}
}
export class Predictor implements sdk.MLToolkit.SVM.Predictor {
private clf: any
private labels: string[]
constructor(model: string) {
const options = JSON.parse(model)
this.labels = options.labels_idx
delete options.labels_idx
this.clf = binding.restore({ ...options, kFold: 1 })
}
private getLabelByIdx(idx): string {
idx = Math.round(idx)
if (idx < 0) {
throw new Error(`Invalid prediction, prediction must be between 0 and ${this.labels.length}`)
}
return this.labels[idx]
}
async predict(coordinates: number[]): Promise<sdk.MLToolkit.SVM.Prediction[]> {
const results = await this.clf.predictProbabilities(coordinates)
const reducedResults = _.reduce(
Object.keys(results),
(acc, curr) => {
const label = this.getLabelByIdx(curr).replace(/__k__\d+$/, '')
acc[label] = (acc[label] || 0) + results[curr]
return acc
},
{}
)
return _.orderBy(
Object.keys(reducedResults).map(idx => ({ label: idx, confidence: reducedResults[idx] })),
'confidence',
'desc'
)
}
isLoaded(): boolean {
return !!this.clf
}
getLabels(): string[] {
return _.values(this.labels)
}
}