diff --git a/requests_mock/adapter.py b/requests_mock/adapter.py index 5e3cf46..19bf058 100644 --- a/requests_mock/adapter.py +++ b/requests_mock/adapter.py @@ -12,6 +12,7 @@ import json as jsonutils +import requests from requests.adapters import BaseAdapter, HTTPAdapter from requests.packages.urllib3.response import HTTPResponse import six @@ -31,6 +32,56 @@ def __init__(self, headers, status_code, reason): self.reason = reason +class _RequestObjectProxy(object): + """A wrapper around a requests.Request that gives some extra information. + + This will be important both for matching and so that when it's save into + the request_history users will be able to access these properties. + """ + + def __init__(self, request): + self._request = request + self._url_parts_ = None + self._qs = None + + def __getattr__(self, name): + return getattr(self._request, name) + + @property + def _url_parts(self): + if self._url_parts_ is None: + self._url_parts_ = urlparse.urlparse(self._request.url.lower()) + + return self._url_parts_ + + @property + def scheme(self): + return self._url_parts.scheme + + @property + def netloc(self): + return self._url_parts.netloc + + @property + def path(self): + return self._url_parts.path + + @property + def query(self): + return self._url_parts.query + + @property + def qs(self): + if self._qs is None: + self._qs = urlparse.parse_qs(self.query) + + return self._qs + + @classmethod + def _create(cls, *args, **kwargs): + return cls(requests.Request(*args, **kwargs).prepare()) + + class _RequestHistoryTracker(object): def __init__(self): @@ -191,19 +242,18 @@ def _match_url(self, request): if hasattr(self._url, 'search'): return self._url.search(request.url) is not None - url = urlparse.urlparse(request.url.lower()) - - if self._url_parts.scheme and url.scheme != self._url_parts.scheme: + if self._url_parts.scheme and request.scheme != self._url_parts.scheme: return False - if self._url_parts.netloc and url.netloc != self._url_parts.netloc: + if self._url_parts.netloc and request.netloc != self._url_parts.netloc: return False - if (url.path or '/') != (self._url_parts.path or '/'): + if (request.path or '/') != (self._url_parts.path or '/'): return False + # construct our own qs structure as we remove items from it below + request_qs = urlparse.parse_qs(request.query) matcher_qs = urlparse.parse_qs(self._url_parts.query) - request_qs = urlparse.parse_qs(url.query) for k, vals in six.iteritems(matcher_qs): for v in vals: @@ -263,6 +313,7 @@ def __init__(self): self._matchers = [] def send(self, request, **kwargs): + request = _RequestObjectProxy(request) self._add_to_history(request) for matcher in reversed(self._matchers): diff --git a/requests_mock/tests/test_adapter.py b/requests_mock/tests/test_adapter.py index 4923794..6306070 100644 --- a/requests_mock/tests/test_adapter.py +++ b/requests_mock/tests/test_adapter.py @@ -14,6 +14,7 @@ import requests import six +from six.moves.urllib import parse as urlparse import requests_mock from requests_mock.tests import base @@ -42,6 +43,15 @@ def assertLastRequest(self, method='GET', body=None): self.assertEqual(method, self.adapter.last_request.method) self.assertEqual(body, self.adapter.last_request.body) + url_parts = urlparse.urlparse(self.url) + qs = urlparse.parse_qs(url_parts.query) + self.assertEqual(url_parts.scheme, self.adapter.last_request.scheme) + self.assertEqual(url_parts.netloc, self.adapter.last_request.netloc) + self.assertEqual(url_parts.path, self.adapter.last_request.path) + self.assertEqual(url_parts.query, self.adapter.last_request.query) + self.assertEqual(url_parts.query, self.adapter.last_request.query) + self.assertEqual(qs, self.adapter.last_request.qs) + def test_content(self): data = six.b('testdata') @@ -340,3 +350,31 @@ def test_called_and_called_count(self): self.assertEqual(len(resps), self.adapter.call_count) self.assertTrue(self.adapter.called) + + def test_query_string(self): + qs = 'a=1&b=2' + self.adapter.register_uri('GET', self.url, text='resp') + resp = self.session.get("%s?%s" % (self.url, qs)) + + self.assertEqual('resp', resp.text) + + self.assertEqual(qs, self.adapter.last_request.query) + self.assertEqual(['1'], self.adapter.last_request.qs['a']) + self.assertEqual(['2'], self.adapter.last_request.qs['b']) + + def test_adapter_picks_correct_adatper(self): + good = '%s://test3.url/' % self.PREFIX + self.adapter.register_uri('GET', + '%s://test1.url' % self.PREFIX, + text='bad') + self.adapter.register_uri('GET', + '%s://test2.url' % self.PREFIX, + text='bad') + self.adapter.register_uri('GET', good, text='good') + self.adapter.register_uri('GET', + '%s://test4.url' % self.PREFIX, + text='bad') + + resp = self.session.get(good) + + self.assertEqual('good', resp.text) diff --git a/requests_mock/tests/test_matcher.py b/requests_mock/tests/test_matcher.py index 0355a95..ba1081d 100644 --- a/requests_mock/tests/test_matcher.py +++ b/requests_mock/tests/test_matcher.py @@ -35,7 +35,9 @@ def match(self, [], complete_qs, request_headers) - request = requests.Request(request_method, url, headers).prepare() + request = adapter._RequestObjectProxy._create(request_method, + url, + headers) return matcher._match(request) def assertMatch(self,