diff --git a/BE/apps/api-server/src/modules/ai/ai.service.ts b/BE/apps/api-server/src/modules/ai/ai.service.ts index 411c220..791ffb0 100644 --- a/BE/apps/api-server/src/modules/ai/ai.service.ts +++ b/BE/apps/api-server/src/modules/ai/ai.service.ts @@ -11,6 +11,8 @@ import { OpenAiRequestDto } from './dto/openai.request.dto'; import { ClovaSpeechRequestDto } from './dto/clova.speech.request.dtd'; import { plainToInstance } from 'class-transformer'; import { OPENAI_PROMPT } from 'apps/api-server/src/common/constant'; +import { RedisService } from '@liaoliaots/nestjs-redis'; +import Redis from 'ioredis'; export interface TextAiResponse { keyword: string; @@ -20,15 +22,27 @@ export interface TextAiResponse { @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); + private readonly redis: Redis | null; constructor( private readonly configService: ConfigService, private readonly httpService: HttpService, private readonly nodeService: NodeService, private readonly publisherService: PublisherService, - ) {} + private readonly redisService: RedisService, + ) { + this.redis = redisService.getOrThrow('general'); + } async requestOpenAi(aiDto: AiDto) { try { + const aiCount = await this.redis.hget(aiDto.connectionId, 'aiCount'); + if (Number(aiCount) <= 0) { + this.publisherService.publish('api-socket', { + event: 'textAiSocket', + data: { error: 'AI 사용 횟수가 모두 소진되었습니다.', connectionId: aiDto.connectionId }, + }); + return; + } const apiKey = this.configService.get('OPENAI_API_KEY'); const openai = new OpenAI(apiKey); diff --git a/BE/apps/socket-server/src/modules/map/map.gateway.ts b/BE/apps/socket-server/src/modules/map/map.gateway.ts index 5fe7b86..514e0a8 100644 --- a/BE/apps/socket-server/src/modules/map/map.gateway.ts +++ b/BE/apps/socket-server/src/modules/map/map.gateway.ts @@ -135,14 +135,13 @@ export class MapGateway implements OnGatewayConnection, OnGatewayDisconnect { this.server.to(client.data.connectionId).emit('aiPending', { status: true }); } - textAiResponse(data) { + async textAiResponse(data) { const room = this.server.sockets.adapter.rooms.get(data.connectionId); if (data.error) { this.textAiError(data); this.server.to(data.connectionId).emit('error', { error: data.error }); } else { this.logger.log(`AI 응답 내용 : ${JSON.stringify(data.nodeData)}`); - if (room && room.size > 0) { // 첫 번째 클라이언트 ID 가져오기 const socketId = room.values().next().value; @@ -150,6 +149,7 @@ export class MapGateway implements OnGatewayConnection, OnGatewayDisconnect { if (clientSocket) { clientSocket.emit('aiResponse', data.nodeData); + await this.mapService.updateAiCount(data.connectionId); } else { this.logger.error(`Client socket not found for ID: ${socketId}`); } diff --git a/BE/apps/socket-server/src/modules/map/map.service.ts b/BE/apps/socket-server/src/modules/map/map.service.ts index 8543a88..3a6c4e6 100644 --- a/BE/apps/socket-server/src/modules/map/map.service.ts +++ b/BE/apps/socket-server/src/modules/map/map.service.ts @@ -110,7 +110,6 @@ export class MapService { event: 'textAiApi', data: { connectionId: client.data.connectionId, aiContent, mindmapId }, }); - await this.redis.hset(client.data.connectionId, 'aiCount', Number(aiCount) - 1); } catch (error) { if (error instanceof UnauthorizedException) { throw error; @@ -118,6 +117,15 @@ export class MapService { } } + async updateAiCount(connectionId: string) { + try { + const currentAiCount = await this.redis.hget(connectionId, 'aiCount'); + await this.redis.hset(connectionId, 'aiCount', Number(currentAiCount) - 1); + } catch { + throw new DatabaseException('aiCount 업데이트 실패'); + } + } + async checkAuth(client: Socket) { const type = await this.redis.hget(client.data.connectionId, 'type'); this.logger.log('연결 type: ' + type + ' 마인드맵');