diff --git a/openapi_python_sdk/client.py b/openapi_python_sdk/client.py index 18b6341..4ba5211 100644 --- a/openapi_python_sdk/client.py +++ b/openapi_python_sdk/client.py @@ -1,4 +1,5 @@ import json +import threading from typing import Any, Dict import httpx @@ -15,13 +16,33 @@ class Client: """ def __init__(self, token: str, client: Any = None, timeout: float = 30.0): - self.client = client if client is not None else httpx.Client(timeout=timeout) + self._client = client + self._thread_local = threading.local() + self.timeout = timeout self.auth_header: str = f"Bearer {token}" self.headers: Dict[str, str] = { "Authorization": self.auth_header, "Content-Type": "application/json", } + @property + def client(self) -> Any: + """ + Thread-safe access to the underlying HTTP client. + If a custom client was provided at initialization, it is returned. + Otherwise, a thread-local httpx.Client is created and returned. + """ + if self._client is not None: + return self._client + + if not hasattr(self._thread_local, "client"): + self._thread_local.client = httpx.Client(timeout=self.timeout) + return self._thread_local.client + + @client.setter + def client(self, value: Any): + self._client = value + def __enter__(self): """Enable use as a synchronous context manager.""" return self diff --git a/openapi_python_sdk/oauth_client.py b/openapi_python_sdk/oauth_client.py index f6a418d..a3be9db 100644 --- a/openapi_python_sdk/oauth_client.py +++ b/openapi_python_sdk/oauth_client.py @@ -1,4 +1,5 @@ import base64 +import threading from typing import Any, Dict, List import httpx @@ -13,7 +14,9 @@ class OauthClient: """ def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None, timeout: float = 30.0): - self.client = client if client is not None else httpx.Client(timeout=timeout) + self._client = client + self._thread_local = threading.local() + self.timeout = timeout self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL self.auth_header: str = ( "Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode() @@ -23,6 +26,23 @@ def __init__(self, username: str, apikey: str, test: bool = False, client: Any = "Content-Type": "application/json", } + @property + def client(self) -> Any: + """ + Thread-safe access to the underlying HTTP client. + If a custom client was provided at initialization, it is returned. + Otherwise, a thread-local httpx.Client is created and returned. + """ + if self._client is not None: + return self._client + if not hasattr(self._thread_local, "client"): + self._thread_local.client = httpx.Client(timeout=self.timeout) + return self._thread_local.client + + @client.setter + def client(self, value: Any): + self._client = value + def __enter__(self): """Enable use as a synchronous context manager.""" return self diff --git a/tests/test_thread_safety.py b/tests/test_thread_safety.py new file mode 100644 index 0000000..5763084 --- /dev/null +++ b/tests/test_thread_safety.py @@ -0,0 +1,65 @@ +import threading +import unittest + +import httpx + +from openapi_python_sdk import Client, OauthClient + + +class TestThreadSafety(unittest.TestCase): + def test_oauth_client_thread_safety(self): + oauth = OauthClient(username="user", apikey="key") + + clients = [] + def get_client(): + clients.append(oauth.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have gotten a unique client instance + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 5) + + def test_client_thread_safety(self): + client = Client(token="tok") + + clients = [] + def get_client(): + clients.append(client.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have gotten a unique client instance + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 5) + + def test_shared_client_injection_still_works(self): + # If we explicitly pass a client, it SHOULD be shared (backward compatibility) + shared_engine = httpx.Client() + oauth = OauthClient(username="user", apikey="key", client=shared_engine) + + clients = [] + def get_client(): + clients.append(oauth.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should have the SAME instance because it was injected + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 1) + self.assertEqual(id(clients[0]), id(shared_engine)) + +if __name__ == "__main__": + unittest.main()