diff --git a/mlflow/openai/__init__.py b/mlflow/openai/__init__.py index 9ebd445fce539..6dbd68a970abb 100644 --- a/mlflow/openai/__init__.py +++ b/mlflow/openai/__init__.py @@ -501,13 +501,51 @@ def _mock_request_chat_completion(): yield m +class MockResponse: + def __init__(self, status, json_data): + self.status = status + self._json = json_data + self.headers = {"Content-Type": "application/json"} + + async def json(self): + return self._json + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + + +@contextmanager +def _mock_aiohttp_request_chat_completion(): + response = MockResponse( + 200, + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": _TEST_CONTENT}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + }, + ) + with mock.patch("aiohttp.ClientSession.post", return_value=response) as m: + yield m + + class _TestOpenAIWrapper(_OpenAIWrapper): """ A wrapper class that should be used for testing purposes only. """ def predict(self, pdf): - with _mock_request_chat_completion() as m: + with _mock_aiohttp_request_chat_completion() as m: res = super().predict(pdf) m.assert_called_once() return res