Skip to content

Commit

Permalink
Added client support for header
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Nov 12, 2013
1 parent 89b64a2 commit 0f5c26a
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 34 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
@@ -1,6 +1,12 @@
CHANGES
=======

0.4.1 (11-12-2013)
------------------

- Added client support for `expect: 100-continue` header.


0.4 (11-06-2013)
----------------

Expand Down
2 changes: 2 additions & 0 deletions README.rst
Expand Up @@ -55,6 +55,7 @@ The signature of request is the following::
conn_timeout=None,
compress=None,
chunked=None,
expect100=False,
session=None,
verify_ssl=True,
loop=None
Expand All @@ -81,6 +82,7 @@ It constructs and sends a request. It returns response object. Parameters are ex
with deflate encoding.
- ``chunked``: Boolean or Integer. Set to chunk size for chunked
transfer encoding.
- ``expect100``: Boolean. Expect 100-continue response from server.
- ``session``: ``aiohttp.Session`` instance to support connection pooling and
session cookies.
- ``loop``: Optional event loop.
Expand Down
108 changes: 75 additions & 33 deletions aiohttp/client.py
Expand Up @@ -40,6 +40,7 @@ def request(method, url, *,
conn_timeout=None,
compress=None,
chunked=None,
expect100=False,
session=None,
verify_ssl=True,
loop=None):
Expand All @@ -65,6 +66,7 @@ def request(method, url, *,
with deflate encoding.
chunked: Boolean or Integer. Set to chunk size for chunked
transfer encoding.
expect100: Boolean. Expect 100-continue response from server.
session: aiohttp.Session instance to support connection pooling and
session cookies.
loop: Optional event loop.
Expand All @@ -88,7 +90,7 @@ def request(method, url, *,
method, url, params=params, headers=headers, data=data,
cookies=cookies, files=files, auth=auth, encoding=encoding,
version=version, compress=compress, chunked=chunked,
verify_ssl=verify_ssl, loop=loop)
verify_ssl=verify_ssl, loop=loop, expect100=expect100)

if session is None:
conn = _connect(req, loop)
Expand Down Expand Up @@ -254,6 +256,7 @@ def request(self, method=None, path=None, *,
timeout=None,
conn_timeout=None,
chunked=None,
expect100=False,
verify_ssl=True):

if method is None:
Expand All @@ -280,14 +283,12 @@ def request(self, method=None, path=None, *,
self._schema, host_info[0], host_info[1]), path)
try:
resp = yield from request(
method, url,
params=params, data=data, headers=headers,
cookies=cookies, files=files, auth=auth,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
encoding=encoding, version=version,
timeout=timeout, conn_timeout=conn_timeout,
compress=compress, chunked=chunked, verify_ssl=verify_ssl,
method, url, params=params, data=data, headers=headers,
cookies=cookies, files=files, auth=auth, encoding=encoding,
allow_redirects=allow_redirects, version=version,
max_redirects=max_redirects, conn_timeout=conn_timeout,
timeout=timeout, compress=compress, chunked=chunked,
verify_ssl=verify_ssl, expect100=expect100,
session=self._session, loop=self._loop)
except (aiohttp.ConnectionError, aiohttp.TimeoutError):
if host_info in hosts:
Expand Down Expand Up @@ -319,11 +320,13 @@ class HttpRequest:
response = None

_writer = None # async task for streaming body
_continue = None # waiter future for '100 Continue' response

def __init__(self, method, url, *,
params=None, headers=None, data=None, cookies=None,
files=None, auth=None, encoding='utf-8', version=(1, 1),
compress=None, chunked=None, verify_ssl=True, loop=None):
compress=None, chunked=None, expect100=False,
verify_ssl=True, loop=None):
self.url = url
self.method = method.upper()
self.encoding = encoding
Expand All @@ -347,6 +350,13 @@ def __init__(self, method, url, *,
self.update_body_from_files(files, data)

self.update_transfer_encoding()
self.update_expect_continue(expect100)

def __del__(self):
"""Close request on GC"""
if self._writer is not None:
self._writer.cancel()
self._writer = None

def update_host(self, url):
"""Update destination host, port and connection type (ssl)."""
Expand Down Expand Up @@ -592,26 +602,47 @@ def update_transfer_encoding(self):
if 'content-length' not in self.headers:
self.headers['content-length'] = str(len(self.body))

def update_expect_continue(self, expect=False):
if expect:
self.headers['expect'] = '100-continue'
elif self.headers.get('expect', '').lower() == '100-continue':
expect = True

if expect:
self._continue = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def write_bytes(self, request, stream):
def write_bytes(self, request, is_stream):
"""Support coroutines that yields bytes objects."""
value = None
# 100 response
if self._continue is not None:
yield from self._continue

while True:
try:
result = stream.send(value)
except StopIteration as exc:
if isinstance(exc.value, bytes):
request.write(exc.value)
break
if is_stream:
value = None
stream = self.body

if isinstance(result, asyncio.Future):
value = yield from result
elif isinstance(result, bytes):
request.write(result)
value = None
else:
raise ValueError('Bytes object is expected.')
while True:
try:
result = stream.send(value)
except StopIteration as exc:
if isinstance(exc.value, bytes):
request.write(exc.value)
break

if isinstance(result, asyncio.Future):
value = yield from result
elif isinstance(result, bytes):
request.write(result)
value = None
else:
raise ValueError('Bytes object is expected.')
else:
if isinstance(self.body, bytes):
self.body = (self.body,)

for chunk in self.body:
request.write(chunk)

request.write_eof()
self._writer = None
Expand All @@ -629,9 +660,11 @@ def send(self, transport):
request.add_headers(*self.headers.items())
request.send_headers()

if inspect.isgenerator(self.body):
is_stream = inspect.isgenerator(self.body)

if is_stream or self._continue is not None:
self._writer = asyncio.async(
self.write_bytes(request, self.body), loop=self.loop)
self.write_bytes(request, is_stream), loop=self.loop)
else:
if isinstance(self.body, bytes):
self.body = (self.body,)
Expand All @@ -641,7 +674,8 @@ def send(self, transport):

request.write_eof()

self.response = HttpResponse(self.method, self.path, self.host)
self.response = HttpResponse(
self.method, self.path, self.host, self._continue)
return self.response

@asyncio.coroutine
Expand Down Expand Up @@ -670,13 +704,14 @@ class HttpResponse(http.client.HTTPMessage):

_response_parser = aiohttp.HttpResponseParser()

def __init__(self, method, url, host=''):
def __init__(self, method, url, host='', continue100=None):
super().__init__()

self.method = method
self.url = url
self.host = host
self._content = None
self._continue = continue100

def __del__(self):
self.close()
Expand All @@ -693,10 +728,17 @@ def start(self, stream, transport):
self.stream = stream
self.transport = transport

httpstream = stream.set_parser(self._response_parser)
while True:
httpstream = stream.set_parser(self._response_parser)

# read response
self.message = yield from httpstream.read()
if self.message.code != 100:
break

# read response
self.message = yield from httpstream.read()
if self._continue is not None and not self._continue.done():
self._continue.set_result(True)
self._continue = None

# response status
self.version = self.message.version
Expand Down
5 changes: 5 additions & 0 deletions aiohttp/test_utils.py
Expand Up @@ -67,6 +67,11 @@ def handle_request(self, message, payload):
if properties.get('noresponse', False):
yield from asyncio.sleep(99999)

for hdr, val in message.headers:
if (hdr == 'EXPECT') and (val == '100-continue'):
self.transport.write(b'HTTP/1.0 100 Continue\r\n\r\n')
break

if router is not None:
body = bytearray()
try:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
@@ -1,7 +1,7 @@
import os
from setuptools import setup, find_packages

version = '0.4'
version = '0.4.1'

install_requires = ['asyncio']
tests_require = install_requires + ['nose', 'gunicorn']
Expand Down
13 changes: 13 additions & 0 deletions tests/http_client_functional_test.py
Expand Up @@ -417,6 +417,19 @@ def stream():
self.assertEqual(str(len(data)),
content['headers']['Content-Length'])

def test_expect_continue(self):
with test_utils.run_server(self.loop, router=Functional) as httpd:
url = httpd.url('method', 'post')
r = self.loop.run_until_complete(
client.request('post', url, data={'some': 'data'},
expect100=True, loop=self.loop))
self.assertEqual(r.status, 200)

content = self.loop.run_until_complete(r.read(True))
self.assertEqual('100-continue', content['headers']['Expect'])
self.assertEqual(r.status, 200)
r.close()

def test_encoding(self):
with test_utils.run_server(self.loop, router=Functional) as httpd:
r = self.loop.run_until_complete(
Expand Down
57 changes: 57 additions & 0 deletions tests/http_client_test.py
Expand Up @@ -304,6 +304,19 @@ def test_chunked_length(self):
self.assertEqual(req.headers['Transfer-Encoding'], 'chunked')
self.assertNotIn('Content-Length', req.headers)

def test_expect100(self):
req = HttpRequest('get', 'http://python.org/',
expect100=True, loop=self.loop)
req.send(self.transport)
self.assertEqual('100-continue', req.headers['expect'])
self.assertIsNotNone(req._continue)

req = HttpRequest('get', 'http://python.org/',
headers={'expect': '100-continue'}, loop=self.loop)
req.send(self.transport)
self.assertEqual('100-continue', req.headers['expect'])
self.assertIsNotNone(req._continue)

def test_data_stream(self):
def gen():
yield b'binary data'
Expand Down Expand Up @@ -335,6 +348,50 @@ def gen():
self.assertRaises(
ValueError, self.loop.run_until_complete, req._writer)

def test_data_stream_continue(self):
def gen():
yield b'binary data'
return b' result'

req = HttpRequest(
'POST', 'http://python.org/', data=gen(),
expect100=True, loop=self.loop)
self.assertTrue(req.chunked)
self.assertTrue(inspect.isgenerator(req.body))

def coro():
yield from asyncio.sleep(0.0001, loop=self.loop)
req._continue.set_result(1)

asyncio.async(coro(), loop=self.loop)

req.send(self.transport)
self.loop.run_until_complete(req._writer)
self.assertEqual(
self.transport.write.mock_calls[-3:],
[unittest.mock.call(b'binary data result'),
unittest.mock.call(b'\r\n'),
unittest.mock.call(b'0\r\n\r\n')])

def test_data_continue(self):
req = HttpRequest(
'POST', 'http://python.org/', data=b'data',
expect100=True, loop=self.loop)

def coro():
yield from asyncio.sleep(0.0001, loop=self.loop)
req._continue.set_result(1)

asyncio.async(coro(), loop=self.loop)

req.send(self.transport)
self.assertEqual(1, len(self.transport.write.mock_calls))

self.loop.run_until_complete(req._writer)
self.assertEqual(
self.transport.write.mock_calls[-1],
unittest.mock.call(b'data'))

def test_close(self):
@asyncio.coroutine
def gen():
Expand Down

0 comments on commit 0f5c26a

Please sign in to comment.