Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions struct_module/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ class ModelWrapper:
def __init__(self, logger=None):
self.logger = logger or logging.getLogger(__name__)
self.model_name = os.getenv("AI_MODEL") or "openai:gpt-4.1"

# Set default API key if not provided to prevent startup crashes
if self.model_name.startswith("openai:") and not os.getenv("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = "sk-default-placeholder-key"
self.logger.warning("OPENAI_API_KEY not set. Using placeholder. AI features may not work properly.")

self.agent = Agent(model=self.model_name)
self.logger.debug(f"Configured Agent with model: {self.model_name}")

Expand All @@ -22,11 +28,20 @@ def generate_content(self, system_prompt, user_prompt, dry_run=False):
if dry_run:
self.logger.info("[DRY RUN] Would generate content using AI agent.")
return "[DRY RUN] Generating content using AI agent"

# Check if using placeholder API key
if os.getenv("OPENAI_API_KEY") == "sk-default-placeholder-key":
self.logger.warning("Using placeholder API key. Set OPENAI_API_KEY environment variable for AI features to work.")
return "AI generation skipped: Please set OPENAI_API_KEY environment variable."

prompt = f"{user_prompt}"
try:
self.agent.system_prompt = system_prompt
result = self.agent.run_sync(prompt)
return result.output
except Exception as e:
self.logger.error(f"AI agent generation failed: {e}")
# Provide more helpful error message for API key issues
if "api_key" in str(e).lower() or "unauthorized" in str(e).lower():
return "AI generation failed: Please check your OPENAI_API_KEY environment variable."
return f"AI agent generation failed: {e}"
132 changes: 132 additions & 0 deletions tests/test_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import os
import logging
from unittest.mock import patch, MagicMock
from struct_module.model_wrapper import ModelWrapper


class TestModelWrapper:
"""Test ModelWrapper functionality including API key handling."""

def test_init_with_existing_openai_api_key(self):
"""Test that ModelWrapper initializes correctly when OPENAI_API_KEY is already set."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test-key", "AI_MODEL": "openai:gpt-4.1"}, clear=False):
with patch('struct_module.model_wrapper.Agent') as mock_agent:
wrapper = ModelWrapper()
assert wrapper.model_name == "openai:gpt-4.1"
mock_agent.assert_called_once_with(model="openai:gpt-4.1")
# Should not set placeholder key when real key exists
assert os.environ["OPENAI_API_KEY"] == "sk-test-key"

def test_init_without_openai_api_key_sets_placeholder(self):
"""Test that ModelWrapper sets placeholder API key when OPENAI_API_KEY is not set."""
# Remove OPENAI_API_KEY if it exists
env_vars = os.environ.copy()
if "OPENAI_API_KEY" in env_vars:
del env_vars["OPENAI_API_KEY"]
env_vars["AI_MODEL"] = "openai:gpt-4.1"

with patch.dict(os.environ, env_vars, clear=True):
with patch('struct_module.model_wrapper.Agent') as mock_agent:
mock_logger = MagicMock()
wrapper = ModelWrapper(logger=mock_logger)

# Should set placeholder key
assert os.environ["OPENAI_API_KEY"] == "sk-default-placeholder-key"
# Should log warning
mock_logger.warning.assert_called_with(
"OPENAI_API_KEY not set. Using placeholder. AI features may not work properly."
)
# Agent should still be created
mock_agent.assert_called_once_with(model="openai:gpt-4.1")

def test_init_with_non_openai_model_no_placeholder(self):
"""Test that ModelWrapper doesn't set placeholder for non-OpenAI models."""
# Remove OPENAI_API_KEY if it exists
env_vars = os.environ.copy()
if "OPENAI_API_KEY" in env_vars:
del env_vars["OPENAI_API_KEY"]
env_vars["AI_MODEL"] = "anthropic:claude-3"

with patch.dict(os.environ, env_vars, clear=True):
with patch('struct_module.model_wrapper.Agent') as mock_agent:
mock_logger = MagicMock()
wrapper = ModelWrapper(logger=mock_logger)

# Should not set placeholder key for non-OpenAI models
assert "OPENAI_API_KEY" not in os.environ
# Should not log warning
mock_logger.warning.assert_not_called()
# Agent should still be created
mock_agent.assert_called_once_with(model="anthropic:claude-3")

def test_generate_content_with_placeholder_key(self):
"""Test that generate_content handles placeholder API key gracefully."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-default-placeholder-key"}, clear=False):
with patch('struct_module.model_wrapper.Agent'):
mock_logger = MagicMock()
wrapper = ModelWrapper(logger=mock_logger)

result = wrapper.generate_content("system", "user")

assert result == "AI generation skipped: Please set OPENAI_API_KEY environment variable."
mock_logger.warning.assert_called_with(
"Using placeholder API key. Set OPENAI_API_KEY environment variable for AI features to work."
)

def test_generate_content_with_valid_key(self):
"""Test that generate_content works normally with valid API key."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-valid-key"}, clear=False):
with patch('struct_module.model_wrapper.Agent') as mock_agent_class:
mock_agent = MagicMock()
mock_agent_class.return_value = mock_agent

# Mock the agent run result
mock_result = MagicMock()
mock_result.output = "Generated content"
mock_agent.run_sync.return_value = mock_result

wrapper = ModelWrapper()
result = wrapper.generate_content("system prompt", "user prompt")

assert result == "Generated content"
mock_agent.run_sync.assert_called_once_with("user prompt")

def test_generate_content_dry_run(self):
"""Test that generate_content handles dry run mode."""
with patch('struct_module.model_wrapper.Agent'):
mock_logger = MagicMock()
wrapper = ModelWrapper(logger=mock_logger)

result = wrapper.generate_content("system", "user", dry_run=True)

assert result == "[DRY RUN] Generating content using AI agent"
mock_logger.info.assert_called_with("[DRY RUN] Would generate content using AI agent.")

def test_generate_content_api_key_error_handling(self):
"""Test that generate_content provides helpful error messages for API key issues."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-invalid-key"}, clear=False):
with patch('struct_module.model_wrapper.Agent') as mock_agent_class:
mock_agent = MagicMock()
mock_agent_class.return_value = mock_agent

# Mock an API key error
mock_agent.run_sync.side_effect = Exception("Invalid API_KEY provided")

mock_logger = MagicMock()
wrapper = ModelWrapper(logger=mock_logger)
result = wrapper.generate_content("system", "user")

assert result == "AI generation failed: Please check your OPENAI_API_KEY environment variable."
mock_logger.error.assert_called_once()

def test_default_model_name(self):
"""Test that default model is set correctly when AI_MODEL is not provided."""
env_vars = os.environ.copy()
if "AI_MODEL" in env_vars:
del env_vars["AI_MODEL"]

with patch.dict(os.environ, env_vars, clear=True):
with patch('struct_module.model_wrapper.Agent'):
wrapper = ModelWrapper()
assert wrapper.model_name == "openai:gpt-4.1"
Loading