From 9c0a22fd66ee6bc4703b10cc2d5213bd177105dd Mon Sep 17 00:00:00 2001 From: adkm12 Date: Wed, 4 Dec 2024 00:44:37 +0900 Subject: [PATCH] =?UTF-8?q?refactor=20:=20aiCount=20=EA=B2=80=EC=A6=9D=20?= =?UTF-8?q?=EC=B6=94=EA=B0=80=20=EB=B0=8F=20=EA=B0=90=EC=86=8C=20=EB=A1=9C?= =?UTF-8?q?=EC=A7=81=20response=EB=A1=9C=20=EC=9D=B4=EB=8F=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- BE/apps/api-server/src/modules/ai/ai.service.ts | 16 +++++++++++++++- .../socket-server/src/modules/map/map.gateway.ts | 4 ++-- .../socket-server/src/modules/map/map.service.ts | 10 +++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/BE/apps/api-server/src/modules/ai/ai.service.ts b/BE/apps/api-server/src/modules/ai/ai.service.ts index 411c220..791ffb0 100644 --- a/BE/apps/api-server/src/modules/ai/ai.service.ts +++ b/BE/apps/api-server/src/modules/ai/ai.service.ts @@ -11,6 +11,8 @@ import { OpenAiRequestDto } from './dto/openai.request.dto'; import { ClovaSpeechRequestDto } from './dto/clova.speech.request.dtd'; import { plainToInstance } from 'class-transformer'; import { OPENAI_PROMPT } from 'apps/api-server/src/common/constant'; +import { RedisService } from '@liaoliaots/nestjs-redis'; +import Redis from 'ioredis'; export interface TextAiResponse { keyword: string; @@ -20,15 +22,27 @@ export interface TextAiResponse { @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); + private readonly redis: Redis | null; constructor( private readonly configService: ConfigService, private readonly httpService: HttpService, private readonly nodeService: NodeService, private readonly publisherService: PublisherService, - ) {} + private readonly redisService: RedisService, + ) { + this.redis = redisService.getOrThrow('general'); + } async requestOpenAi(aiDto: AiDto) { try { + const aiCount = await this.redis.hget(aiDto.connectionId, 'aiCount'); + if (Number(aiCount) <= 0) { + this.publisherService.publish('api-socket', { + event: 'textAiSocket', + data: { error: 'AI 사용 횟수가 모두 소진되었습니다.', connectionId: aiDto.connectionId }, + }); + return; + } const apiKey = this.configService.get('OPENAI_API_KEY'); const openai = new OpenAI(apiKey); diff --git a/BE/apps/socket-server/src/modules/map/map.gateway.ts b/BE/apps/socket-server/src/modules/map/map.gateway.ts index 5fe7b86..514e0a8 100644 --- a/BE/apps/socket-server/src/modules/map/map.gateway.ts +++ b/BE/apps/socket-server/src/modules/map/map.gateway.ts @@ -135,14 +135,13 @@ export class MapGateway implements OnGatewayConnection, OnGatewayDisconnect { this.server.to(client.data.connectionId).emit('aiPending', { status: true }); } - textAiResponse(data) { + async textAiResponse(data) { const room = this.server.sockets.adapter.rooms.get(data.connectionId); if (data.error) { this.textAiError(data); this.server.to(data.connectionId).emit('error', { error: data.error }); } else { this.logger.log(`AI 응답 내용 : ${JSON.stringify(data.nodeData)}`); - if (room && room.size > 0) { // 첫 번째 클라이언트 ID 가져오기 const socketId = room.values().next().value; @@ -150,6 +149,7 @@ export class MapGateway implements OnGatewayConnection, OnGatewayDisconnect { if (clientSocket) { clientSocket.emit('aiResponse', data.nodeData); + await this.mapService.updateAiCount(data.connectionId); } else { this.logger.error(`Client socket not found for ID: ${socketId}`); } diff --git a/BE/apps/socket-server/src/modules/map/map.service.ts b/BE/apps/socket-server/src/modules/map/map.service.ts index 8543a88..3a6c4e6 100644 --- a/BE/apps/socket-server/src/modules/map/map.service.ts +++ b/BE/apps/socket-server/src/modules/map/map.service.ts @@ -110,7 +110,6 @@ export class MapService { event: 'textAiApi', data: { connectionId: client.data.connectionId, aiContent, mindmapId }, }); - await this.redis.hset(client.data.connectionId, 'aiCount', Number(aiCount) - 1); } catch (error) { if (error instanceof UnauthorizedException) { throw error; @@ -118,6 +117,15 @@ export class MapService { } } + async updateAiCount(connectionId: string) { + try { + const currentAiCount = await this.redis.hget(connectionId, 'aiCount'); + await this.redis.hset(connectionId, 'aiCount', Number(currentAiCount) - 1); + } catch { + throw new DatabaseException('aiCount 업데이트 실패'); + } + } + async checkAuth(client: Socket) { const type = await this.redis.hget(client.data.connectionId, 'type'); this.logger.log('연결 type: ' + type + ' 마인드맵');