From 98afa800979a022128cd76afe6111b428712efad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Levasseur?= Date: Tue, 22 Feb 2022 11:33:38 -0500 Subject: [PATCH] fix(distributed): canceling a task that is not handled locally waits for the task to be canceled (#180) * fix(distributed): task cancelation waits for instance to answer * fix: allow training to start * chore: allow setting max linting through CLI * fix: canceling a task works when task is not handled on local instance * chore(e2e): added a test that ensures canceling an unexistant training fails with 404 * chore(gh): add gh check to make sure trainings are distributed * chore: refactor e2e tests + fix race condition in distributed queue (#181) --- .github/workflows/bench.yml | 2 +- .github/workflows/e2e.yml | 168 ++++++++++++++++-- packages/distributed/src/queues/base-queue.ts | 89 ++++------ packages/distributed/src/queues/errors.ts | 6 + packages/distributed/src/queues/index.ts | 2 +- .../distributed/src/queues/local-queue.ts | 25 ++- .../src/queues/pg-distributed-queue.ts | 111 +++++++++--- .../src/queues/pg-event-observer.ts | 62 +++++++ packages/distributed/src/queues/typings.ts | 2 +- packages/nlu-cli/src/parameters/nlu-server.ts | 4 + packages/nlu-e2e/src/assertions.ts | 17 ++ packages/nlu-e2e/src/tests/model-lifecycle.ts | 3 + packages/nlu-server/src/application/index.ts | 2 - .../src/application/linting-queue/index.ts | 15 +- .../application/linting-queue/lint-handler.ts | 3 +- .../src/application/training-queue/index.ts | 15 +- .../training-queue/train-handler.ts | 4 +- packages/nlu-server/src/bootstrap/config.ts | 3 +- .../nlu-server/src/bootstrap/documentation.ts | 2 +- packages/nlu-server/src/bootstrap/launcher.ts | 2 +- .../database-utils.ts} | 0 .../linting-repo/db-linting-repo.ts | 2 +- .../model-repo/db-model-repo.ts | 2 +- .../training-repo/db-training-repo.ts | 2 +- packages/nlu-server/src/typings.d.ts | 1 + 25 files changed, 416 insertions(+), 128 deletions(-) create mode 100644 packages/distributed/src/queues/pg-event-observer.ts rename packages/nlu-server/src/{utils/database.ts => infrastructure/database-utils.ts} (100%) diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index 3b841fd6..d04f8b5a 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -23,5 +23,5 @@ jobs: - name: Run Regression Test run: | yarn start lang --dim 100 & - sleep 15s && yarn start nlu --doc false --verbose 0 --ducklingEnabled false --languageURL http://localhost:3100 & + sleep 15s && yarn start nlu --doc false --log-level "critical" --ducklingEnabled false --languageURL http://localhost:3100 & sleep 25s && yarn bench --skip="clinc150" diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 91b49d8b..f6360446 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -1,15 +1,67 @@ name: E2E on: [pull_request] jobs: - run_e2e: - name: Run e2e tests using binary executable file + fs: + name: file system + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@master + - uses: actions/setup-node@v1 + with: + node-version: '16.13.0' + - name: Fetch Node Packages + run: | + yarn --verbose + - name: Build + run: | + yarn build + - name: package + run: | + yarn package --linux + - name: Rename binary + id: rename_binary + run: | + bin_original_name=$(node -e "console.log(require('./scripts/utils/binary').getFileName())") + echo "Moving ./dist/$bin_original_name to ./nlu ..." + mv ./dist/$bin_original_name ./nlu + - name: Download language models + run: | + ./nlu lang download --lang en --dim 25 + - name: Start Language Server + run: | + ./nlu lang --dim 25 & + echo "Lang Server started on pid $!" + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '15s' + - name: Start NLU Server + run: | + ./nlu \ + --log-level "critical" \ + --ducklingEnabled false \ + --languageURL http://localhost:3100 \ + --port 3200 & + nlu_pid=$! + echo "NLU Server started on pid $nlu_pid" + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '15s' + - name: Run Tests + run: | + yarn e2e --nlu-endpoint http://localhost:3200 + + db: + name: database runs-on: ubuntu-latest services: postgres: # Docker Hub image image: postgres env: - POSTGRES_DB: botpress-nlu + POSTGRES_DB: botpress-nlu-1 POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres POSTGRES_PORT: 5432 @@ -53,31 +105,111 @@ jobs: uses: jakejarvis/wait-action@master with: time: '15s' - - name: Run Tests on File System + - name: Start NLU Server run: | ./nlu \ - --verbose 0 \ + --log-level "critical" \ --ducklingEnabled false \ --languageURL http://localhost:3100 \ - --port 3200 & + --port 3201 \ + --dbURL postgres://postgres:postgres@localhost:5432/botpress-nlu-1 & \ nlu_pid=$! echo "NLU Server started on pid $nlu_pid" + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '15s' + - name: Run Tests + run: | + yarn e2e --nlu-endpoint http://localhost:3201 + + cluster: + name: cluster + runs-on: ubuntu-latest + services: + postgres: + # Docker Hub image + image: postgres + env: + POSTGRES_DB: botpress-nlu-2 + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres + POSTGRES_PORT: 5432 + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: + - name: Checkout code + uses: actions/checkout@master + - uses: actions/setup-node@v1 + with: + node-version: '16.13.0' + - name: Fetch Node Packages + run: | + yarn --verbose + - name: Build + run: | + yarn build + - name: package + run: | + yarn package --linux + - name: Rename binary + id: rename_binary + run: | + bin_original_name=$(node -e "console.log(require('./scripts/utils/binary').getFileName())") + echo "Moving ./dist/$bin_original_name to ./nlu ..." + mv ./dist/$bin_original_name ./nlu + - name: Download language models + run: | + ./nlu lang download --lang en --dim 25 + - name: Start Language Server + run: | + ./nlu lang --dim 25 & + echo "Lang Server started on pid $!" + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '15s' - sleep 10s && \ - yarn e2e --nlu-endpoint http://localhost:3200 && \ - kill -9 $nlu_pid + - name: Start First NLU Server on port 3202 + run: | + ./nlu \ + --log-level "critical" \ + --ducklingEnabled false \ + --maxTraining 0 \ + --maxLinting 0 \ + --languageURL http://localhost:3100 \ + --port 3202 \ + --dbURL postgres://postgres:postgres@localhost:5432/botpress-nlu-2 & \ + nlu_pid1=$! + echo "NLU Server started on pid $nlu_pid1" - - name: Run Tests on Database + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '5s' + + - name: Start Second NLU Server on port 3203 run: | ./nlu \ - --verbose 0 \ + --log-level "critical" \ --ducklingEnabled false \ --languageURL http://localhost:3100 \ - --port 3200 \ - --dbURL postgres://postgres:postgres@localhost:5432/botpress-nlu & \ - nlu_pid=$! - echo "NLU Server started on pid $nlu_pid" + --port 3203 \ + --dbURL postgres://postgres:postgres@localhost:5432/botpress-nlu-2 & \ + nlu_pid2=$! + echo "NLU Server started on pid $nlu_pid2" - sleep 10s && \ - yarn e2e --nlu-endpoint http://localhost:3200 && \ - kill -9 $nlu_pid + - name: Sleep + uses: jakejarvis/wait-action@master + with: + time: '15s' + + - name: Run Tests + run: | + yarn e2e --nlu-endpoint http://localhost:3202 diff --git a/packages/distributed/src/queues/base-queue.ts b/packages/distributed/src/queues/base-queue.ts index 6362684d..31e29436 100644 --- a/packages/distributed/src/queues/base-queue.ts +++ b/packages/distributed/src/queues/base-queue.ts @@ -4,7 +4,7 @@ import _ from 'lodash' import moment from 'moment' import { nanoid } from 'nanoid' -import { TaskAlreadyStartedError, TaskNotFoundError } from './errors' +import { TaskAlreadyStartedError } from './errors' import { createTimer, InterruptTimer } from './interrupt' import { Task, @@ -18,19 +18,20 @@ import { TaskQueue as ITaskQueue } from './typings' -export class BaseTaskQueue implements ITaskQueue { +export abstract class BaseTaskQueue implements ITaskQueue { private _schedulingTimmer!: InterruptTimer<[]> protected _clusterId: string = nanoid() constructor( protected _taskRepo: SafeTaskRepository, - private _taskRunner: TaskRunner, - private _logger: Logger, - private _idToString: (id: TId) => string, - private _options: QueueOptions + protected _taskRunner: TaskRunner, + protected _logger: Logger, + protected _idToString: (id: TId) => string, + protected _options: QueueOptions ) {} public async initialize() { + this._logger.debug(`cluster id: "${this._clusterId}"`) await this._taskRepo.initialize() this._schedulingTimmer = createTimer(this._runSchedulerInterrupt.bind(this), this._options.maxProgressDelay * 2) } @@ -68,53 +69,29 @@ export class BaseTaskQueue implements ITaskQueue { - const taskKey = this._idToString(taskId) - return this._taskRepo.inTransaction(async (repo) => { - const currentTask = await repo.get(taskId) - if (!currentTask) { - throw new TaskNotFoundError(taskKey) - } - - const zombieTasks = await this._getZombies(repo) - const isZombie = !!zombieTasks.find((t) => this._idToString(t) === taskKey) - - if (currentTask.status === 'pending' || isZombie) { - const newTask = { ...currentTask, status: 'canceled' } - return repo.set(newTask) - } - - if (currentTask.cluster !== this._clusterId) { - this._logger.debug(`Task "${taskId}" was not launched on this instance`) - return - } - - if (currentTask.status === 'running') { - return this._taskRunner.cancel(currentTask) - } - }, 'cancelTask') - } + public abstract cancelTask(taskId: TId): Promise protected async runSchedulerInterrupt() { - return this._schedulingTimmer.run() + try { + return this._schedulingTimmer.run() + } catch (thrown) { + const err = thrown instanceof Error ? thrown : new Error(`${thrown}`) + this._logger.attachError(err).error('An error occured when running scheduler interrupt.') + } } private _runSchedulerInterrupt = async () => { return this._taskRepo.inTransaction(async (repo) => { + await this._queueBackZombies(repo) + const localTasks = await repo.query({ cluster: this._clusterId, status: 'running' }) if (localTasks.length >= this._options.maxTasks) { + this._logger.debug( + `[${this._clusterId}/${this._options.queueId}] max allowed of task already launched in queue.` + ) return } - const zombieTasks = await this._getZombies(repo) - if (zombieTasks.length) { - this._logger.debug(`Queuing back ${zombieTasks.length} tasks because they seem to be zombies.`) - - const progress = this._options.initialProgress - const newState = { status: 'zombie', cluster: this._clusterId, progress } - await Bluebird.each(zombieTasks, (z) => repo.set({ ...z, ...newState })) - } - const pendings = await repo.query({ status: 'pending' }) if (pendings.length <= 0) { return @@ -122,10 +99,10 @@ export class BaseTaskQueue implements ITaskQueue implements ITaskQueue { - await this._taskRepo.inTransaction(async (repo) => { - return repo.set(task) - }, 'progressCallback') + await this._taskRepo.inTransaction((repo) => repo.set(task), 'progressCallback') }, this._options.progressThrottle) try { @@ -153,9 +128,7 @@ export class BaseTaskQueue implements ITaskQueue { - return repo.set(terminatedTask) - }, '_task_terminated') + await this._taskRepo.inTransaction((repo) => repo.set(terminatedTask), '_task_terminated') } } catch (thrown) { updateTask.flush() @@ -168,8 +141,20 @@ export class BaseTaskQueue implements ITaskQueue) => { + protected _queueBackZombies = async (repo: TaskRepository) => { const zombieThreshold = moment().subtract(this._options.maxProgressDelay, 'ms').toDate() - return repo.queryOlderThan({ status: 'running' }, zombieThreshold) + const newZombies = await repo.queryOlderThan({ status: 'running' }, zombieThreshold) + if (newZombies.length) { + this._logger.debug(`Queuing back ${newZombies.length} tasks because they seem to be zombies.`) + + const progress = this._options.initialProgress + const newState = { status: 'zombie', cluster: this._clusterId, progress } + await Bluebird.each(newZombies, (z) => repo.set({ ...z, ...newState })) + } + } + + protected _isCancelable = (task: Task) => { + const cancellableStatus: TaskStatus[] = ['running', 'pending', 'zombie'] + return cancellableStatus.includes(task.status) } } diff --git a/packages/distributed/src/queues/errors.ts b/packages/distributed/src/queues/errors.ts index 5fb7f3da..b56699c3 100644 --- a/packages/distributed/src/queues/errors.ts +++ b/packages/distributed/src/queues/errors.ts @@ -4,6 +4,12 @@ export class TaskNotFoundError extends Error { } } +export class TaskNotRunning extends Error { + constructor(taskId: string) { + super(`no current running or pending task for model: ${taskId}`) + } +} + export class TaskAlreadyStartedError extends Error { constructor(taskId: string) { super(`Training "${taskId}" already started...`) diff --git a/packages/distributed/src/queues/index.ts b/packages/distributed/src/queues/index.ts index 922e3074..44048d91 100644 --- a/packages/distributed/src/queues/index.ts +++ b/packages/distributed/src/queues/index.ts @@ -2,4 +2,4 @@ export * from './typings' export { PGDistributedTaskQueue } from './pg-distributed-queue' export { LocalTaskQueue } from './local-queue' -export { TaskAlreadyStartedError, TaskNotFoundError } from './errors' +export { TaskAlreadyStartedError, TaskNotFoundError, TaskNotRunning } from './errors' diff --git a/packages/distributed/src/queues/local-queue.ts b/packages/distributed/src/queues/local-queue.ts index cffda292..af555e82 100644 --- a/packages/distributed/src/queues/local-queue.ts +++ b/packages/distributed/src/queues/local-queue.ts @@ -1,10 +1,12 @@ import { Logger } from '@botpress/logger' import _ from 'lodash' import { InMemoryTransactionLocker } from '../locks' +import { TaskNotFoundError } from '.' import { BaseTaskQueue } from './base-queue' +import { TaskNotRunning } from './errors' import { SafeTaskRepo } from './safe-repo' -import { TaskRunner, TaskRepository, QueueOptions, TaskQueue as ITaskQueue } from './typings' +import { TaskRunner, TaskRepository, QueueOptions, TaskQueue as ITaskQueue, TaskStatus } from './typings' export class LocalTaskQueue extends BaseTaskQueue @@ -20,4 +22,25 @@ export class LocalTaskQueue const safeRepo = new SafeTaskRepo(taskRepo, new InMemoryTransactionLocker(logCb)) super(safeRepo, taskRunner, logger, idToString, opt) } + + public cancelTask(taskId: TId): Promise { + const taskKey = this._idToString(taskId) + return this._taskRepo.inTransaction(async (repo) => { + await this._queueBackZombies(repo) + + const currentTask = await repo.get(taskId) + if (!currentTask) { + throw new TaskNotFoundError(taskKey) + } + if (!this._isCancelable(currentTask)) { + throw new TaskNotRunning(taskKey) + } + + if (currentTask.status === 'pending' || currentTask.status === 'zombie') { + const newTask = { ...currentTask, status: 'canceled' } + return repo.set(newTask) + } + return this._taskRunner.cancel(currentTask) + }, 'cancelTask') + } } diff --git a/packages/distributed/src/queues/pg-distributed-queue.ts b/packages/distributed/src/queues/pg-distributed-queue.ts index 45a658dc..05427ea7 100644 --- a/packages/distributed/src/queues/pg-distributed-queue.ts +++ b/packages/distributed/src/queues/pg-distributed-queue.ts @@ -1,21 +1,21 @@ import { Logger } from '@botpress/logger' +import Bluebird from 'bluebird' +import ms from 'ms' import PGPubSub from 'pg-pubsub' import { PGTransactionLocker } from '../locks' +import { TaskNotFoundError } from '.' import { BaseTaskQueue } from './base-queue' -import { LocalTaskQueue } from './local-queue' +import { TaskNotRunning } from './errors' +import { PGQueueEventObserver } from './pg-event-observer' import { SafeTaskRepo } from './safe-repo' -import { TaskRunner, TaskRepository, QueueOptions, TaskQueue as ITaskQueue } from './typings' +import { TaskRunner, TaskRepository, QueueOptions, TaskQueue as ITaskQueue, TaskStatus } from './typings' -type Func = (...x: X) => Y +const DISTRIBUTED_CANCEL_TIMEOUT_DELAY = ms('2s') export class PGDistributedTaskQueue extends BaseTaskQueue implements ITaskQueue { - private _pubsub: PGPubSub - private _queueId: string - - private _broadcastCancelTask!: LocalTaskQueue['cancelTask'] - private _broadcastSchedulerInterrupt!: () => Promise + private _obs: PGQueueEventObserver constructor( pgURL: string, @@ -26,10 +26,8 @@ export class PGDistributedTaskQueue opt: QueueOptions ) { super(PGDistributedTaskQueue._makeSafeRepo(pgURL, taskRepo, logger), taskRunner, logger, idToString, opt) - this._pubsub = new PGPubSub(pgURL, { - log: () => {} - }) - this._queueId = opt.queueId + const _pubsub = new PGPubSub(pgURL, { log: () => {} }) + this._obs = new PGQueueEventObserver(_pubsub, opt.queueId) } private static _makeSafeRepo( @@ -43,29 +41,84 @@ export class PGDistributedTaskQueue public async initialize() { await super.initialize() + await this._obs.initialize() + this._obs.on('run_scheduler_interrupt', super.runSchedulerInterrupt.bind(this)) + this._obs.on('cancel_task', ({ taskId, clusterId }) => this._handleCancelTaskEvent(taskId, clusterId)) + } + + public async cancelTask(taskId: TId) { + const taskKey = this._idToString(taskId) + + return this._taskRepo.inTransaction(async (repo) => { + await this._queueBackZombies(repo) - this._broadcastCancelTask = await this._broadcast<[TId]>( - `${this._queueId}:cancel_task`, - super.cancelTask.bind(this) - ) - this._broadcastSchedulerInterrupt = await this._broadcast<[]>( - `${this._queueId}:scheduler_interrupt`, - super.runSchedulerInterrupt.bind(this) - ) + const currentTask = await this._taskRepo.get(taskId) + if (!currentTask) { + throw new TaskNotFoundError(taskKey) + } + if (!this._isCancelable(currentTask)) { + throw new TaskNotRunning(taskKey) + } + + if (currentTask.status === 'pending' || currentTask.status === 'zombie') { + const newTask = { ...currentTask, status: 'canceled' } + return repo.set(newTask) + } + + if (currentTask.cluster === this._clusterId) { + return this._taskRunner.cancel(currentTask) + } + + this._logger.debug(`Task "${taskId}" was not launched on this instance`) + await Bluebird.race([ + this._cancelAndWaitForResponse(taskId, currentTask.cluster), + this._timeoutTaskCancelation(DISTRIBUTED_CANCEL_TIMEOUT_DELAY) + ]) + }, 'cancelTask') } - // for if a different instance gets the cancel task http call - public cancelTask(taskId: TId) { - return this._broadcastCancelTask(taskId) + private _cancelAndWaitForResponse = (taskId: TId, clusterId: string): Promise => + new Promise(async (resolve, reject) => { + this._obs.onceOrMore('cancel_task_done', async (response) => { + if (this._idToString(response.taskId) !== this._idToString(taskId)) { + return 'stay' // canceled task is not the one we're waiting for + } + + if (response.err) { + const { message, stack } = response.err + const err = new Error(message) + err.stack = stack + reject(err) + return 'leave' + } + + resolve() + return 'leave' + }) + await this._obs.emit('cancel_task', { taskId, clusterId }) + }) + + private _timeoutTaskCancelation = (ms: number): Promise => + new Promise((_resolve, reject) => { + setTimeout(() => reject(new Error(`Canceling operation took more than ${ms} ms`)), ms) + }) + + private _handleCancelTaskEvent = async (taskId: TId, clusterId: string) => { + if (clusterId !== this._clusterId) { + return // message was not adressed to this instance + } + + try { + await this._taskRunner.cancel(taskId) + await this._obs.emit('cancel_task_done', { taskId }) + } catch (thrown) { + const { message, stack } = thrown instanceof Error ? thrown : new Error(`${thrown}`) + await this._obs.emit('cancel_task_done', { taskId, err: { message, stack } }) + } } // for if an completly busy instance receives a queue task http call protected runSchedulerInterrupt() { - return this._broadcastSchedulerInterrupt() - } - - private _broadcast = async (name: string, fn: Func>) => { - await this._pubsub.addChannel(name, (x) => fn(...x)) - return (...x: X) => this._pubsub.publish(name, x) + return this._obs.emit('run_scheduler_interrupt', undefined) } } diff --git a/packages/distributed/src/queues/pg-event-observer.ts b/packages/distributed/src/queues/pg-event-observer.ts new file mode 100644 index 00000000..6ab2aa15 --- /dev/null +++ b/packages/distributed/src/queues/pg-event-observer.ts @@ -0,0 +1,62 @@ +import Bluebird from 'bluebird' +import { EventEmitter2 } from 'eventemitter2' +import PGPubSub from 'pg-pubsub' + +const CHANNELS = ['cancel_task', 'run_scheduler_interrupt', 'cancel_task_done'] as const +type Channel = typeof CHANNELS[number] + +type CancelTaskError = { + message: string + stack?: string +} + +type PGQueueEventData = C extends 'run_scheduler_interrupt' + ? void + : C extends 'cancel_task' + ? { taskId: TId; clusterId: string } + : C extends 'cancel_task_done' + ? { taskId: TId; err?: CancelTaskError } + : never + +export class PGQueueEventObserver { + constructor(private _pubsub: PGPubSub, private _queueId: string) {} + + private _evEmitter = new EventEmitter2() + + public initialize = async (): Promise => { + await Bluebird.map(CHANNELS, (c: Channel) => + this._pubsub.addChannel(this._pgChannelId(c), (x) => this._evEmitter.emit(c, x)) + ) + } + + public teardown = async (): Promise => { + await Bluebird.map(CHANNELS, (c: Channel) => this._pubsub.removeChannel(this._pgChannelId(c))) + } + + public on(c: C, handler: (data: PGQueueEventData) => Promise): void { + this._evEmitter.on(c, handler) + } + + public off(c: C, handler: (data: PGQueueEventData) => Promise): void { + this._evEmitter.off(c, handler) + } + + public onceOrMore( + c: C, + handler: (data: PGQueueEventData) => Promise<'stay' | 'leave'> + ): void { + const cb = async (x: PGQueueEventData) => { + const y = await handler(x) + if (y === 'leave') { + this._evEmitter.off(c, cb) + } + } + this._evEmitter.on(c, cb) + } + + public async emit(c: C, data: PGQueueEventData): Promise { + return this._pubsub.publish(this._pgChannelId(c), data) + } + + private _pgChannelId = (c: Channel) => `${this._queueId}:${c}` +} diff --git a/packages/distributed/src/queues/typings.ts b/packages/distributed/src/queues/typings.ts index 6b82e250..915c82e1 100644 --- a/packages/distributed/src/queues/typings.ts +++ b/packages/distributed/src/queues/typings.ts @@ -8,7 +8,7 @@ export type TaskRunner = { task: Task, progress: ProgressCb ) => Promise | undefined> - cancel: (task: Task) => Promise + cancel: (taskId: TId) => Promise } export type TaskTerminatedStatus = 'done' | 'canceled' | 'errored' diff --git a/packages/nlu-cli/src/parameters/nlu-server.ts b/packages/nlu-cli/src/parameters/nlu-server.ts index 6515c2a3..6883accc 100644 --- a/packages/nlu-cli/src/parameters/nlu-server.ts +++ b/packages/nlu-cli/src/parameters/nlu-server.ts @@ -95,6 +95,10 @@ export const parameters = asYargs({ description: 'The max allowed amount of simultaneous trainings on a single instance', type: 'number' }, + maxLinting: { + description: 'The max allowed amount of simultaneous lintings on a single instance', + type: 'number' + }, usageURL: { description: 'Endpoint to send usage info to.', type: 'string' diff --git a/packages/nlu-e2e/src/assertions.ts b/packages/nlu-e2e/src/assertions.ts index 417a3c6a..18b165f8 100644 --- a/packages/nlu-e2e/src/assertions.ts +++ b/packages/nlu-e2e/src/assertions.ts @@ -170,6 +170,23 @@ export const assertQueueTrainingFails = async ( chai.expect(error.type).to.equal(expectedError) } +export const assertCancelTrainingFails = async ( + args: AssertionArgs, + modelId: string, + expectedError: http.ErrorType +): Promise => { + const { client, logger, appId } = args + logger.debug('assert cancel training fails') + + const cancelRes = await client.cancelTraining(appId, modelId) + if (cancelRes.success) { + throw new Error(`Expected training cancel to fail with error: "${expectedError}"`) + } + + const { error } = cancelRes + chai.expect(error.type).to.equal(expectedError) +} + export const assertTrainingCancels = async (args: AssertionArgs, modelId: string): Promise => { const { client, logger, appId } = args logger.debug('assert training cancels') diff --git a/packages/nlu-e2e/src/tests/model-lifecycle.ts b/packages/nlu-e2e/src/tests/model-lifecycle.ts index 98982ceb..601533f6 100644 --- a/packages/nlu-e2e/src/tests/model-lifecycle.ts +++ b/packages/nlu-e2e/src/tests/model-lifecycle.ts @@ -1,6 +1,7 @@ import ms from 'ms' import { AssertionArgs, Test } from 'src/typings' import { + assertCancelTrainingFails, assertModelsInclude, assertModelsPrune, assertPredictionFails, @@ -23,6 +24,8 @@ export const modelLifecycleTest: Test = { const modelLifecycleLogger = logger.sub(NAME) const modelLifecycleArgs = { ...args, logger: modelLifecycleLogger } + await assertCancelTrainingFails(modelLifecycleArgs, 'my-model-id-lol', 'training_not_found') + let clinc150_42_modelId = await assertTrainingStarts(modelLifecycleArgs, clinc50_42_dataset) await sleep(ms('1s')) diff --git a/packages/nlu-server/src/application/index.ts b/packages/nlu-server/src/application/index.ts index 7f1a415f..a1031cae 100644 --- a/packages/nlu-server/src/application/index.ts +++ b/packages/nlu-server/src/application/index.ts @@ -236,9 +236,7 @@ You can increase your cache size by the CLI or config. specifications: this._engine.getSpecifications() }) - // unhandled promise to return asap await this._lintingQueue.queueLinting(appId, modelId, trainInput) - return modelId } diff --git a/packages/nlu-server/src/application/linting-queue/index.ts b/packages/nlu-server/src/application/linting-queue/index.ts index 21daa490..6514fb98 100644 --- a/packages/nlu-server/src/application/linting-queue/index.ts +++ b/packages/nlu-server/src/application/linting-queue/index.ts @@ -62,7 +62,7 @@ export abstract class LintingQueue { const lintId: LintingId = { appId, modelId } await this.taskQueue.cancelTask(lintId) } catch (thrown) { - if (thrown instanceof q.TaskNotFoundError) { + if (thrown instanceof q.TaskNotFoundError || thrown instanceof q.TaskNotRunning) { throw new LintingNotFoundError(appId, modelId) } throw thrown @@ -104,12 +104,13 @@ export class LocalLintingQueue extends LintingQueue { const lintingLogger = baseLogger.sub(LINTING_PREFIX) const lintHandler = new LintHandler(engine, lintingLogger) - const options = opt.maxLinting - ? { - ...TASK_OPTIONS, - maxTasks: opt.maxLinting - } - : TASK_OPTIONS + const options = + opt.maxLinting === undefined + ? TASK_OPTIONS + : { + ...TASK_OPTIONS, + maxTasks: opt.maxLinting + } const taskQueue = new q.LocalTaskQueue(lintTaskRepo, lintHandler, lintingLogger, idToString, options) super(taskQueue, lintingLogger) diff --git a/packages/nlu-server/src/application/linting-queue/lint-handler.ts b/packages/nlu-server/src/application/linting-queue/lint-handler.ts index 108f282a..d098be4d 100644 --- a/packages/nlu-server/src/application/linting-queue/lint-handler.ts +++ b/packages/nlu-server/src/application/linting-queue/lint-handler.ts @@ -2,6 +2,7 @@ import { Logger } from '@botpress/logger' import { DatasetIssue, IssueCode, LintingError } from '@botpress/nlu-client' import * as NLUEngine from '@botpress/nlu-engine' import _ from 'lodash' +import { LintingId } from '../../infrastructure' import { idToString } from '../training-queue' import { LintTask, LintTaskProgress, LintTaskRunner, TerminatedLintTask } from './typings' @@ -28,7 +29,7 @@ export class LintHandler implements LintTaskRunner { } } - public async cancel(task: LintTask): Promise { + public async cancel(task: LintingId): Promise { const trainKey = idToString(task) return this._engine.cancelTraining(trainKey) } diff --git a/packages/nlu-server/src/application/training-queue/index.ts b/packages/nlu-server/src/application/training-queue/index.ts index 452f7b16..acd96e7a 100644 --- a/packages/nlu-server/src/application/training-queue/index.ts +++ b/packages/nlu-server/src/application/training-queue/index.ts @@ -63,7 +63,7 @@ export abstract class TrainingQueue { try { await this.taskQueue.cancelTask({ modelId, appId }) } catch (thrown) { - if (thrown instanceof q.TaskNotFoundError) { + if (thrown instanceof q.TaskNotFoundError || thrown instanceof q.TaskNotRunning) { throw new TrainingNotFoundError(appId, modelId) } throw thrown @@ -84,12 +84,13 @@ export class PgTrainingQueue extends TrainingQueue { const trainTaskRepo = new TrainTaskRepo(trainingRepo) const trainHandler = new TrainHandler(engine, modelRepo, trainingLogger) - const options = opt.maxTraining - ? { - ...TASK_OPTIONS, - maxTasks: opt.maxTraining - } - : TASK_OPTIONS + const options = + opt.maxTraining === undefined + ? TASK_OPTIONS + : { + ...TASK_OPTIONS, + maxTasks: opt.maxTraining + } const taskQueue = new q.PGDistributedTaskQueue( pgURL, diff --git a/packages/nlu-server/src/application/training-queue/train-handler.ts b/packages/nlu-server/src/application/training-queue/train-handler.ts index 5e006308..a89a9721 100644 --- a/packages/nlu-server/src/application/training-queue/train-handler.ts +++ b/packages/nlu-server/src/application/training-queue/train-handler.ts @@ -2,7 +2,7 @@ import { Logger } from '@botpress/logger' import { TrainingErrorType } from '@botpress/nlu-client' import * as NLUEngine from '@botpress/nlu-engine' import _ from 'lodash' -import { ModelRepository } from '../../infrastructure' +import { ModelRepository, TrainingId } from '../../infrastructure' import { idToString, MIN_TRAINING_HEARTBEAT } from '.' import { TerminatedTrainTask, TrainTask, TrainTaskProgress, TrainTaskRunner } from './typings' @@ -73,7 +73,7 @@ export class TrainHandler implements TrainTaskRunner { } } - public cancel(task: TrainTask): Promise { + public cancel(task: TrainingId): Promise { const trainKey = idToString(task) return this.engine.cancelTraining(trainKey) } diff --git a/packages/nlu-server/src/bootstrap/config.ts b/packages/nlu-server/src/bootstrap/config.ts index 40c4a7d6..831d7ff3 100644 --- a/packages/nlu-server/src/bootstrap/config.ts +++ b/packages/nlu-server/src/bootstrap/config.ts @@ -22,7 +22,8 @@ const DEFAULT_OPTIONS = (): NLUServerOptions => ({ logLevel: 'info', debugFilter: undefined, logFormat: 'text', - maxTraining: 2 + maxTraining: 2, + maxLinting: 2 }) export const getConfig = async (cliOptions: CommandLineOptions): Promise => { diff --git a/packages/nlu-server/src/bootstrap/documentation.ts b/packages/nlu-server/src/bootstrap/documentation.ts index d3fa32a7..0e361cd7 100644 --- a/packages/nlu-server/src/bootstrap/documentation.ts +++ b/packages/nlu-server/src/bootstrap/documentation.ts @@ -17,7 +17,7 @@ export const displayDocumentation = (logger: Logger, options: NLUServerOptions) {green /** * Gets the current version of the NLU engine being used. Usefull to test if your installation is working. - * @returns {bold info}: version, health and supported languages. + * @returns {bold info}: version and supported languages. */} {bold GET ${baseUrl}/info} diff --git a/packages/nlu-server/src/bootstrap/launcher.ts b/packages/nlu-server/src/bootstrap/launcher.ts index 09c4fda1..5723510f 100644 --- a/packages/nlu-server/src/bootstrap/launcher.ts +++ b/packages/nlu-server/src/bootstrap/launcher.ts @@ -46,7 +46,7 @@ export const logLaunchingMessage = async (info: NLUServerOptions & LaunchingInfo launcherLogger.info(`models stored at "${info.modelDir}"`) } - if (info.batchSize > 0) { + if (info.batchSize > 1) { launcherLogger.info(`batch size: allowing up to ${info.batchSize} predictions in one call to POST /predict`) } diff --git a/packages/nlu-server/src/utils/database.ts b/packages/nlu-server/src/infrastructure/database-utils.ts similarity index 100% rename from packages/nlu-server/src/utils/database.ts rename to packages/nlu-server/src/infrastructure/database-utils.ts diff --git a/packages/nlu-server/src/infrastructure/linting-repo/db-linting-repo.ts b/packages/nlu-server/src/infrastructure/linting-repo/db-linting-repo.ts index 00369bb5..3df18d6b 100644 --- a/packages/nlu-server/src/infrastructure/linting-repo/db-linting-repo.ts +++ b/packages/nlu-server/src/infrastructure/linting-repo/db-linting-repo.ts @@ -14,7 +14,7 @@ import { Knex } from 'knex' import _ from 'lodash' import moment from 'moment' import ms from 'ms' -import { createTableIfNotExists } from '../../utils/database' +import { createTableIfNotExists } from '../database-utils' import { packTrainSet, unpackTrainSet } from '../dataset-serializer' import { LintingRepository } from '.' import { Linting, LintingId, LintingState } from './typings' diff --git a/packages/nlu-server/src/infrastructure/model-repo/db-model-repo.ts b/packages/nlu-server/src/infrastructure/model-repo/db-model-repo.ts index 215e9b79..fc61154f 100644 --- a/packages/nlu-server/src/infrastructure/model-repo/db-model-repo.ts +++ b/packages/nlu-server/src/infrastructure/model-repo/db-model-repo.ts @@ -3,7 +3,7 @@ import * as NLUEngine from '@botpress/nlu-engine' import Bluebird from 'bluebird' import { Knex } from 'knex' import _ from 'lodash' -import { createTableIfNotExists } from '../../utils/database' +import { createTableIfNotExists } from '../database-utils' import { compressModel, decompressModel } from './compress-model' import { ModelRepository, PruneOptions } from './typings' diff --git a/packages/nlu-server/src/infrastructure/training-repo/db-training-repo.ts b/packages/nlu-server/src/infrastructure/training-repo/db-training-repo.ts index 320b9346..c8798195 100644 --- a/packages/nlu-server/src/infrastructure/training-repo/db-training-repo.ts +++ b/packages/nlu-server/src/infrastructure/training-repo/db-training-repo.ts @@ -5,7 +5,7 @@ import { Knex } from 'knex' import _ from 'lodash' import moment from 'moment' import ms from 'ms' -import { createTableIfNotExists } from '../../utils/database' +import { createTableIfNotExists } from '../database-utils' import { packTrainSet, unpackTrainSet } from '../dataset-serializer' import { Training, TrainingId, TrainingState, TrainingRepository, TrainingListener } from './typings' diff --git a/packages/nlu-server/src/typings.d.ts b/packages/nlu-server/src/typings.d.ts index 584757f3..915790ef 100644 --- a/packages/nlu-server/src/typings.d.ts +++ b/packages/nlu-server/src/typings.d.ts @@ -23,6 +23,7 @@ export type NLUServerOptions = { apmEnabled?: boolean apmSampleRate?: number maxTraining: number + maxLinting: number languageURL: string languageAuthToken?: string ducklingURL: string