Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Fetching contributors…

Cannot retrieve contributors at this time

396 lines (321 sloc) 14.746 kb
#!/usr/bin/env python
from __future__ import absolute_import, division, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, LogTrapTestCase
from tornado.test.util import unittest
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler
import os
import shutil
import socket
import sys
import tempfile
try:
import ssl
except ImportError:
ssl = None
class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
return Application([('/', self.__class__.Handler)])
def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
response.rethrow()
return json_decode(response.body)
class HelloWorldRequestHandler(RequestHandler):
def initialize(self, protocol="http"):
self.expected_protocol = protocol
def get(self):
if self.request.protocol != self.expected_protocol:
raise Exception("unexpected protocol")
self.finish("Hello world")
def post(self):
self.finish("Got %d bytes in POST" % len(self.request.body))
skipIfNoSSL = unittest.skipIf(ssl is None, "ssl module not present")
# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
# ClientHello messages, which are rejected by SSLv3 and TLSv1
# servers. Note that while the OPENSSL_VERSION_INFO was formally
# introduced in python3.2, it was present but undocumented in
# python 2.7
skipIfOldSSL = unittest.skipIf(
getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0),
"old version of ssl module and/or openssl")
class BaseSSLTest(AsyncHTTPSTestCase, LogTrapTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler,
dict(protocol="https"))])
class SSLTestMixin(object):
def get_ssl_options(self):
return dict(ssl_version=self.get_ssl_version(),
**AsyncHTTPSTestCase.get_ssl_options())
def get_ssl_version(self):
raise NotImplementedError()
def test_ssl(self):
response = self.fetch('/')
self.assertEqual(response.body, b("Hello world"))
def test_large_post(self):
response = self.fetch('/',
method='POST',
body='A' * 5000)
self.assertEqual(response.body, b("Got 5000 bytes in POST"))
def test_non_ssl_request(self):
# Make sure the server closes the connection when it gets a non-ssl
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
self.http_client.fetch(self.get_url("/"), self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
# of SSLv23 allows it.
class SSLv23Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv23
SSLv23Test = skipIfNoSSL(SSLv23Test)
class SSLv3Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv3
SSLv3Test = skipIfNoSSL(skipIfOldSSL(SSLv3Test))
class TLSv1Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_TLSv1
TLSv1Test = skipIfNoSSL(skipIfOldSSL(TLSv1Test))
class BadSSLOptionsTest(unittest.TestCase):
def test_missing_arguments(self):
application = Application()
self.assertRaises(KeyError, HTTPServer, application, ssl_options={
"keyfile": "/__missing__.crt",
})
def test_missing_key(self):
'''A missing SSL key should cause an immediate exception.'''
application = Application()
module_dir = os.path.dirname(__file__)
existing_certificate = os.path.join(module_dir, 'test.crt')
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": "/__mising__.crt",
})
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
"certfile": existing_certificate,
"keyfile": "/__missing__.key"
})
# This actually works because both files exist
server = HTTPServer(application, ssl_options={
"certfile": existing_certificate,
"keyfile": existing_certificate
})
class MultipartTestHandler(RequestHandler):
def post(self):
self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
"argument": self.get_argument("argument"),
"filename": self.request.files["files"][0].filename,
"filebody": _unicode(self.request.files["files"][0]["body"]),
})
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
def set_request(self, request):
self.__next_request = request
def _on_connect(self, parsed, parsed_hostname):
self.stream.write(self.__next_request)
self.__next_request = None
self.stream.read_until(b("\r\n\r\n"), self._on_headers)
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
("/hello", HelloWorldRequestHandler)]
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
client = SimpleAsyncHTTPClient(self.io_loop)
conn = RawRequestHTTPConnection(self.io_loop, client,
httpclient.HTTPRequest(self.get_url("/")),
None, self.stop,
1024 * 1024)
conn.set_request(
b("\r\n").join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b("\r\n") + body)
response = self.wait()
client.close()
response.rethrow()
return response
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
# anything (we use utf8 by default).
response = self.raw_fetch([
b("POST /multipart HTTP/1.0"),
b("Content-Type: multipart/form-data; boundary=1234567890"),
b("X-Header-encoding-test: \xe9"),
],
b("\r\n").join([
b("Content-Disposition: form-data; name=argument"),
b(""),
u"\u00e1".encode("utf-8"),
b("--1234567890"),
u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
b(""),
u"\u00fa".encode("utf-8"),
b("--1234567890--"),
b(""),
]))
data = json_decode(response.body)
self.assertEqual(u"\u00e9", data["header"])
self.assertEqual(u"\u00e1", data["argument"])
self.assertEqual(u"\u00f3", data["filename"])
self.assertEqual(u"\u00fa", data["filebody"])
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
b("Content-Length: 1024"),
b("Expect: 100-continue"),
b("Connection: close"),
b("\r\n")]), callback=self.stop)
self.wait()
stream.read_until(b("\r\n\r\n"), self.stop)
data = self.wait()
self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
stream.write(b("a") * 1024)
stream.read_until(b("\r\n"), self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
stream.read_until(b("\r\n\r\n"), self.stop)
header_data = self.wait()
headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b("Got 1024 bytes in POST"))
stream.close()
class EchoHandler(RequestHandler):
def get(self):
self.write(recursive_unicode(self.request.arguments))
class TypeCheckHandler(RequestHandler):
def prepare(self):
self.errors = {}
fields = [
('method', str),
('uri', str),
('version', str),
('remote_ip', str),
('protocol', str),
('host', str),
('path', str),
('query', str),
]
for field, expected_type in fields:
self.check_type(field, getattr(self.request, field), expected_type)
self.check_type('header_key', self.request.headers.keys()[0], str)
self.check_type('header_value', self.request.headers.values()[0], str)
self.check_type('cookie_key', self.request.cookies.keys()[0], str)
self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
# secure cookies
self.check_type('arg_key', self.request.arguments.keys()[0], str)
self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
def post(self):
self.check_type('body', self.request.body, bytes_type)
self.write(self.errors)
def get(self):
self.write(self.errors)
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
self.errors[name] = "expected %s, got %s" % (expected_type,
actual_type)
class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
return Application([("/echo", EchoHandler),
("/typecheck", TypeCheckHandler),
("//doubleslash", EchoHandler),
])
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
data = json_decode(response.body)
self.assertEqual(data, {u"foo": [u"\u00e9"]})
def test_types(self):
headers = {"Cookie": "foo=bar"}
response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
def test_double_slash(self):
# urlparse.urlsplit (which tornado.httpserver used to use
# incorrectly) would parse paths beginning with "//" as
# protocol-relative urls.
response = self.fetch("//doubleslash")
self.assertEqual(200, response.code)
self.assertEqual(json_decode(response.body), {})
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
self.write(dict(remote_ip=self.request.remote_ip))
def get_httpserver_options(self):
return dict(xheaders=True)
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"],
"127.0.0.1")
valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
"4.4.4.4")
valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
"2620:0:1cfe:face:b00c::3")
invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
self.assertEqual(
self.fetch_json("/", headers=invalid_chars)["remote_ip"],
"127.0.0.1")
invalid_host = {"X-Real-IP": "www.google.com"}
self.assertEqual(
self.fetch_json("/", headers=invalid_host)["remote_ip"],
"127.0.0.1")
class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
"""HTTPServers can listen on Unix sockets too.
Why would you want to do this? Nginx can proxy to backends listening
on unix sockets, for one thing (and managing a namespace for unix
sockets can be easier than managing a bunch of TCP port numbers).
Unfortunately, there's no way to specify a unix socket in a url for
an HTTP client, so we have to test this by hand.
"""
def setUp(self):
super(UnixSocketTest, self).setUp()
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir)
super(UnixSocketTest, self).tearDown()
def test_unix_socket(self):
sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
server = HTTPServer(app, io_loop=self.io_loop)
server.add_socket(sock)
stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
stream.connect(sockfile, self.stop)
self.wait()
stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
stream.read_until(b("\r\n"), self.stop)
response = self.wait()
self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
stream.read_until(b("\r\n\r\n"), self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
stream.read_bytes(int(headers["Content-Length"]), self.stop)
body = self.wait()
self.assertEqual(body, b("Hello world"))
stream.close()
server.stop()
UnixSocketTest = unittest.skipIf(
not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin',
"unix sockets not supported on this platform")
Jump to Line
Something went wrong with that request. Please try again.