From 9550b03c3f765c19720f08bdd900c31354d6dc47 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 5 Jun 2024 22:03:45 +0700 Subject: [PATCH] feat: add query messages of thread api Signed-off-by: James --- .../controllers/threads.controller.ts | 80 ++++++++++++++++++- cortex-js/src/infrastructure/dtos/page.dto.ts | 28 +++++++ cortex-js/src/main.ts | 1 + .../src/usecases/threads/threads.usecases.ts | 49 +++++++++++- 4 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 cortex-js/src/infrastructure/dtos/page.dto.ts diff --git a/cortex-js/src/infrastructure/controllers/threads.controller.ts b/cortex-js/src/infrastructure/controllers/threads.controller.ts index fd48927ea..7ca45cab8 100644 --- a/cortex-js/src/infrastructure/controllers/threads.controller.ts +++ b/cortex-js/src/infrastructure/controllers/threads.controller.ts @@ -7,14 +7,24 @@ import { Param, Delete, UseInterceptors, + HttpCode, + Query, + DefaultValuePipe, } from '@nestjs/common'; import { ThreadsUsecases } from '@/usecases/threads/threads.usecases'; import { CreateThreadDto } from '@/infrastructure/dtos/threads/create-thread.dto'; import { UpdateThreadDto } from '@/infrastructure/dtos/threads/update-thread.dto'; import { DeleteThreadResponseDto } from '@/infrastructure/dtos/threads/delete-thread.dto'; import { GetThreadResponseDto } from '@/infrastructure/dtos/threads/get-thread.dto'; -import { ApiOperation, ApiParam, ApiTags, ApiResponse } from '@nestjs/swagger'; +import { + ApiOperation, + ApiParam, + ApiTags, + ApiResponse, + ApiQuery, +} from '@nestjs/swagger'; import { TransformInterceptor } from '../interceptors/transform.interceptor'; +import { ListMessagesResponseDto } from '../dtos/messages/list-message.dto'; @ApiTags('Threads') @Controller('threads') @@ -41,6 +51,74 @@ export class ThreadsController { return this.threadsService.findAll(); } + @HttpCode(200) + @ApiResponse({ + status: 200, + description: 'A list of message objects.', + type: ListMessagesResponseDto, + }) + @ApiOperation({ + summary: 'List messages', + description: 'Returns a list of messages for a given thread.', + }) + @ApiParam({ + name: 'id', + required: true, + description: 'The ID of the thread the messages belong to.', + }) + @ApiQuery({ + name: 'limit', + type: Number, + required: false, + description: + 'A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.', + }) + @ApiQuery({ + name: 'order', + type: String, + required: false, + description: + 'Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.', + }) + @ApiQuery({ + name: 'after', + type: String, + required: false, + description: + 'A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.', + }) + @ApiQuery({ + name: 'before', + type: String, + required: false, + description: + 'A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.', + }) + @ApiQuery({ + name: 'run_id', + type: String, + required: false, + description: 'Filter messages by the run ID that generated them.', + }) + @Get(':id/messages') + getMessagesOfThread( + @Param('id') id: string, + @Query('limit', new DefaultValuePipe(20)) limit: number, + @Query('order', new DefaultValuePipe('desc')) order: 'asc' | 'desc', + @Query('after') after?: string, + @Query('before') before?: string, + @Query('run_id') runId?: string, + ) { + return this.threadsService.getMessagesOfThread( + id, + limit, + order, + after, + before, + runId, + ); + } + @ApiResponse({ status: 200, description: 'Ok', diff --git a/cortex-js/src/infrastructure/dtos/page.dto.ts b/cortex-js/src/infrastructure/dtos/page.dto.ts new file mode 100644 index 000000000..f81f3878e --- /dev/null +++ b/cortex-js/src/infrastructure/dtos/page.dto.ts @@ -0,0 +1,28 @@ +import { ApiProperty } from '@nestjs/swagger'; +import { IsArray } from 'class-validator'; + +export class PageDto { + @ApiProperty() + readonly object: string; + + @IsArray() + @ApiProperty({ isArray: true }) + readonly data: T[]; + + @ApiProperty() + readonly first_id: string | undefined; + + @ApiProperty() + readonly last_id: string | undefined; + + @ApiProperty() + readonly has_more: boolean; + + constructor(data: T[], hasMore: boolean, firstId?: string, lastId?: string) { + this.object = 'list'; + this.data = data; + this.first_id = firstId; + this.last_id = lastId; + this.has_more = hasMore; + } +} diff --git a/cortex-js/src/main.ts b/cortex-js/src/main.ts index f0e2b6cda..9ec978be7 100644 --- a/cortex-js/src/main.ts +++ b/cortex-js/src/main.ts @@ -19,6 +19,7 @@ async function bootstrap() { app.useGlobalPipes( new ValidationPipe({ + transform: true, enableDebugMessages: true, }), ); diff --git a/cortex-js/src/usecases/threads/threads.usecases.ts b/cortex-js/src/usecases/threads/threads.usecases.ts index 75e84c1ea..0d10e88e9 100644 --- a/cortex-js/src/usecases/threads/threads.usecases.ts +++ b/cortex-js/src/usecases/threads/threads.usecases.ts @@ -1,15 +1,19 @@ -import { Inject, Injectable } from '@nestjs/common'; +import { Inject, Injectable, NotFoundException } from '@nestjs/common'; import { CreateThreadDto } from '@/infrastructure/dtos/threads/create-thread.dto'; import { UpdateThreadDto } from '@/infrastructure/dtos/threads/update-thread.dto'; import { ThreadEntity } from '@/infrastructure/entities/thread.entity'; import { Repository } from 'typeorm'; import { v4 as uuidv4 } from 'uuid'; +import { MessageEntity } from '@/infrastructure/entities/message.entity'; +import { PageDto } from '@/infrastructure/dtos/page.dto'; @Injectable() export class ThreadsUsecases { constructor( @Inject('THREAD_REPOSITORY') private threadRepository: Repository, + @Inject('MESSAGE_REPOSITORY') + private messageRepository: Repository, ) {} async create(createThreadDto: CreateThreadDto): Promise { @@ -29,6 +33,49 @@ export class ThreadsUsecases { return this.threadRepository.find(); } + async getMessagesOfThread( + id: string, + limit: number, + order: 'asc' | 'desc', + after?: string, + before?: string, + runId?: string, + ) { + const thread = await this.findOne(id); + if (!thread) { + throw new NotFoundException(`Thread with id ${id} not found`); + } + + const queryBuilder = this.messageRepository.createQueryBuilder(); + const normalizedOrder = order === 'asc' ? 'ASC' : 'DESC'; + + queryBuilder + .where('thread_id = :id', { id }) + .orderBy('created', normalizedOrder) + .take(limit + 1); // Fetch one more record than the limit + + if (after) { + queryBuilder.andWhere('id > :after', { after }); + } + + if (before) { + queryBuilder.andWhere('id < :before', { before }); + } + + const { entities: messages } = await queryBuilder.getRawAndEntities(); + + let hasMore = false; + if (messages.length > limit) { + hasMore = true; + messages.pop(); // Remove the extra record + } + + const firstId = messages[0]?.id ?? undefined; + const lastId = messages[messages.length - 1]?.id ?? undefined; + + return new PageDto(messages, hasMore, firstId, lastId); + } + findOne(id: string) { return this.threadRepository.findOne({ where: { id } }); }