From 17044b16aa1a4ed3472b2083c8429e6f601ab7a0 Mon Sep 17 00:00:00 2001 From: Hong Minhee Date: Sat, 14 Feb 2015 19:46:11 +0900 Subject: [PATCH] urllib.request mock --- tests/conftest.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index e2c56a6..4215acb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,17 @@ import datetime +import http.client +import io import os +import re import threading +import urllib.request from paramiko.pkey import PKey from paramiko.rsakey import RSAKey from paramiko.sftp_client import SFTPClient from paramiko.transport import Transport from pytest import fixture, yield_fixture +from werkzeug.test import EnvironBuilder from geofront.keystore import format_openssh_pubkey from geofront import server @@ -184,3 +189,97 @@ def fx_authorized_servers(fx_sftpd, fx_master_key): f.write(format_openssh_pubkey(fx_master_key)) thread.start() return fx_sftpd + + +class TestHTTPHandler(urllib.request.HTTPHandler): + + mock_hosts = {} + mock_urls = {} + + def http_open(self, req): + cls = type(self) + url = req.full_url + url_without_qs = re.sub(r'\?.*$', '', url) + try: + handler = cls.mock_urls[url_without_qs] + except KeyError: + try: + wsgi_app = cls.mock_hosts[req.host] + except KeyError: + return super().http_open(req) + builder = EnvironBuilder( + path=req.selector, + base_url=re.match(r'^https?://[^/]+', url).group(0), + method=req.get_method(), + headers=req.headers, + data=req.data, + content_type=req.headers.get( + 'content-type', + 'application/x-www-form-urlencoded' if req.data else None + ) + ) + status_code = None + headers = None + + def start_response(code, hlist): + nonlocal status_code, headers + status_code = code + headers = hlist + buffer_ = io.BytesIO() + for chunk in wsgi_app(builder.get_environ(), start_response): + buffer_.write(chunk) + buffer_.seek(0) + resp = urllib.request.addinfourl( + buffer_, + {k.lower().strip(): v for k, v in headers}, + url + ) + code, resp.msg = status_code.split(None, 1) + resp.code = resp.status = int(code) + resp.reason = resp.msg + resp.version = 10 + return resp + content, status_code, headers = handler(req) + if isinstance(content, str): + buffer_ = io.StringIO(content) + elif isinstance(content, bytes): + buffer_ = io.BytesIO(content) + elif isinstance(content, io.IOBase): + buffer_ = content + else: + raise TypeError('content must be a string, or a bytes, or a file ' + 'object, not ' + repr(content)) + resp = urllib.request.addinfourl( + buffer_, + {k.lower().strip(): v for k, v in headers.items()}, + url + ) + resp.code = status_code + resp.msg = http.client.responses[status_code] + return resp + + @classmethod + def route(cls, url: str): + def decorate(function): + cls.mock_urls[url] = function + return function + return decorate + + @classmethod + def route_wsgi(cls, host, app): + assert callable(app) + cls.mock_hosts[host] = app + + +@yield_fixture +def fx_urllib_mock(request): + original_opener = urllib.request._opener + handler_cls = type( + 'TestHTTPHandler_', + (TestHTTPHandler,), + {'mock_urls': {}, 'mock_hosts': {}} + ) + opener = urllib.request.build_opener(handler_cls) + urllib.request.install_opener(opener) + yield handler_cls + urllib.request._opener = original_opener