Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/bot/cogs/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import discord
from discord.ext import commands
from openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from src.bot.constants.settings import get_bot_settings
from src.bot.discord_bot import Bot
Expand All @@ -14,7 +14,7 @@ class OpenAi(commands.Cog):
def __init__(self, bot: Bot) -> None:
self.bot = bot
self._bot_settings = get_bot_settings()
self._openai_client: OpenAI = OpenAI(api_key=self._bot_settings.openai_api_key)
self._openai_client: AsyncOpenAI = AsyncOpenAI(api_key=self._bot_settings.openai_api_key)

@commands.command()
@commands.cooldown(1, CoolDowns.OpenAI.value, commands.BucketType.user)
Expand Down Expand Up @@ -57,8 +57,8 @@ async def _get_ai_response(self, message: str) -> str:
ChatCompletionUserMessageParam(role="user", content=message),
]

# Use the correct OpenAI API endpoint
response = self._openai_client.chat.completions.create(
# Use the correct OpenAI API endpoint (async — does not block the event loop)
response = await self._openai_client.chat.completions.create(
model=model,
messages=messages,
max_completion_tokens=1000,
Expand Down
16 changes: 8 additions & 8 deletions src/bot/tools/bot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _is_transient_discord_error(e: discord.HTTPException) -> bool:
return (isinstance(status, int) and status >= 500) or code == 40062


async def _send_with_retry(ctx, send_method, *args, max_attempts: int = 3, base_delay: float = 1.0, **kwargs):
async def send_with_retry(ctx, send_method, *args, max_attempts: int = 3, base_delay: float = 1.0, **kwargs):
"""Call send_method(*args, **kwargs) and retry on transient Discord errors.

On the first transient failure, posts a one-time "retrying" notice to the channel.
Expand Down Expand Up @@ -163,25 +163,25 @@ async def send_embed(ctx, embed, dm=False):

if is_private_message(ctx):
# Already in DM, just send the embed
await _send_with_retry(ctx, ctx.author.send, embed=embed)
await send_with_retry(ctx, ctx.author.send, embed=embed)
elif dm:
# Send to DM and notify in channel
try:
await _send_with_retry(ctx, ctx.author.send, embed=embed)
await send_with_retry(ctx, ctx.author.send, embed=embed)
notification_embed = discord.Embed(
description="📬 Response sent to your DM", color=discord.Color.green()
)
notification_embed.set_author(
name=ctx.author.display_name,
icon_url=ctx.author.avatar.url if ctx.author.avatar else ctx.author.default_avatar.url,
)
await _send_with_retry(ctx, ctx.send, embed=notification_embed)
await send_with_retry(ctx, ctx.send, embed=notification_embed)
except discord.Forbidden, discord.HTTPException:
# DM failed, fall back to sending in the channel
await _send_with_retry(ctx, ctx.send, embed=embed)
await send_with_retry(ctx, ctx.send, embed=embed)
else:
# Send to channel
await _send_with_retry(ctx, ctx.send, embed=embed)
await send_with_retry(ctx, ctx.send, embed=embed)
except (discord.Forbidden, discord.HTTPException) as e:
ctx.bot.log.error(f"Failed to send message: {e}")
if dm or is_private_message(ctx):
Expand Down Expand Up @@ -240,10 +240,10 @@ def _dict_to_embed(data: dict) -> discord.Embed:
async def send_and_save(self, ctx) -> None:
"""Send the first page and save all pages to the database.

Uses _send_with_retry so transient Discord errors (5xx, code 40062) are
Uses send_with_retry so transient Discord errors (5xx, code 40062) are
retried before propagating to the command error handler.
"""
msg = await _send_with_retry(ctx, ctx.send, embed=self.pages[0], view=self)
msg = await send_with_retry(ctx, ctx.send, embed=self.pages[0], view=self)
self.message = msg
from src.database.dal.bot.embed_pages_dal import EmbedPagesDal

Expand Down
39 changes: 1 addition & 38 deletions src/gw2/cogs/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
from src.gw2.tools.gw2_cooldowns import GW2CoolDowns


async def _keep_typing_alive(ctx, stop_event):
"""Helper to keep Discord typing indicator alive during long operations."""
try:
while not stop_event.is_set():
try:
await ctx.message.channel.typing()
await asyncio.sleep(4) # Renew every 4 seconds (Discord typing lasts ~5s)
except asyncio.CancelledError:
raise # Re-raise CancelledError
except discord.HTTPException, discord.Forbidden:
# Handle Discord API errors gracefully and stop the loop
break
except asyncio.CancelledError:
# Clean up and re-raise CancelledError as required
raise


async def _fetch_guild_info_standalone(gw2_api, guild_id, api_key, ctx):
"""Helper to fetch individual guild information."""
try:
Expand Down Expand Up @@ -83,10 +66,6 @@ async def account(ctx):
if "account" not in permissions:
return await bot_utils.send_error_msg(ctx, gw2_messages.API_KEY_NO_PERMISSION, True)

# Initialize variables for cleanup
stop_typing = None
typing_task = None

try:
# Send progress message as embed
color = ctx.bot.settings["gw2"]["EmbedColor"]
Expand All @@ -95,11 +74,7 @@ async def account(ctx):
color=color,
)
progress_embed.set_author(name=ctx.message.author.display_name, icon_url=ctx.message.author.display_avatar.url)
progress_msg = await ctx.send(embed=progress_embed)

# Start background typing keeper
stop_typing = asyncio.Event()
typing_task = asyncio.create_task(_keep_typing_alive(ctx, stop_typing))
progress_msg = await bot_utils.send_with_retry(ctx, ctx.send, embed=progress_embed)

# Fetch basic account info and server info in parallel
account_task = gw2_api.call_api("account", api_key)
Expand Down Expand Up @@ -251,23 +226,11 @@ async def limited_guild_fetch(task):
text=f"{bot_utils.get_current_date_time_str_long()} UTC",
)

# Stop the background typing task
stop_typing.set()
typing_task.cancel()

# Clean up progress message and send final result
await progress_msg.delete()
await bot_utils.send_embed(ctx, embed)

except Exception as e:
# Stop the background typing task if it exists
if stop_typing is not None and typing_task is not None:
try:
stop_typing.set()
typing_task.cancel()
except AttributeError, RuntimeError:
# Handle cases where task is already done or event is invalid
pass
await bot_utils.send_error_msg(ctx, e)
return ctx.bot.log.error(ctx, e)

Expand Down
10 changes: 8 additions & 2 deletions src/gw2/tools/gw2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,17 @@ class Gw2Servers(Enum):
async def send_progress_embed(
ctx: commands.Context, message: str = "Please wait, I'm fetching data from GW2 API... (this may take a moment)"
) -> discord.Message:
"""Send a progress embed that can be deleted when the operation completes."""
"""Send a progress embed that can be deleted when the operation completes.

Uses send_with_retry so transient Discord errors (5xx, code 40062) are retried
instead of bubbling up to the command error handler.
"""
from src.bot.tools.bot_utils import send_with_retry

color = ctx.bot.settings["gw2"]["EmbedColor"]
embed = discord.Embed(description=f"\U0001f504 **{message}**", color=color)
embed.set_author(name=ctx.message.author.display_name, icon_url=ctx.message.author.display_avatar.url)
return await ctx.send(embed=embed)
return await send_with_retry(ctx, ctx.send, embed=embed)


async def send_msg(ctx: commands.Context, description: str, dm: bool = False) -> None:
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/bot/cogs/test_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def mock_bot():
@pytest.fixture
def openai_cog(mock_bot):
"""Create an OpenAi cog instance."""
with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.OpenAI"):
with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.AsyncOpenAI"):
mock_settings.return_value = MagicMock(openai_api_key="test-key", openai_model="gpt-3.5-turbo")
return OpenAi(mock_bot)

Expand Down Expand Up @@ -79,7 +79,7 @@ class TestOpenAi:

def test_init(self, mock_bot):
"""Test OpenAi cog initialization."""
with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.OpenAI"):
with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.AsyncOpenAI"):
mock_settings.return_value = MagicMock(openai_api_key="test-key", openai_model="gpt-3.5-turbo")
cog = OpenAi(mock_bot)
assert cog.bot == mock_bot
Expand Down Expand Up @@ -140,7 +140,7 @@ async def test_get_ai_response_success(

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_openai_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
openai_cog._openai_client = mock_client

result = await openai_cog._get_ai_response("What is Python?")
Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_get_ai_response_with_leading_trailing_spaces(

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_openai_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
openai_cog._openai_client = mock_client

result = await openai_cog._get_ai_response("Test message")
Expand Down Expand Up @@ -250,7 +250,7 @@ async def test_ai_command_with_different_models(self, mock_send_embed, openai_co

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_openai_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
openai_cog._openai_client = mock_client

await openai_cog.ai.callback(openai_cog, mock_ctx, msg_text="Test question")
Expand Down Expand Up @@ -303,7 +303,7 @@ async def test_get_ai_response_system_message_content(

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_openai_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
openai_cog._openai_client = mock_client

await openai_cog._get_ai_response("Test message")
Expand All @@ -323,7 +323,7 @@ async def test_get_ai_response_api_parameters(

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_openai_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
openai_cog._openai_client = mock_client

await openai_cog._get_ai_response("Test message")
Expand All @@ -348,7 +348,7 @@ async def test_setup_function(self, mock_bot):
"""Test the setup function."""
from src.bot.cogs.open_ai import setup

with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.OpenAI"):
with patch("src.bot.cogs.open_ai.get_bot_settings") as mock_settings, patch("src.bot.cogs.open_ai.AsyncOpenAI"):
mock_settings.return_value = MagicMock(openai_api_key="test-key", openai_model="gpt-3.5-turbo")
await setup(mock_bot)

Expand Down Expand Up @@ -414,7 +414,7 @@ async def test_get_ai_response_empty_response(self, mock_get_settings, openai_co

# Mock the client instance directly
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
openai_cog._openai_client = mock_client

result = await openai_cog._get_ai_response("Test message")
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/bot/tools/test_bot_utils_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def _make_http_exception(status: int, code: int = 0) -> discord.HTTPException:


class TestSendWithRetry:
"""Test _send_with_retry helper for transient Discord errors."""
"""Test send_with_retry helper for transient Discord errors."""

@pytest.fixture
def mock_ctx(self):
Expand All @@ -1121,7 +1121,7 @@ def mock_ctx(self):
async def test_success_on_first_attempt_no_retry(self, mock_ctx):
"""Happy path: send_method called once, no notice sent."""
send = AsyncMock(return_value="ok")
result = await bot_utils._send_with_retry(mock_ctx, send, embed="x")
result = await bot_utils.send_with_retry(mock_ctx, send, embed="x")
assert result == "ok"
send.assert_awaited_once_with(embed="x")
mock_ctx.send.assert_not_called()
Expand All @@ -1131,7 +1131,7 @@ async def test_retries_on_500_then_succeeds(self, mock_ctx):
"""500 error → retry, second attempt succeeds; one channel notice sent."""
send = AsyncMock(side_effect=[_make_http_exception(500), "ok"])
with patch("src.bot.tools.bot_utils.asyncio.sleep", new_callable=AsyncMock):
result = await bot_utils._send_with_retry(mock_ctx, send, embed="x")
result = await bot_utils.send_with_retry(mock_ctx, send, embed="x")
assert result == "ok"
assert send.await_count == 2
# Notice sent exactly once
Expand All @@ -1144,7 +1144,7 @@ async def test_retries_on_429_code_40062(self, mock_ctx):
"""429 with code 40062 is treated as transient and retried."""
send = AsyncMock(side_effect=[_make_http_exception(429, code=40062), "ok"])
with patch("src.bot.tools.bot_utils.asyncio.sleep", new_callable=AsyncMock):
result = await bot_utils._send_with_retry(mock_ctx, send)
result = await bot_utils.send_with_retry(mock_ctx, send)
assert result == "ok"
assert send.await_count == 2

Expand All @@ -1154,7 +1154,7 @@ async def test_does_not_retry_on_403_forbidden(self, mock_ctx):
forbidden = discord.Forbidden(MagicMock(status=403), {"message": "no", "code": 50007})
send = AsyncMock(side_effect=forbidden)
with pytest.raises(discord.Forbidden):
await bot_utils._send_with_retry(mock_ctx, send)
await bot_utils.send_with_retry(mock_ctx, send)
send.assert_awaited_once()
mock_ctx.send.assert_not_called()

Expand All @@ -1163,7 +1163,7 @@ async def test_does_not_retry_on_429_other_code(self, mock_ctx):
"""429 without code 40062 is not retried by this helper."""
send = AsyncMock(side_effect=_make_http_exception(429, code=20016))
with pytest.raises(discord.HTTPException):
await bot_utils._send_with_retry(mock_ctx, send)
await bot_utils.send_with_retry(mock_ctx, send)
send.assert_awaited_once()

@pytest.mark.asyncio
Expand All @@ -1172,7 +1172,7 @@ async def test_exhausts_retries_then_raises(self, mock_ctx):
send = AsyncMock(side_effect=_make_http_exception(500))
with patch("src.bot.tools.bot_utils.asyncio.sleep", new_callable=AsyncMock):
with pytest.raises(discord.HTTPException):
await bot_utils._send_with_retry(mock_ctx, send, max_attempts=3)
await bot_utils.send_with_retry(mock_ctx, send, max_attempts=3)
assert send.await_count == 3
# Notice sent at most once even across multiple failed attempts
assert mock_ctx.send.call_count == 1
Expand All @@ -1183,6 +1183,6 @@ async def test_notice_failure_does_not_break_retry(self, mock_ctx):
mock_ctx.send.side_effect = _make_http_exception(500)
send = AsyncMock(side_effect=[_make_http_exception(500), "ok"])
with patch("src.bot.tools.bot_utils.asyncio.sleep", new_callable=AsyncMock):
result = await bot_utils._send_with_retry(mock_ctx, send)
result = await bot_utils.send_with_retry(mock_ctx, send)
assert result == "ok"
assert send.await_count == 2
Loading
Loading