Skip to content

Commit

Permalink
feat(nlu-server): models weights are validated on upload (#198)
Browse files Browse the repository at this point in the history
* chore: ensure all fields of protobuf message are either required, optional or repeated

* feat(nlu-server): models are validated

* chore: added a test that ensures a model of unsuported spec is reported as so

* also validate model content
  • Loading branch information
franklevasseur committed Mar 23, 2022
1 parent c1b40c3 commit aa2a471
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 76 deletions.
5 changes: 4 additions & 1 deletion packages/nlu-e2e/src/run-tests.ts
@@ -1,13 +1,14 @@
import { Logger } from '@botpress/logger'
import { Client as NLUClient } from '@botpress/nlu-client'

import fs from 'fs'
import _ from 'lodash'
import { nanoid } from 'nanoid'
import { assertModelsAreEmpty, assertServerIsReachable } from './assertions'
import { clinc50_42_dataset, clinc50_666_dataset, grocery_dataset } from './datasets'
import tests from './tests'
import { AssertionArgs, Test } from './typings'
import { syncE2ECachePath } from './utils'
import { getE2ECachePath, syncE2ECachePath } from './utils'

type CommandLineArgs = {
nluEndpoint: string
Expand Down Expand Up @@ -46,4 +47,6 @@ export const runTests = async (cliArgs: CommandLineArgs) => {
for (const test of testToRun) {
await test.handler(args)
}

fs.rmSync(getE2ECachePath(appId), { recursive: true, force: true })
}
30 changes: 21 additions & 9 deletions packages/nlu-e2e/src/tests/modelweights-transfer.ts
@@ -1,8 +1,9 @@
import fs from 'fs'
import _ from 'lodash'
import path from 'path'
import { AssertionArgs, Test } from 'src/typings'
import {
assertIntentPredictionWorks,
assertModelsAreEmpty,
assertModelsInclude,
assertModelsPrune,
assertModelTransferIsEnabled,
Expand All @@ -14,9 +15,8 @@ import {
assertTrainingFinishes,
assertTrainingStarts
} from '../assertions'
import fs from 'fs'
import { grocery_dataset, grocery_test_sample } from '../datasets'
import { corruptBuffer, getE2ECachePath } from '../utils'
import { bufferReplace, corruptBuffer, getE2ECachePath } from '../utils'

const NAME = 'modelweights-transfer'

Expand Down Expand Up @@ -64,12 +64,24 @@ export const modelWeightsTransferTest: Test = {
grocery_test_sample.intent
)

// TODO: ensure uploading a corrupted buffer fails
// const modelWeights = await fs.promises.readFile(fileLocation)
// const corruptedWeights = corruptBuffer(modelWeights)
// const corruptedFileLocation = path.join(cachePath, `${modelId}.corrupted.model`)
// await fs.promises.writeFile(corruptedFileLocation, corruptedWeights)
// await assertModelWeightsUploadFails(modelWeightsTransferArgs, corruptedFileLocation, 'INVALID_MODEL_FORMAT')
// ensure uploading a corrupted buffer fails
const originalModelWeights = await fs.promises.readFile(fileLocation)
const corruptedWeights = corruptBuffer(originalModelWeights)
const corruptedFileLocation = path.join(cachePath, `${modelId}.corrupted.model`)
await fs.promises.writeFile(corruptedFileLocation, corruptedWeights)
await assertModelWeightsUploadFails(modelWeightsTransferArgs, corruptedFileLocation, 'INVALID_MODEL_FORMAT')

// ensure uploading a older version buffer fails
const specHash = modelId.split('.')[1]
const dummySpecHash = 'ffffff9999999999'
const deprecatedWeights = bufferReplace(
originalModelWeights,
Buffer.from(specHash, 'utf8'),
Buffer.from(dummySpecHash, 'utf8')
)
const deprecatedFileLocation = path.join(cachePath, `${modelId}.deprecated.model`)
await fs.promises.writeFile(deprecatedFileLocation, deprecatedWeights)
await assertModelWeightsUploadFails(modelWeightsTransferArgs, deprecatedFileLocation, 'UNSUPORTED_MODEL_SPEC')

// cleanup
await assertModelsPrune(modelWeightsTransferArgs)
Expand Down
18 changes: 18 additions & 0 deletions packages/nlu-e2e/src/utils.ts
Expand Up @@ -137,3 +137,21 @@ export const corruptBuffer = (buffer: Buffer): Buffer => {

return encrypted
}

export const bufferReplace = (buffer: Buffer, from: Buffer, to: Buffer): Buffer => {
const patternStart = buffer.indexOf(from)
if (patternStart < 0) {
return buffer
}

const patternEnd = patternStart + from.length

let result = Buffer.from([])
buffer.copy(result)

result = Buffer.concat([result, buffer.slice(0, patternStart)])
result = Buffer.concat([result, to])
result = Buffer.concat([result, buffer.slice(patternEnd, buffer.length)])

return result
}
4 changes: 4 additions & 0 deletions packages/nlu-engine/src/engine/index.ts
Expand Up @@ -237,6 +237,10 @@ export default class Engine implements IEngine {
return this._trainingWorkerQueue.cancelTraining(trainSessionId)
}

public validateModel(serialized: Model): void {
deserializeModel(serialized) // try to deserialize a model to see if it throws
}

public async loadModel(serialized: Model) {
const stringId = modelIdService.toString(serialized.id)
this._logger.debug(`Load model ${stringId}`)
Expand Down
Expand Up @@ -16,7 +16,7 @@ type Predictors = Model
type ExactMatchIndex = _.Dictionary<{ intent: string }>

const PTBExactIndexValue = new ptb.PTBMessage('ExactIndexValue', {
intent: { type: 'string', id: 1 }
intent: { type: 'string', id: 1, rule: 'required' }
})

const PTBExactIntentModel = new ptb.PTBMessage('ExactIntentModel', {
Expand Down
Expand Up @@ -30,9 +30,9 @@ type Model = {

const PTBOOSIntentModel = new ptb.PTBMessage('OOSIntentModel', {
trainingVocab: { type: 'string', id: 1, rule: 'repeated' },
baseIntentClfModel: { type: SvmIntentClassifier.modelType, id: 2 },
baseIntentClfModel: { type: SvmIntentClassifier.modelType, id: 2, rule: 'required' },
oosSvmModel: { type: MLToolkit.SVM.Classifier.modelType, id: 3, rule: 'optional' },
exactMatchModel: { type: ExactIntenClassifier.modelType, id: 4 }
exactMatchModel: { type: ExactIntenClassifier.modelType, id: 4, rule: 'required' }
})

type Predictors = {
Expand Down
32 changes: 16 additions & 16 deletions packages/nlu-engine/src/engine/model-serializer.ts
Expand Up @@ -26,23 +26,23 @@ export type PredictableModel = Omit<Model, 'data'> & {
}

const PTBSlotDef = new ptb.PTBMessage('SlotDef', {
name: { type: 'string', id: 1 },
name: { type: 'string', id: 1, rule: 'required' },
entities: { type: 'string', id: 2, rule: 'repeated' }
})

const PTBIntentDef = new ptb.PTBMessage('IntentDef', {
name: { type: 'string', id: 1 },
name: { type: 'string', id: 1, rule: 'required' },
contexts: { type: 'string', id: 2, rule: 'repeated' },
slot_definitions: { type: PTBSlotDef, id: 3, rule: 'repeated' },
utterances: { type: 'string', id: 4, rule: 'repeated' }
})

const PTBPatternEntityDef = new ptb.PTBMessage('PatternEntityDef', {
name: { type: 'string', id: 1 },
pattern: { type: 'string', id: 2 },
name: { type: 'string', id: 1, rule: 'required' },
pattern: { type: 'string', id: 2, rule: 'required' },
examples: { type: 'string', id: 3, rule: 'repeated' },
matchCase: { type: 'bool', id: 4 },
sensitive: { type: 'bool', id: 5 }
matchCase: { type: 'bool', id: 4, rule: 'required' },
sensitive: { type: 'bool', id: 5, rule: 'required' }
})

const PTBSynonymValue = new ptb.PTBMessage('ListEntitySynonymValue', {
Expand All @@ -54,37 +54,37 @@ const PTBSynonym = new ptb.PTBMessage('ListEntitySynonym', {
})

const PTBListEntityModel = new ptb.PTBMessage('ListEntityModel', {
type: { type: 'string', id: 1 },
id: { type: 'string', id: 2 },
entityName: { type: 'string', id: 3 },
fuzzyTolerance: { type: 'double', id: 4 },
sensitive: { type: 'bool', id: 5 },
type: { type: 'string', id: 1, rule: 'required' },
id: { type: 'string', id: 2, rule: 'required' },
entityName: { type: 'string', id: 3, rule: 'required' },
fuzzyTolerance: { type: 'double', id: 4, rule: 'required' },
sensitive: { type: 'bool', id: 5, rule: 'required' },
mappingsTokens: { keyType: 'string', type: PTBSynonym, id: 6 }
})

const PTBCentroid = new ptb.PTBMessage('KmeanCentroid', {
centroid: { type: 'double', id: 1, rule: 'repeated' },
error: { type: 'double', id: 2 },
size: { type: 'int32', id: 3 }
error: { type: 'double', id: 2, rule: 'required' },
size: { type: 'int32', id: 3, rule: 'required' }
})

const PTBKmeansResult = new ptb.PTBMessage('KmeansResult', {
clusters: { type: 'int32', id: 1, rule: 'repeated' },
centroids: { type: PTBCentroid, id: 2, rule: 'repeated' },
iterations: { type: 'int32', id: 3 }
iterations: { type: 'int32', id: 3, rule: 'required' }
})

let model_data_idx = 0
const PTBPredictableModelData = new ptb.PTBMessage('PredictableModelData', {
intents: { type: PTBIntentDef, id: model_data_idx++, rule: 'repeated' },
languageCode: { type: 'string', id: model_data_idx++ },
languageCode: { type: 'string', id: model_data_idx++, rule: 'required' },
pattern_entities: { type: PTBPatternEntityDef, id: model_data_idx++, rule: 'repeated' },
contexts: { type: 'string', id: model_data_idx++, rule: 'repeated' },
list_entities: { type: PTBListEntityModel, id: model_data_idx++, rule: 'repeated' },
tfidf: { keyType: 'string', type: 'double', id: model_data_idx++ },
vocab: { type: 'string', id: model_data_idx++, rule: 'repeated' },
kmeans: { type: PTBKmeansResult, id: model_data_idx++, rule: 'optional' },
ctx_model: { type: SvmIntentClassifier.modelType, id: model_data_idx++ },
ctx_model: { type: SvmIntentClassifier.modelType, id: model_data_idx++, rule: 'required' },
intent_model_by_ctx: { keyType: 'string', type: OOSIntentClassifier.modelType, id: model_data_idx++ },
slots_model_by_intent: { keyType: 'string', type: SlotTagger.modelType, id: model_data_idx++ }
})
Expand Down
6 changes: 3 additions & 3 deletions packages/nlu-engine/src/engine/slots/slot-tagger.ts
Expand Up @@ -27,19 +27,19 @@ const CRF_TRAINER_PARAMS = {
}

const PTBSlotDefinition = new ptb.PTBMessage('SlotDefinition', {
name: { type: 'string', id: 1 },
name: { type: 'string', id: 1, rule: 'required' },
entities: { type: 'string', id: 2, rule: 'repeated' }
})

const PTBIntentSlotFeatures = new ptb.PTBMessage('IntentSlotFeatures', {
name: { type: 'string', id: 1 },
name: { type: 'string', id: 1, rule: 'required' },
vocab: { type: 'string', id: 2, rule: 'repeated' },
slot_entities: { type: 'string', id: 3, rule: 'repeated' }
})

const PTBSlotTaggerModel = new ptb.PTBMessage('SlotTaggerModel', {
crfModel: { type: MLToolkit.CRF.Tagger.modelType, id: 1, rule: 'optional' },
intentFeatures: { type: PTBIntentSlotFeatures, id: 2 },
intentFeatures: { type: PTBIntentSlotFeatures, id: 2, rule: 'required' },
slot_definitions: { type: PTBSlotDefinition, id: 3, rule: 'repeated' }
})

Expand Down
2 changes: 1 addition & 1 deletion packages/nlu-engine/src/ml/crf/base.ts
Expand Up @@ -8,7 +8,7 @@ import { MarginalPrediction, TagPrediction } from '.'
import { CRFTrainInput } from './typings'

const PTBCRFTaggerModel = new ptb.PTBMessage('CRFTaggerModel', {
content: { type: 'bytes', id: 1 }
content: { type: 'bytes', id: 1, rule: 'required' }
})

type CRFTaggerModel = ptb.Infer<typeof PTBCRFTaggerModel>
Expand Down
2 changes: 1 addition & 1 deletion packages/nlu-engine/src/ml/svm/flat-matrix.ts
Expand Up @@ -3,7 +3,7 @@ import _ from 'lodash'

let matrix_idx = 0
export const PTBFlatMatrixMsg = new ptb.PTBMessage('Matrix', {
nCol: { type: 'int32', id: matrix_idx++ },
nCol: { type: 'int32', id: matrix_idx++, rule: 'required' },
data: { type: 'double', id: matrix_idx++, rule: 'repeated' }
})
export type PTBFlatMatrix = ptb.Infer<typeof PTBFlatMatrixMsg>
Expand Down
38 changes: 19 additions & 19 deletions packages/nlu-engine/src/ml/svm/serialization.ts
Expand Up @@ -3,37 +3,37 @@ import { PTBFlatMatrixMsg } from './flat-matrix'

let param_idx = 0
export const PTBSVMClassifierParams = new ptb.PTBMessage('SVMClassifierParameters', {
svm_type: { type: 'int32', id: param_idx++ },
kernel_type: { type: 'int32', id: param_idx++ },
cache_size: { type: 'double', id: param_idx++ },
eps: { type: 'double', id: param_idx++ },
nr_weight: { type: 'int32', id: param_idx++ },
svm_type: { type: 'int32', id: param_idx++, rule: 'required' },
kernel_type: { type: 'int32', id: param_idx++, rule: 'required' },
cache_size: { type: 'double', id: param_idx++, rule: 'required' },
eps: { type: 'double', id: param_idx++, rule: 'required' },
nr_weight: { type: 'int32', id: param_idx++, rule: 'required' },
weight_label: { type: 'int32', id: param_idx++, rule: 'repeated' },
weight: { type: 'double', id: param_idx++, rule: 'repeated' },
shrinking: { type: 'bool', id: param_idx++ },
probability: { type: 'bool', id: param_idx++ },
C: { type: 'double', id: param_idx++ },
gamma: { type: 'double', id: param_idx++ },
degree: { type: 'int32', id: param_idx++ },
nu: { type: 'double', id: param_idx++ },
p: { type: 'double', id: param_idx++ },
coef0: { type: 'double', id: param_idx++ }
shrinking: { type: 'bool', id: param_idx++, rule: 'required' },
probability: { type: 'bool', id: param_idx++, rule: 'required' },
C: { type: 'double', id: param_idx++, rule: 'required' },
gamma: { type: 'double', id: param_idx++, rule: 'required' },
degree: { type: 'int32', id: param_idx++, rule: 'required' },
nu: { type: 'double', id: param_idx++, rule: 'required' },
p: { type: 'double', id: param_idx++, rule: 'required' },
coef0: { type: 'double', id: param_idx++, rule: 'required' }
})

let model_idx = 0
export const PTBSVMClassifierModel = new ptb.PTBMessage('SVMClassifierModel', {
param: { type: PTBSVMClassifierParams, id: model_idx++ },
nr_class: { type: 'int32', id: model_idx++ },
l: { type: 'int32', id: model_idx++ },
SV: { type: PTBFlatMatrixMsg, id: model_idx++ },
sv_coef: { type: PTBFlatMatrixMsg, id: model_idx++ },
param: { type: PTBSVMClassifierParams, id: model_idx++, rule: 'required' },
nr_class: { type: 'int32', id: model_idx++, rule: 'required' },
l: { type: 'int32', id: model_idx++, rule: 'required' },
SV: { type: PTBFlatMatrixMsg, id: model_idx++, rule: 'required' },
sv_coef: { type: PTBFlatMatrixMsg, id: model_idx++, rule: 'required' },
rho: { type: 'double', id: model_idx++, rule: 'repeated' },
probA: { type: 'double', id: model_idx++, rule: 'repeated' },
probB: { type: 'double', id: model_idx++, rule: 'repeated' },
sv_indices: { type: 'int32', id: model_idx++, rule: 'repeated' },
label: { type: 'int32', id: model_idx++, rule: 'repeated' },
nSV: { type: 'int32', id: model_idx++, rule: 'repeated' },
free_sv: { type: 'int32', id: model_idx++ },
free_sv: { type: 'int32', id: model_idx++, rule: 'required' },

mu: { type: 'double', id: param_idx++, rule: 'repeated' },
sigma: { type: 'double', id: param_idx++, rule: 'repeated' },
Expand Down
1 change: 1 addition & 0 deletions packages/nlu-engine/src/typings.d.ts
Expand Up @@ -85,6 +85,7 @@ export type Engine = {
getLanguages: () => string[]
getSpecifications: () => Specifications

validateModel(serialized: Model): void
loadModel: (model: Model) => Promise<void>
unloadModel: (modelId: ModelId) => void
hasModel: (modelId: ModelId) => boolean
Expand Down
9 changes: 8 additions & 1 deletion packages/nlu-server/src/application/errors.ts
Expand Up @@ -26,9 +26,16 @@ export class LintingNotFoundError extends ResponseError {
}
}

export class InvalidModelFormatError extends ResponseError {
constructor(message: string) {
super(`model weights have an invalid format: ${message}`, 400)
}
}

export class InvalidModelSpecError extends ResponseError {
constructor(modelId: ModelId, currentSpec: string) {
super(`expected spec hash to be "${currentSpec}". target model has spec "${modelId.specificationHash}".`, 400)
const code = 455 // custom status code
super(`expected spec hash to be "${currentSpec}". target model has spec "${modelId.specificationHash}".`, code)
}
}

Expand Down

0 comments on commit aa2a471

Please sign in to comment.