Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion openapi_python_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import threading
from typing import Any, Dict

import httpx
Expand All @@ -15,13 +16,33 @@ class Client:
"""

def __init__(self, token: str, client: Any = None, timeout: float = 30.0):
self.client = client if client is not None else httpx.Client(timeout=timeout)
self._client = client
self._thread_local = threading.local()
self.timeout = timeout
self.auth_header: str = f"Bearer {token}"
self.headers: Dict[str, str] = {
"Authorization": self.auth_header,
"Content-Type": "application/json",
}

@property
def client(self) -> Any:
"""
Thread-safe access to the underlying HTTP client.
If a custom client was provided at initialization, it is returned.
Otherwise, a thread-local httpx.Client is created and returned.
"""
if self._client is not None:
return self._client

if not hasattr(self._thread_local, "client"):
self._thread_local.client = httpx.Client(timeout=self.timeout)
return self._thread_local.client

@client.setter
def client(self, value: Any):
self._client = value

def __enter__(self):
"""Enable use as a synchronous context manager."""
return self
Expand Down
22 changes: 21 additions & 1 deletion openapi_python_sdk/oauth_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import threading
from typing import Any, Dict, List

import httpx
Expand All @@ -13,7 +14,9 @@ class OauthClient:
"""

def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None, timeout: float = 30.0):
self.client = client if client is not None else httpx.Client(timeout=timeout)
self._client = client
self._thread_local = threading.local()
self.timeout = timeout
self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL
self.auth_header: str = (
"Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode()
Expand All @@ -23,6 +26,23 @@ def __init__(self, username: str, apikey: str, test: bool = False, client: Any =
"Content-Type": "application/json",
}

@property
def client(self) -> Any:
"""
Thread-safe access to the underlying HTTP client.
If a custom client was provided at initialization, it is returned.
Otherwise, a thread-local httpx.Client is created and returned.
"""
if self._client is not None:
return self._client
if not hasattr(self._thread_local, "client"):
self._thread_local.client = httpx.Client(timeout=self.timeout)
return self._thread_local.client

@client.setter
def client(self, value: Any):
self._client = value

def __enter__(self):
"""Enable use as a synchronous context manager."""
return self
Expand Down
65 changes: 65 additions & 0 deletions tests/test_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import threading
import unittest

import httpx

from openapi_python_sdk import Client, OauthClient


class TestThreadSafety(unittest.TestCase):
def test_oauth_client_thread_safety(self):
oauth = OauthClient(username="user", apikey="key")

clients = []
def get_client():
clients.append(oauth.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# Each thread should have gotten a unique client instance
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 5)

def test_client_thread_safety(self):
client = Client(token="tok")

clients = []
def get_client():
clients.append(client.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# Each thread should have gotten a unique client instance
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 5)

def test_shared_client_injection_still_works(self):
# If we explicitly pass a client, it SHOULD be shared (backward compatibility)
shared_engine = httpx.Client()
oauth = OauthClient(username="user", apikey="key", client=shared_engine)

clients = []
def get_client():
clients.append(oauth.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# All threads should have the SAME instance because it was injected
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 1)
self.assertEqual(id(clients[0]), id(shared_engine))

if __name__ == "__main__":
unittest.main()
Loading