Skip to content

Commit

Permalink
Add argument to allow connection to specific hosts (#9)
Browse files Browse the repository at this point in the history
Resolves #6 
Closes #8
  • Loading branch information
sadams authored and miketheman committed Jun 26, 2018
1 parent 4459011 commit 07e1c89
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -71,3 +71,4 @@ target/
# pyenv
.python-version

.pytest_cache/
19 changes: 17 additions & 2 deletions README.rst
Expand Up @@ -16,7 +16,7 @@ pytest-socket
:target: https://ci.appveyor.com/project/miketheman/pytest-socket/branch/master
:alt: See Build Status on AppVeyor

A plugin to use with Pytest to disable ``socket`` calls during tests to ensure network calls are prevented.
A plugin to use with Pytest to disable or restrict ``socket`` calls during tests to ensure network calls are prevented.

----

Expand Down Expand Up @@ -77,7 +77,22 @@ Usage
@pytest.mark.enable_socket
def test_explicitly_enable_socket_with_mark():
assert socket.socket(socket.AF_INET, socket.SOCK_STREAM)
assert socket.socket(socket.AF_INET, socket.SOCK_STREAM)
* To allow only specific hosts per-test:

.. code:: python
@pytest.mark.allow_hosts(['127.0.0.1'])
def test_explicitly_enable_socket_with_mark():
assert socket.socket.connect(('127.0.0.1', 80))
or for whole test run

.. code:: ini
[pytest]
addopts = --allow-hosts=127.0.0.1,127.0.1.1
Contributing
Expand Down
66 changes: 66 additions & 0 deletions pytest_socket.py
Expand Up @@ -4,13 +4,23 @@
import pytest

_true_socket = socket.socket
_true_connect = socket.socket.connect


class SocketBlockedError(RuntimeError):
def __init__(self, *args, **kwargs):
super(SocketBlockedError, self).__init__("A test tried to use socket.socket.")


class SocketConnectBlockedError(RuntimeError):
def __init__(self, allowed, host, *args, **kwargs):
if allowed:
allowed = ','.join(allowed)
super(SocketConnectBlockedError, self).__init__(
'A test tried to use socket.socket.connect() with host "{0}" (allowed: "{1}").'.format(host, allowed)
)


def pytest_addoption(parser):
group = parser.getgroup('socket')
group.addoption(
Expand All @@ -19,6 +29,12 @@ def pytest_addoption(parser):
dest='disable_socket',
help='Disable socket.socket by default to block network calls.'
)
group.addoption(
'--allow-hosts',
dest='allow_hosts',
metavar='ALLOWED_HOSTS_CSV',
help='Only allow specified hosts through socket.socket.connect((host, port)).'
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -62,3 +78,53 @@ def enable_socket():
""" re-enable socket.socket to enable the Internet. useful in testing.
"""
socket.socket = _true_socket


def pytest_configure(config):
config.addinivalue_line("markers", "disable_socket(): Disable socket connections for a specific test")
config.addinivalue_line("markers", "enable_socket(): Enable socket connections for a specific test")
config.addinivalue_line("markers", "allow_hosts([hosts]): Restrict socket connection to defined list of hosts")


def pytest_runtest_setup(item):
mark_restrictions = item.get_closest_marker('allow_hosts')
cli_restrictions = item.config.getoption('--allow-hosts')
hosts = None
if mark_restrictions:
hosts = mark_restrictions.args[0]
elif cli_restrictions:
hosts = cli_restrictions
socket_allow_hosts(hosts)


def pytest_runtest_teardown():
remove_host_restrictions()


def host_from_connect_args(args):
address = args[0]
if isinstance(address, tuple) and isinstance(address[0], str):
return address[0]


def socket_allow_hosts(allowed=None):
""" disable socket.socket.connect() to disable the Internet. useful in testing.
"""
if isinstance(allowed, str):
allowed = allowed.split(',')
if not isinstance(allowed, list):
return

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host and host in allowed:
return _true_connect(inst, *args)
raise SocketConnectBlockedError(allowed, host)

socket.socket.connect = guarded_connect


def remove_host_restrictions():
""" restore socket.socket.connect() to allow access to the Internet. useful in testing.
"""
socket.socket.connect = _true_connect
231 changes: 231 additions & 0 deletions tests/test_restrict_hosts.py
@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
import pytest

try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse

import inspect


localhost = '127.0.0.1'

connect_code_template = """
import socket
import pytest
{3}
def {2}():
socket.socket().connect(('{0}', {1}))
"""

urlopen_code_template = """
import pytest
try:
from urllib.request import urlopen
except ImportError:
from urllib2 import urlopen
{3}
def {2}():
assert urlopen('http://{0}:{1}/').getcode() == 200
"""


def assert_host_blocked(result, host):
result.stdout.fnmatch_lines('*A test tried to use socket.socket.connect() with host "{0}"*'.format(host))


@pytest.fixture
def assert_connect(httpbin, testdir):
def assert_socket_connect(should_pass, **kwargs):
# get the name of the calling function
test_name = inspect.stack()[1][3]
test_url = urlparse(httpbin.url)

mark = ''
cli_arg = kwargs.get('cli_arg', None)
code_template = kwargs.get('code_template', connect_code_template)
mark_arg = kwargs.get('mark_arg', None)

if mark_arg and isinstance(mark_arg, str):
mark = '@pytest.mark.allow_hosts("{0}")'.format(mark_arg)
elif mark_arg and isinstance(mark_arg, list):
mark = '@pytest.mark.allow_hosts(["{0}"])'.format('","'.join(mark_arg))
code = code_template.format(test_url.hostname, test_url.port, test_name, mark)
testdir.makepyfile(code)

if cli_arg:
result = testdir.runpytest("--verbose", '--allow-hosts={0}'.format(cli_arg))
else:
result = testdir.runpytest("--verbose")

if should_pass:
result.assert_outcomes(1, 0, 0)
else:
result.assert_outcomes(0, 0, 1)
assert_host_blocked(result, test_url.hostname)
return assert_socket_connect


def test_help_message(testdir):
result = testdir.runpytest(
'--help',
)
result.stdout.fnmatch_lines([
'socket:',
'*--allow-hosts=ALLOWED_HOSTS_CSV',
'*Only allow specified hosts through',
'*socket.socket.connect((host, port)).'
])


def test_marker_help_message(testdir):
result = testdir.runpytest(
'--markers',
)
result.stdout.fnmatch_lines([
'@pytest.mark.allow_hosts([[]hosts[]]): Restrict socket connection to defined list of hosts',
])


def test_default_connect_enabled(assert_connect):
assert_connect(True)


def test_single_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost)


def test_multiple_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost + ',1.2.3.4')


def test_single_mark_arg_connect_enabled(assert_connect):
assert_connect(True, mark_arg=localhost)


def test_multiple_mark_arg_csv_connect_enabled(assert_connect):
assert_connect(True, mark_arg=localhost + ',1.2.3.4')


def test_multiple_mark_arg_list_connect_enabled(assert_connect):
assert_connect(True, mark_arg=[localhost, '1.2.3.4'])


def test_mark_cli_conflict_mark_wins_connect_enabled(assert_connect):
assert_connect(True, mark_arg=[localhost], cli_arg='1.2.3.4')


def test_single_cli_arg_connect_disabled(assert_connect):
assert_connect(False, cli_arg='1.2.3.4')


def test_multiple_cli_arg_connect_disabled(assert_connect):
assert_connect(False, cli_arg='5.6.7.8,1.2.3.4')


def test_single_mark_arg_connect_disabled(assert_connect):
assert_connect(False, mark_arg='1.2.3.4')


def test_multiple_mark_arg_csv_connect_disabled(assert_connect):
assert_connect(False, mark_arg='5.6.7.8,1.2.3.4')


def test_multiple_mark_arg_list_connect_disabled(assert_connect):
assert_connect(False, mark_arg=['5.6.7.8', '1.2.3.4'])


def test_mark_cli_conflict_mark_wins_connect_disabled(assert_connect):
assert_connect(False, mark_arg=['1.2.3.4'], cli_arg=localhost)


def test_default_urllib_succeeds_by_default(assert_connect):
assert_connect(True, code_template=urlopen_code_template)


def test_single_cli_arg_urlopen_enabled(assert_connect):
assert_connect(True, cli_arg=localhost + ',1.2.3.4', code_template=urlopen_code_template)


def test_single_mark_arg_urlopen_enabled(assert_connect):
assert_connect(True, mark_arg=[localhost, '1.2.3.4'], code_template=urlopen_code_template)


def test_global_restrict_via_config_fail(testdir):
testdir.makepyfile("""
import socket
def test_global_restrict_via_config_fail():
socket.socket().connect(('127.0.0.1', 80))
""")
testdir.makeini("""
[pytest]
addopts = --allow-hosts=2.2.2.2
""")
result = testdir.runpytest("--verbose")
result.assert_outcomes(0, 0, 1)
assert_host_blocked(result, '127.0.0.1')


def test_global_restrict_via_config_pass(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile("""
import socket
def test_global_restrict_via_config_pass():
socket.socket().connect(('{0}', {1}))
""".format(test_url.hostname, test_url.port))
testdir.makeini("""
[pytest]
addopts = --allow-hosts={0}
""".format(test_url.hostname))
result = testdir.runpytest("--verbose")
result.assert_outcomes(1, 0, 0)


def test_test_isolation(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile("""
import pytest
import socket
@pytest.mark.allow_hosts('{0}')
def test_pass():
socket.socket().connect(('{0}', {1}))
@pytest.mark.allow_hosts('2.2.2.2')
def test_fail():
socket.socket().connect(('{0}', {1}))
def test_pass_2():
socket.socket().connect(('{0}', {1}))
""".format(test_url.hostname, test_url.port))
result = testdir.runpytest("--verbose")
result.assert_outcomes(2, 0, 1)
assert_host_blocked(result, test_url.hostname)


def test_conflicting_cli_vs_marks(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile("""
import pytest
import socket
@pytest.mark.allow_hosts('{0}')
def test_pass():
socket.socket().connect(('{0}', {1}))
@pytest.mark.allow_hosts('2.2.2.2')
def test_fail():
socket.socket().connect(('{0}', {1}))
def test_fail_2():
socket.socket().connect(('2.2.2.2', {1}))
""".format(test_url.hostname, test_url.port))
result = testdir.runpytest("--verbose", '--allow-hosts=1.2.3.4')
result.assert_outcomes(1, 0, 2)
assert_host_blocked(result, '2.2.2.2')
assert_host_blocked(result, test_url.hostname)
5 changes: 4 additions & 1 deletion tox.ini
Expand Up @@ -2,8 +2,11 @@
[tox]
envlist = py27,py34,py35,py36,pypy,flake8,coverage


[testenv]
deps = pytest
deps =
pytest
pytest-httpbin
commands = pytest {posargs:tests}

[testenv:flake8]
Expand Down

0 comments on commit 07e1c89

Please sign in to comment.