Skip to content

Commit

Permalink
Merge branch 'main' into task/changeAdvSimInit
Browse files Browse the repository at this point in the history
  • Loading branch information
nagkumar91 committed Jun 24, 2024
2 parents fcf5b56 + ff5ee3d commit ba9db9c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ def __init__(
if isinstance(conversation_starter_content, dict):
self.conversation_starter = conversation_starter_content
else:
self.conversation_starter = jinja2.Template(
conversation_starter_content, undefined=jinja2.StrictUndefined
)
try:
self.conversation_starter = jinja2.Template(
conversation_starter_content, undefined=jinja2.StrictUndefined
)
except jinja2.exceptions.TemplateSyntaxError as e: # noqa: F841
self.conversation_starter = conversation_starter_content
else:
self.logger.info(
"This simulated bot will generate the first turn as no conversation starter is provided"
Expand Down Expand Up @@ -121,8 +124,10 @@ async def generate_response(
# if conversation_starter is a dictionary, pass it into samples as is
if isinstance(self.conversation_starter, dict):
samples = [self.conversation_starter]
if isinstance(self.conversation_starter, jinja2.Template):
samples = [self.conversation_starter.render(**self.persona_template_args)]
else:
samples = [self.conversation_starter.render(**self.persona_template_args)] # type: ignore[attr-defined]
samples = [self.conversation_starter] # type: ignore[attr-defined]
time_taken = 0

finish_reason = ["stop"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
LLMBase,
OpenAIChatCompletionsModel,
)
from promptflow.evals.synthetic._model_tools import AsyncHTTPClientWithRetry


# Mock classes for dependencies
Expand Down Expand Up @@ -45,6 +46,16 @@ def bot_assistant_params():
}


@pytest.fixture
def bot_invalid_jinja_params():
return {
"role": ConversationRole.USER,
"model": MockOpenAIChatCompletionsModel(),
"conversation_template": "Hello, {{ name }}!!!!",
"instantiation_parameters": {"name": "TestUser", "conversation_starter": "Hello, world! {{world }"},
}


@pytest.mark.unittest
class TestConversationBot:
@pytest.mark.asyncio
Expand All @@ -54,6 +65,30 @@ async def test_conversation_bot_initialization_user(self, bot_user_params):
assert bot.name == "TestUser"
assert isinstance(bot.conversation_template, jinja2.Template)

@pytest.mark.asyncio
async def test_conversation_bot_initialization_user_invalid_jinja(self, bot_invalid_jinja_params):
bot = ConversationBot(**bot_invalid_jinja_params)

assert bot.role == ConversationRole.USER
assert bot.name == "TestUser"
assert isinstance(bot.conversation_template, jinja2.Template)
assert isinstance(bot.conversation_starter, str)
assert bot.conversation_starter is not None
asyncHttpClient = AsyncHTTPClientWithRetry(
n_retry=1,
retry_timeout=0,
logger=None,
)
client = asyncHttpClient.client
parsed_response, req, time_taken, full_response = await bot.generate_response(
session=client, conversation_history=[], max_history=0, turn_number=0
)
assert (
parsed_response["samples"][0]
== bot_invalid_jinja_params["instantiation_parameters"]["conversation_starter"]
)
client.close()

@pytest.mark.asyncio
async def test_conversation_bot_initialization_assistant(self, bot_assistant_params):
bot = ConversationBot(**bot_assistant_params)
Expand Down

0 comments on commit ba9db9c

Please sign in to comment.