diff --git a/gapipy/client.py b/gapipy/client.py index 4e72932..a00291a 100644 --- a/gapipy/client.py +++ b/gapipy/client.py @@ -142,7 +142,7 @@ def build(self, resource_name, data_dict, **kwargs): return resource_cls(data_dict, client=self, **kwargs) - def create(self, resource_name, data_dict): + def create(self, resource_name, data_dict, headers=None): """ Create an instance of the specified resource with `data_dict` """ @@ -151,4 +151,4 @@ def create(self, resource_name, data_dict): except AttributeError: raise AttributeError("No resource named %s is defined." % resource_name) - return resource_cls.create(self, data_dict) + return resource_cls.create(self, data_dict, headers=headers) diff --git a/gapipy/query.py b/gapipy/query.py index 81b271c..dbb5424 100644 --- a/gapipy/query.py +++ b/gapipy/query.py @@ -177,9 +177,9 @@ def count(self): self._filters = {} return out - def create(self, data_dict): + def create(self, data_dict, headers=None): """Create an instance of the query resource using the given data""" - return self.resource.create(self._client, data_dict) + return self.resource.create(self._client, data_dict, headers=headers) def first(self): """ diff --git a/gapipy/request.py b/gapipy/request.py index ab2fab2..5492a09 100644 --- a/gapipy/request.py +++ b/gapipy/request.py @@ -142,13 +142,13 @@ def update(self, resource_id, data, partial=True, uri=None): uri = '/{0}/{1}'.format(self._get_uri(), resource_id) return self._request(uri, method, data=data) - def create(self, data, uri=None): + def create(self, data, uri=None, headers=None): """ Create a single new resource with the given data. """ if not uri: uri = '/{0}'.format(self._get_uri()) - return self._request(uri, 'POST', data=data) + return self._request(uri, 'POST', data=data, additional_headers=headers) def list_raw(self, uri=None): """Return the raw response for listing resources. diff --git a/gapipy/resources/base.py b/gapipy/resources/base.py index 87aecc6..31b3d99 100644 --- a/gapipy/resources/base.py +++ b/gapipy/resources/base.py @@ -41,9 +41,9 @@ def fetch(self): return self @classmethod - def create(cls, client, data_dict): + def create(cls, client, data_dict, headers=None): request = APIRequestor(client, cls) - response = request.create(json.dumps(data_dict)) + response = request.create(json.dumps(data_dict), headers=headers) return cls(response, client=client) def __getattr__(self, name): diff --git a/tests/test_client.py b/tests/test_client.py index d9275c9..87d7893 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import json from mock import patch import unittest @@ -54,6 +55,32 @@ class MockResource(Resource): resource = self.gapi.create('foo', {'id': 1, 'foo': 'bar', 'context': 'abc'}) self.assertEqual(resource.id, 1) + @patch('gapipy.request.APIRequestor._request') + def test_create_extra_headers(self, mock_request): + """ + Test that extra HTTP headers can be passed through the `.create` + method on a resource + """ + class MockResource(Resource): + _as_is_fields = ['id', 'foo'] + _resource_name = 'foo' + self.gapi.foo = Query(self.gapi, MockResource) + + resource_data = {'id': 1, 'foo': 'bar'} # content doesn't really matter for this test + mock_request.return_value = resource_data + + # Create a `foo` while passing extra headers + extra_headers = {'X-Bender': 'I\'m not allowed to sing. Court order.'} + self.gapi.create('foo', resource_data, headers=extra_headers) + + # Did those headers make it all the way to the requestor? + mock_request.assert_called_once_with( + '/foo', + 'POST', + data=json.dumps(resource_data), + additional_headers=extra_headers, + ) + @patch('gapipy.query.Query.get_resource_data') def test_correct_client_is_associated_with_resources(self, mock_get_data): mock_get_data.return_value = { diff --git a/tests/test_query.py b/tests/test_query.py index 473f399..af46d70 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -356,7 +356,7 @@ def test_create_object(self, mock_request): r = MockResource(data, client=self.client) r.save() mock_request.assert_called_once_with( - '/mocks', 'POST', data=r.to_json()) + '/mocks', 'POST', data=r.to_json(), additional_headers=None) def test_update_object(self, mock_request): data = {