diff --git a/meilisearch/client.py b/meilisearch/client.py index a74f0dce..bf077bcc 100644 --- a/meilisearch/client.py +++ b/meilisearch/client.py @@ -74,9 +74,12 @@ def __init__( self.config = Config(url, api_key, timeout=timeout, client_agents=client_agents) + # Store custom headers so they can be propagated to sub-clients (Index, TaskHandler, etc.) + self._custom_headers = custom_headers + self.http = HttpRequests(self.config, custom_headers) - self.task_handler = TaskHandler(self.config) + self.task_handler = TaskHandler(self.config, custom_headers) def create_index(self, uid: str, options: Optional[Mapping[str, Any]] = None) -> TaskInfo: """Create an index. @@ -99,7 +102,7 @@ def create_index(self, uid: str, options: Optional[Mapping[str, Any]] = None) -> MeilisearchApiError An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors """ - return Index.create(self.config, uid, options) + return Index.create(self.config, uid, options, custom_headers=self._custom_headers) def delete_index(self, uid: str) -> TaskInfo: """Deletes an index @@ -153,6 +156,7 @@ def get_indexes(self, parameters: Optional[Mapping[str, Any]] = None) -> Dict[st index["primaryKey"], index["createdAt"], index["updatedAt"], + custom_headers=self._custom_headers, ) for index in response["results"] ] @@ -201,7 +205,7 @@ def get_index(self, uid: str) -> Index: MeilisearchApiError An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors """ - return Index(self.config, uid).fetch_info() + return Index(self.config, uid, custom_headers=self._custom_headers).fetch_info() def get_raw_index(self, uid: str) -> Dict[str, Any]: """Get the index as a dictionary. @@ -239,7 +243,7 @@ def index(self, uid: str) -> Index: An Index instance. """ if uid is not None: - return Index(self.config, uid=uid) + return Index(self.config, uid=uid, custom_headers=self._custom_headers) raise ValueError("The index UID should not be None") def multi_search( diff --git a/meilisearch/index.py b/meilisearch/index.py index 2441e518..f61578fd 100644 --- a/meilisearch/index.py +++ b/meilisearch/index.py @@ -66,6 +66,7 @@ def __init__( primary_key: Optional[str] = None, created_at: Optional[Union[datetime, str]] = None, updated_at: Optional[Union[datetime, str]] = None, + custom_headers: Optional[Mapping[str, str]] = None, ) -> None: """ Parameters @@ -78,8 +79,8 @@ def __init__( Primary-key of the index. """ self.config = config - self.http = HttpRequests(config) - self.task_handler = TaskHandler(config) + self.http = HttpRequests(config, custom_headers) + self.task_handler = TaskHandler(config, custom_headers) self.uid = uid self.primary_key = primary_key self.created_at = iso_to_date_time(created_at) @@ -175,7 +176,12 @@ def get_primary_key(self) -> str | None: return self.fetch_info().primary_key @staticmethod - def create(config: Config, uid: str, options: Optional[Mapping[str, Any]] = None) -> TaskInfo: + def create( + config: Config, + uid: str, + options: Optional[Mapping[str, Any]] = None, + custom_headers: Optional[Mapping[str, str]] = None, + ) -> TaskInfo: """Create the index. Parameters @@ -199,7 +205,7 @@ def create(config: Config, uid: str, options: Optional[Mapping[str, Any]] = None if options is None: options = {} payload = {**options, "uid": uid} - task = HttpRequests(config).post(config.paths.index, payload) + task = HttpRequests(config, custom_headers).post(config.paths.index, payload) return TaskInfo(**task) diff --git a/meilisearch/task.py b/meilisearch/task.py index e492bd41..faaee15f 100644 --- a/meilisearch/task.py +++ b/meilisearch/task.py @@ -2,7 +2,7 @@ from datetime import datetime from time import sleep -from typing import Any, MutableMapping, Optional +from typing import Any, Mapping, MutableMapping, Optional from urllib import parse from meilisearch._httprequests import HttpRequests @@ -19,13 +19,13 @@ class TaskHandler: https://www.meilisearch.com/docs/reference/api/tasks """ - def __init__(self, config: Config): + def __init__(self, config: Config, custom_headers: Optional[Mapping[str, str]] = None): """Parameters ---------- config: Config object containing permission and location of Meilisearch. """ self.config = config - self.http = HttpRequests(config) + self.http = HttpRequests(config, custom_headers) def get_batches(self, parameters: Optional[MutableMapping[str, Any]] = None) -> BatchResults: """Get all task batches. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 36a8da06..2450a396 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -58,3 +58,13 @@ def test_headers(api_key, custom_headers, expected): client = meilisearch.Client("127.0.0.1:7700", api_key=api_key, custom_headers=custom_headers) assert client.http.headers.items() >= expected.items() + + +def test_index_inherits_custom_headers(): + custom_headers = {"header_key_1": "header_value_1", "header_key_2": "header_value_2"} + client = meilisearch.Client("127.0.0.1:7700", api_key=None, custom_headers=custom_headers) + + index = client.index("movies") + + # Index-level HttpRequests instance should also include the custom headers + assert index.http.headers.items() >= custom_headers.items()