Skip to content

Commit

Permalink
Add wrapper around Request object
Browse files Browse the repository at this point in the history
The request object gets picked up in the history. It would be useful to
have a wrapper around it so that we can add some test helpers to the
last_request etc.
  • Loading branch information
Jamie Lennox committed Jul 30, 2014
1 parent 4ae5f15 commit 1e6e6a1
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
63 changes: 57 additions & 6 deletions requests_mock/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import json as jsonutils

import requests
from requests.adapters import BaseAdapter, HTTPAdapter
from requests.packages.urllib3.response import HTTPResponse
import six
Expand All @@ -31,6 +32,56 @@ def __init__(self, headers, status_code, reason):
self.reason = reason


class _RequestObjectProxy(object):
"""A wrapper around a requests.Request that gives some extra information.
This will be important both for matching and so that when it's save into
the request_history users will be able to access these properties.
"""

def __init__(self, request):
self._request = request
self._url_parts_ = None
self._qs = None

def __getattr__(self, name):
return getattr(self._request, name)

@property
def _url_parts(self):
if self._url_parts_ is None:
self._url_parts_ = urlparse.urlparse(self._request.url.lower())

return self._url_parts_

@property
def scheme(self):
return self._url_parts.scheme

@property
def netloc(self):
return self._url_parts.netloc

@property
def path(self):
return self._url_parts.path

@property
def query(self):
return self._url_parts.query

@property
def qs(self):
if self._qs is None:
self._qs = urlparse.parse_qs(self.query)

return self._qs

@classmethod
def _create(cls, *args, **kwargs):
return cls(requests.Request(*args, **kwargs).prepare())


class _RequestHistoryTracker(object):

def __init__(self):
Expand Down Expand Up @@ -191,19 +242,18 @@ def _match_url(self, request):
if hasattr(self._url, 'search'):
return self._url.search(request.url) is not None

url = urlparse.urlparse(request.url.lower())

if self._url_parts.scheme and url.scheme != self._url_parts.scheme:
if self._url_parts.scheme and request.scheme != self._url_parts.scheme:
return False

if self._url_parts.netloc and url.netloc != self._url_parts.netloc:
if self._url_parts.netloc and request.netloc != self._url_parts.netloc:
return False

if (url.path or '/') != (self._url_parts.path or '/'):
if (request.path or '/') != (self._url_parts.path or '/'):
return False

# construct our own qs structure as we remove items from it below
request_qs = urlparse.parse_qs(request.query)
matcher_qs = urlparse.parse_qs(self._url_parts.query)
request_qs = urlparse.parse_qs(url.query)

for k, vals in six.iteritems(matcher_qs):
for v in vals:
Expand Down Expand Up @@ -263,6 +313,7 @@ def __init__(self):
self._matchers = []

def send(self, request, **kwargs):
request = _RequestObjectProxy(request)
self._add_to_history(request)

for matcher in reversed(self._matchers):
Expand Down
38 changes: 38 additions & 0 deletions requests_mock/tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import requests
import six
from six.moves.urllib import parse as urlparse

import requests_mock
from requests_mock.tests import base
Expand Down Expand Up @@ -42,6 +43,15 @@ def assertLastRequest(self, method='GET', body=None):
self.assertEqual(method, self.adapter.last_request.method)
self.assertEqual(body, self.adapter.last_request.body)

url_parts = urlparse.urlparse(self.url)
qs = urlparse.parse_qs(url_parts.query)
self.assertEqual(url_parts.scheme, self.adapter.last_request.scheme)
self.assertEqual(url_parts.netloc, self.adapter.last_request.netloc)
self.assertEqual(url_parts.path, self.adapter.last_request.path)
self.assertEqual(url_parts.query, self.adapter.last_request.query)
self.assertEqual(url_parts.query, self.adapter.last_request.query)
self.assertEqual(qs, self.adapter.last_request.qs)

def test_content(self):
data = six.b('testdata')

Expand Down Expand Up @@ -340,3 +350,31 @@ def test_called_and_called_count(self):

self.assertEqual(len(resps), self.adapter.call_count)
self.assertTrue(self.adapter.called)

def test_query_string(self):
qs = 'a=1&b=2'
self.adapter.register_uri('GET', self.url, text='resp')
resp = self.session.get("%s?%s" % (self.url, qs))

self.assertEqual('resp', resp.text)

self.assertEqual(qs, self.adapter.last_request.query)
self.assertEqual(['1'], self.adapter.last_request.qs['a'])
self.assertEqual(['2'], self.adapter.last_request.qs['b'])

def test_adapter_picks_correct_adatper(self):
good = '%s://test3.url/' % self.PREFIX
self.adapter.register_uri('GET',
'%s://test1.url' % self.PREFIX,
text='bad')
self.adapter.register_uri('GET',
'%s://test2.url' % self.PREFIX,
text='bad')
self.adapter.register_uri('GET', good, text='good')
self.adapter.register_uri('GET',
'%s://test4.url' % self.PREFIX,
text='bad')

resp = self.session.get(good)

self.assertEqual('good', resp.text)
4 changes: 3 additions & 1 deletion requests_mock/tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def match(self,
[],
complete_qs,
request_headers)
request = requests.Request(request_method, url, headers).prepare()
request = adapter._RequestObjectProxy._create(request_method,
url,
headers)
return matcher._match(request)

def assertMatch(self,
Expand Down

0 comments on commit 1e6e6a1

Please sign in to comment.