Skip to content

Commit

Permalink
fix(nlu): using multiple threads for training
Browse files Browse the repository at this point in the history
  • Loading branch information
allardy committed Aug 24, 2019
1 parent 96d7a2a commit 69822fe
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 12 deletions.
22 changes: 12 additions & 10 deletions modules/nlu/src/backend/pipelines/intents/svm_classifier.ts
Expand Up @@ -93,6 +93,9 @@ export default class SVMClassifier {
}

public async train(intentDefs: sdk.NLU.IntentDefinition[], modelHash: string): Promise<Model[]> {
const svmOptions: Partial<sdk.MLToolkit.SVM.SVMOptions> = { kernel: 'LINEAR', classifier: 'C_SVC' }
const svm = new this.toolkit.SVM.Trainer(svmOptions)

this.realtime.sendPayload(
this.realtimePayload.forAdmins('statusbar.event', getProgressPayload(identityProgress)(0.1))
)
Expand Down Expand Up @@ -204,28 +207,27 @@ export default class SVMClassifier {
}
}

const svm = new this.toolkit.SVM.Trainer({ kernel: 'LINEAR', classifier: 'C_SVC' })

const ratioedProgressForIndex = ratioedProgress(index)

await svm.train(l1Points, progress => {
const updateProgress = progress => {
this.realtime.sendPayload(
this.realtimePayload.forAdmins('statusbar.event', getProgressPayload(ratioedProgressForIndex)(progress))
)
debugTrain('SVM => progress for INT', { context, progress })
})
}

const modelStr = svm.serialize()
const ratioedProgressForIndex = ratioedProgress(index)
const modelStr = await svm.train(l1Points, updateProgress, svmOptions)

models.push({
meta: { context, created_on: Date.now(), hash: modelHash, scope: 'bot', type: 'intent-l1' },
model: new Buffer(modelStr, 'utf8')
})
}

const svm = new this.toolkit.SVM.Trainer({ kernel: 'LINEAR', classifier: 'C_SVC' })
await svm.train(l0Points, progress => debugTrain('SVM => progress for CTX %d', progress))
const ctxModelStr = svm.serialize()
const ctxModelStr = await svm.train(
l0Points,
progress => debugTrain('SVM => progress for CTX %d', progress),
svmOptions
)

this.l1Tfidf = _.mapValues(l1Tfidf, x => x['__avg__'])
this.l0Tfidf = l0Tfidf['__avg__']
Expand Down
8 changes: 8 additions & 0 deletions src/bp/bootstrap.ts
Expand Up @@ -6,6 +6,7 @@ import './common/polyfills'

import sdk from 'botpress/sdk'
import chalk from 'chalk'
import cluster from 'cluster'
import { Botpress, Config, Logger } from 'core/app'
import center from 'core/logger/center'
import { ModuleLoader } from 'core/module-loader'
Expand All @@ -16,6 +17,13 @@ import os from 'os'
import { FatalError } from './errors'

async function start() {
if (cluster.isMaster) {
cluster.fork()
} else {
// The worker doesn't need anything else beside rewire and getos
return
}

const logger = await Logger('Launcher')
logger.info(chalk`========================================
{bold ${center(`Botpress Server`, 40)}}
Expand Down
23 changes: 23 additions & 0 deletions src/bp/ml/svm.ts
Expand Up @@ -30,6 +30,29 @@ export class Trainer implements sdk.MLToolkit.SVM.Trainer {
}

async train(
points: sdk.MLToolkit.SVM.DataPoint[],
callback?: sdk.MLToolkit.SVM.TrainProgressCallback | undefined,
options?: Partial<sdk.MLToolkit.SVM.SVMOptions>
): Promise<string> {
if (options) {
const args = { ...DefaultTrainArgs, ...options }

this.clf = new binding.SVM({
svmType: args.classifier,
kernelType: args.kernel,
c: args.c,
gamma: args.gamma,
reduce: false,
probability: true,
kFold: 4
})
}

await this._train(points, callback)
return this.serialize()
}

private async _train(
points: sdk.MLToolkit.SVM.DataPoint[],
callback?: sdk.MLToolkit.SVM.TrainProgressCallback | undefined
): Promise<any> {
Expand Down
37 changes: 37 additions & 0 deletions src/bp/ml/toolkit.ts
@@ -1,4 +1,5 @@
import * as sdk from 'botpress/sdk'
import cluster from 'cluster'

const { Tagger, Trainer: CRFTrainer } = require('./crfsuite')
import { FastTextModel } from './fasttext'
Expand All @@ -21,4 +22,40 @@ const MLToolkit: typeof sdk.MLToolkit = {
SentencePiece: { createProcessor: processor }
}

if (cluster.isMaster) {
MLToolkit.SVM.Trainer.prototype.train = (
points: sdk.MLToolkit.SVM.DataPoint[],
callback?: sdk.MLToolkit.SVM.TrainProgressCallback | undefined,
options?: Partial<sdk.MLToolkit.SVM.SVMOptions>
): any => {
return Promise.fromCallback(cb => {
const worker = cluster.workers[1]!

const messageHandler = msg => {
if (callback && msg.type === 'progress') {
callback(msg.progress)
}

if (msg.type === 'svm_trained') {
worker.off('message', messageHandler)
cb(undefined, msg.result)
}
}

worker.send({ type: 'svm_train', points, options })
worker.on('message', messageHandler)
})
}
}

if (cluster.isWorker) {
process.on('message', async msg => {
if (msg.type === 'svm_train') {
const svm = new SVMTrainer(msg.options)
const result = await svm.train(msg.points, progress => process.send!({ type: 'progress', progress }))
process.send!({ type: 'svm_trained', result })
}
})
}

export default MLToolkit
3 changes: 1 addition & 2 deletions src/bp/sdk/botpress.d.ts
Expand Up @@ -264,9 +264,8 @@ declare module 'botpress/sdk' {

export class Trainer {
constructor(options?: Partial<SVMOptions>)
train(points: DataPoint[], callback?: TrainProgressCallback): Promise<void>
train(points: DataPoint[], callback?: TrainProgressCallback, options?: Partial<SVMOptions>): Promise<string>
isTrained(): boolean
serialize(): string
}

export class Predictor {
Expand Down

0 comments on commit 69822fe

Please sign in to comment.