From f70850f66094a0cd3661d13450fe8bfb60aa0dcf Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 28 May 2024 22:31:53 +0700 Subject: [PATCH] chore: cortex chat revamp with options --- .../infrastructure/commanders/chat.command.ts | 38 +++++--- .../commanders/usecases/chat.cli.usecases.ts | 28 ++++-- .../controllers/models.controller.ts | 88 ++++++++++++++----- .../controllers/threads.controller.ts | 54 ++++++++---- 4 files changed, 149 insertions(+), 59 deletions(-) diff --git a/cortex-js/src/infrastructure/commanders/chat.command.ts b/cortex-js/src/infrastructure/commanders/chat.command.ts index a7c7ad842..8f3e5c4e7 100644 --- a/cortex-js/src/infrastructure/commanders/chat.command.ts +++ b/cortex-js/src/infrastructure/commanders/chat.command.ts @@ -3,29 +3,43 @@ import { ChatCliUsecases } from './usecases/chat.cli.usecases'; import { exit } from 'node:process'; type ChatOptions = { - model?: string; threadId?: string; + message?: string; + attach: boolean; }; -@SubCommand({ name: 'chat', description: 'Start a chat with a model' }) +@SubCommand({ name: 'chat', description: 'Send a chat request to a model' }) export class ChatCommand extends CommandRunner { constructor(private readonly chatCliUsecases: ChatCliUsecases) { super(); } - async run(_input: string[], option: ChatOptions): Promise { - const modelId = option.model; + async run(_input: string[], options: ChatOptions): Promise { + const modelId = _input[0]; if (!modelId) { console.error('Model ID is required'); exit(1); } - return this.chatCliUsecases.chat(modelId, option.threadId); + return this.chatCliUsecases.chat( + modelId, + options.threadId, + options.message, + options.attach, + ); } @Option({ - flags: '-m, --model ', - description: 'Model Id to start chat with', + flags: '-t, --thread ', + description: 'Thread Id. If not provided, will create new thread', + }) + parseThreadId(value: string) { + return value; + } + + @Option({ + flags: '-m, --message ', + description: 'Message to send to the model', required: true, }) parseModelId(value: string) { @@ -33,10 +47,12 @@ export class ChatCommand extends CommandRunner { } @Option({ - flags: '-t, --thread ', - description: 'Thread Id. If not provided, will create new thread', + flags: '-a, --attach', + description: 'Attach to interactive chat session', + defaultValue: false, + name: 'attach', }) - parseThreadId(value: string) { - return value; + parseAttach() { + return true; } } diff --git a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts index 1995ea829..5147b2e1a 100644 --- a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts +++ b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts @@ -72,8 +72,13 @@ export class ChatCliUsecases { return this.threadUsecases.create(createThreadDto); } - async chat(modelId: string, threadId?: string): Promise { - console.log(`Inorder to exit, type '${this.exitClause}'.`); + async chat( + modelId: string, + threadId?: string, + message?: string, + attach: boolean = true, + ): Promise { + if (attach) console.log(`Inorder to exit, type '${this.exitClause}'.`); const thread = await this.getOrCreateNewThread(modelId, threadId); const messages: ChatCompletionMessage[] = ( await this.messagesUsecases.getLastMessagesByThread(thread.id, 10) @@ -87,18 +92,19 @@ export class ChatCliUsecases { output: stdout, prompt: this.userIndicator, }); - rl.prompt(); + if (message) sendCompletionMessage.bind(this)(message); + if (attach) rl.prompt(); rl.on('close', () => { this.cortexUsecases.stopCortex().then(() => { - console.log(this.exitMessage); + if (attach) console.log(this.exitMessage); exit(0); }); }); - const decoder = new TextDecoder('utf-8'); + rl.on('line', sendCompletionMessage.bind(this)); - rl.on('line', (userInput: string) => { + function sendCompletionMessage(userInput: string) { if (userInput.trim() === this.exitClause) { rl.close(); return; @@ -137,6 +143,8 @@ export class ChatCliUsecases { top_p: 0.7, }; + const decoder = new TextDecoder('utf-8'); + this.chatUsecases .inference(chatDto, {}) .then((response: stream.Readable) => { @@ -144,7 +152,8 @@ export class ChatCliUsecases { response.on('error', (error: any) => { console.error(error); - rl.prompt(); + if (attach) rl.prompt(); + else rl.close(); }); response.on('end', () => { @@ -170,7 +179,8 @@ export class ChatCliUsecases { this.messagesUsecases.create(createMessageDto).then(() => { assistantResponse = ''; console.log('\n'); - rl.prompt(); + if (attach) rl.prompt(); + else rl.close(); }); }); @@ -201,6 +211,6 @@ export class ChatCliUsecases { } }); }); - }); + } } } diff --git a/cortex-js/src/infrastructure/controllers/models.controller.ts b/cortex-js/src/infrastructure/controllers/models.controller.ts index 0c178f9ee..18a54c5d8 100644 --- a/cortex-js/src/infrastructure/controllers/models.controller.ts +++ b/cortex-js/src/infrastructure/controllers/models.controller.ts @@ -16,14 +16,7 @@ import { ModelDto } from '@/infrastructure/dtos/models/model-successfully-create import { ListModelsResponseDto } from '@/infrastructure/dtos/models/list-model-response.dto'; import { DeleteModelResponseDto } from '@/infrastructure/dtos/models/delete-model.dto'; import { DownloadModelResponseDto } from '@/infrastructure/dtos/models/download-model.dto'; -import { - ApiCreatedResponse, - ApiOkResponse, - ApiOperation, - ApiParam, - ApiTags, - ApiResponse -} from '@nestjs/swagger'; +import { ApiOperation, ApiParam, ApiTags, ApiResponse } from '@nestjs/swagger'; import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto'; import { ModelSettingParamsDto } from '../dtos/models/model-setting-params.dto'; import { TransformInterceptor } from '../interceptors/transform.interceptor'; @@ -40,7 +33,10 @@ export class ModelsController { description: 'The model has been successfully created.', type: StartModelSuccessDto, }) - @ApiOperation({ summary: 'Create model', description: "Creates a model `.json` instance file manually." }) + @ApiOperation({ + summary: 'Create model', + description: 'Creates a model `.json` instance file manually.', + }) @Post() create(@Body() createModelDto: CreateModelDto) { return this.modelsUsecases.create(createModelDto); @@ -52,8 +48,15 @@ export class ModelsController { description: 'The model has been successfully started.', type: StartModelSuccessDto, }) - @ApiOperation({ summary: 'Start model', description: "Starts a model operation defined by a model `id`." }) - @ApiParam({ name: 'modelId', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Start model', + description: 'Starts a model operation defined by a model `id`.', + }) + @ApiParam({ + name: 'modelId', + required: true, + description: 'The unique identifier of the model.', + }) @Post(':modelId/start') startModel( @Param('modelId') modelId: string, @@ -68,8 +71,15 @@ export class ModelsController { description: 'The model has been successfully stopped.', type: StartModelSuccessDto, }) - @ApiOperation({ summary: 'Stop model', description: "Stops a model operation defined by a model `id`." }) - @ApiParam({ name: 'modelId', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Stop model', + description: 'Stops a model operation defined by a model `id`.', + }) + @ApiParam({ + name: 'modelId', + required: true, + description: 'The unique identifier of the model.', + }) @Post(':modelId/stop') stopModel(@Param('modelId') modelId: string) { return this.modelsUsecases.stopModel(modelId); @@ -81,8 +91,15 @@ export class ModelsController { description: 'Ok', type: DownloadModelResponseDto, }) - @ApiOperation({ summary: 'Download model', description: "Downloads a specific model instance." }) - @ApiParam({ name: 'modelId', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Download model', + description: 'Downloads a specific model instance.', + }) + @ApiParam({ + name: 'modelId', + required: true, + description: 'The unique identifier of the model.', + }) @Get('download/:modelId') downloadModel(@Param('modelId') modelId: string) { return this.modelsUsecases.downloadModel(modelId); @@ -94,7 +111,11 @@ export class ModelsController { description: 'Ok', type: ListModelsResponseDto, }) - @ApiOperation({ summary: 'List models', description: "Lists the currently available models, and provides basic information about each one such as the owner and availability. [Equivalent to OpenAI's list model](https://platform.openai.com/docs/api-reference/models/list)." }) + @ApiOperation({ + summary: 'List models', + description: + "Lists the currently available models, and provides basic information about each one such as the owner and availability. [Equivalent to OpenAI's list model](https://platform.openai.com/docs/api-reference/models/list).", + }) @Get() findAll() { return this.modelsUsecases.findAll(); @@ -106,8 +127,16 @@ export class ModelsController { description: 'Ok', type: ModelDto, }) - @ApiOperation({ summary: 'Get model', description: "Retrieves a model instance, providing basic information about the model such as the owner and permissions. [Equivalent to OpenAI's list model](https://platform.openai.com/docs/api-reference/models/retrieve)." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Get model', + description: + "Retrieves a model instance, providing basic information about the model such as the owner and permissions. [Equivalent to OpenAI's list model](https://platform.openai.com/docs/api-reference/models/retrieve).", + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the model.', + }) @Get(':id') findOne(@Param('id') id: string) { return this.modelsUsecases.findOne(id); @@ -119,8 +148,15 @@ export class ModelsController { description: 'The model has been successfully updated.', type: UpdateModelDto, }) - @ApiOperation({ summary: 'Update model', description: "Updates a model instance defined by a model's `id`." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Update model', + description: "Updates a model instance defined by a model's `id`.", + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the model.', + }) @Patch(':id') update(@Param('id') id: string, @Body() updateModelDto: UpdateModelDto) { return this.modelsUsecases.update(id, updateModelDto); @@ -131,8 +167,16 @@ export class ModelsController { description: 'The model has been successfully deleted.', type: DeleteModelResponseDto, }) - @ApiOperation({ summary: 'Delete model', description: "Deletes a model. [Equivalent to OpenAI's delete model](https://platform.openai.com/docs/api-reference/models/delete)." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the model." }) + @ApiOperation({ + summary: 'Delete model', + description: + "Deletes a model. [Equivalent to OpenAI's delete model](https://platform.openai.com/docs/api-reference/models/delete).", + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the model.', + }) @Delete(':id') remove(@Param('id') id: string) { return this.modelsUsecases.remove(id); diff --git a/cortex-js/src/infrastructure/controllers/threads.controller.ts b/cortex-js/src/infrastructure/controllers/threads.controller.ts index abf26bc01..fd48927ea 100644 --- a/cortex-js/src/infrastructure/controllers/threads.controller.ts +++ b/cortex-js/src/infrastructure/controllers/threads.controller.ts @@ -6,7 +6,6 @@ import { Patch, Param, Delete, - HttpCode UseInterceptors, } from '@nestjs/common'; import { ThreadsUsecases } from '@/usecases/threads/threads.usecases'; @@ -14,14 +13,7 @@ import { CreateThreadDto } from '@/infrastructure/dtos/threads/create-thread.dto import { UpdateThreadDto } from '@/infrastructure/dtos/threads/update-thread.dto'; import { DeleteThreadResponseDto } from '@/infrastructure/dtos/threads/delete-thread.dto'; import { GetThreadResponseDto } from '@/infrastructure/dtos/threads/get-thread.dto'; -import { - ApiCreatedResponse, - ApiOkResponse, - ApiOperation, - ApiParam, - ApiTags, - ApiResponse -} from '@nestjs/swagger'; +import { ApiOperation, ApiParam, ApiTags, ApiResponse } from '@nestjs/swagger'; import { TransformInterceptor } from '../interceptors/transform.interceptor'; @ApiTags('Threads') @@ -30,13 +22,20 @@ import { TransformInterceptor } from '../interceptors/transform.interceptor'; export class ThreadsController { constructor(private readonly threadsService: ThreadsUsecases) {} - @ApiOperation({ summary: 'Create thread', description: "Creates a new thread." }) + @ApiOperation({ + summary: 'Create thread', + description: 'Creates a new thread.', + }) @Post() create(@Body() createThreadDto: CreateThreadDto) { return this.threadsService.create(createThreadDto); } - @ApiOperation({ summary: 'List threads', description: "Lists all the available threads along with its configurations." }) + @ApiOperation({ + summary: 'List threads', + description: + 'Lists all the available threads along with its configurations.', + }) @Get() findAll() { return this.threadsService.findAll(); @@ -47,8 +46,15 @@ export class ThreadsController { description: 'Ok', type: GetThreadResponseDto, }) - @ApiOperation({ summary: 'Get thread', description: "Retrieves a thread along with its configurations." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the thread." }) + @ApiOperation({ + summary: 'Get thread', + description: 'Retrieves a thread along with its configurations.', + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the thread.', + }) @Get(':id') findOne(@Param('id') id: string) { return this.threadsService.findOne(id); @@ -59,8 +65,15 @@ export class ThreadsController { description: 'The thread has been successfully updated.', type: UpdateThreadDto, }) - @ApiOperation({ summary: 'Update thread', description: "Updates a thread's configurations." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the thread." }) + @ApiOperation({ + summary: 'Update thread', + description: "Updates a thread's configurations.", + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the thread.', + }) @Patch(':id') update(@Param('id') id: string, @Body() updateThreadDto: UpdateThreadDto) { return this.threadsService.update(id, updateThreadDto); @@ -71,8 +84,15 @@ export class ThreadsController { description: 'The thread has been successfully deleted.', type: DeleteThreadResponseDto, }) - @ApiOperation({ summary: 'Delete thread', description: "Deletes a specific thread defined by a thread `id` ." }) - @ApiParam({ name: 'id', required: true, description: "The unique identifier of the thread." }) + @ApiOperation({ + summary: 'Delete thread', + description: 'Deletes a specific thread defined by a thread `id` .', + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The unique identifier of the thread.', + }) @Delete(':id') remove(@Param('id') id: string) { return this.threadsService.remove(id);