From db584c9aab90c36fc9ba65ae186d9bb55954b941 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Sun, 9 Apr 2023 13:51:33 -0700 Subject: [PATCH] allow for user passed requests.Session --- openai/__init__.py | 7 ++++++- openai/api_requestor.py | 4 ++++ openai/tests/test_endpoints.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/openai/__init__.py b/openai/__init__.py index 3e2cb1e281..ecf663a3b0 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -4,7 +4,7 @@ import os import sys -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union, Callable from contextvars import ContextVar @@ -36,6 +36,7 @@ from openai.version import VERSION if TYPE_CHECKING: + import requests from aiohttp import ClientSession api_key = os.environ.get("OPENAI_API_KEY") @@ -58,6 +59,10 @@ debug = False log = None # Set to either 'debug' or 'info', controls console logging +requestssession: Optional[ + Union["requests.Session", Callable[[], "requests.Session"]] +] = None # Provide a requests.Session or Session factory. + aiosession: ContextVar[Optional["ClientSession"]] = ContextVar( "aiohttp-session", default=None ) # Acts as a global aiohttp ClientSession that reuses connections. diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 827b73b78e..b2b6fc1890 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -76,6 +76,10 @@ def _aiohttp_proxies_arg(proxy) -> Optional[str]: def _make_session() -> requests.Session: + if openai.requestssession: + if isinstance(openai.requestssession, requests.Session): + return openai.requestssession + return openai.requestssession() if not openai.verify_ssl_certs: warnings.warn("verify_ssl_certs is ignored; openai always verifies.") s = requests.Session() diff --git a/openai/tests/test_endpoints.py b/openai/tests/test_endpoints.py index c3fc1094bb..958e07f091 100644 --- a/openai/tests/test_endpoints.py +++ b/openai/tests/test_endpoints.py @@ -2,6 +2,7 @@ import json import pytest +import requests import openai from openai import error @@ -86,3 +87,32 @@ def test_timeout_does_not_error(): model="ada", request_timeout=10, ) + + +def test_user_session(): + with requests.Session() as session: + openai.requestssession = session + + completion = openai.Completion.create( + prompt="hello world", + model="ada", + ) + assert completion + + +def test_user_session_factory(): + def factory(): + session = requests.Session() + session.mount( + "https://", + requests.adapters.HTTPAdapter(max_retries=4), + ) + return session + + openai.requestssession = factory + + completion = openai.Completion.create( + prompt="hello world", + model="ada", + ) + assert completion