Skip to content

Commit

Permalink
Merge 0c810cb into 2a1f277
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Nov 25, 2019
2 parents 2a1f277 + 0c810cb commit b70f112
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._exceptions import * # noQA
from ._schemes import Scheme
from ._state import reset, switch
from ._imports import module_load, activate
from ._testing import TestCase, cases


Expand All @@ -54,4 +55,8 @@
# aliases
'invariant',
'require',

# module level
'module_load',
'activate',
]
112 changes: 112 additions & 0 deletions deal/_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# built-in
import ast
import sys
from _frozen_importlib_external import PathFinder
from types import ModuleType
from typing import Callable, Optional, List

from .linter._extractors.common import get_name
from . import _aliases


class DealFinder(PathFinder):
@classmethod
def find_spec(cls, *args, **kwargs):
spec = super().find_spec(*args, **kwargs)
if spec is not None:
spec.loader = DealLoader(spec.loader)
return spec


class DealLoader:
def __init__(self, loader):
self._loader = loader

def __getattr__(self, name: str):
return getattr(self._loader, name)

def exec_module(self, module: ModuleType) -> None:
if not hasattr(self._loader, 'get_source'):
return self._loader.exec_module(module)

# get nodes with module-level contracts from the source code
source = self._loader.get_source(module.__name__)
if source is None:
return self._loader.exec_module(module)
tree = ast.parse(source)
nodes = self._get_contracts(tree=tree)
if not nodes:
return self._loader.exec_module(module)

# convert contracts nodes into real contracts
contracts = []
for node in nodes:
contract = self._exec_contract(node=node)
if contract is None:
msg = 'unsupported contract: {}'.format(ast.dump(node))
raise RuntimeError(msg)
contracts.append(contract)

# execute module with contracts
wrapped = _aliases.chain(contract)(self._loader.exec_module)
wrapped(module)

@staticmethod
def _get_contracts(tree: ast.Module) -> List[ast.AST]:
for node in tree.body:
if not type(node) is ast.Expr:
continue
if not type(node.value) is ast.Call:
continue
if get_name(node.value.func) != 'deal.module_load':
continue
return node.value.args
return []

@classmethod
def _exec_contract(cls, node: ast.AST) -> Optional[Callable]:
"""Get AST node and return a contract function
"""
if type(node) is ast.Call and not node.args:
return cls._exec_contract(node.func)

if not isinstance(node, ast.Attribute):
return None
if node.value.id != 'deal':
return None
contract = getattr(_aliases, node.attr, None)
if contract is None:
return None
return contract


def module_load(*contracts) -> None:
if not contracts:
raise RuntimeError('no contracts specified')
if DealFinder not in sys.meta_path:
msg = 'deal.activate must be called '
msg += 'before importing anything with deal.module_load contract'
raise RuntimeError(msg)


def activate() -> bool:
"""Activate module-level checks.
This function must be called before importing anything
with deal.module_load() contract.
"""
if DealFinder in sys.meta_path:
return False
index = sys.meta_path.index(PathFinder)
sys.meta_path[index] = DealFinder
return True


def deactivate() -> bool:
"""used in tests
"""
if DealFinder not in sys.meta_path:
return False
index = sys.meta_path.index(DealFinder)
sys.meta_path[index] = PathFinder
return True
118 changes: 118 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import ast
from textwrap import dedent

import pytest

import deal
from deal.linter._extractors.common import get_name
from deal._imports import DealLoader, deactivate


def test_get_contracts():
text = """
import deal
a = 1
1 / 0
not_a_deal.module_load(something)
deal.module_load(deal.silent)
"""
text = dedent(text)
tree = ast.parse(text)
print(ast.dump(tree))
nodes = DealLoader._get_contracts(tree=tree)
assert [get_name(node) for node in nodes] == ['deal.silent']


@pytest.mark.parametrize('text, expected', [
('deal.silent', deal.silent),
('deal.silent()', deal.silent),
('deal.pre(something)', None),
('not_a_deal.silent', None),
('deal.typo', None),
])
def test_exec_contract(text, expected):
tree = ast.parse(text)
print(ast.dump(tree))
actual = DealLoader._exec_contract(node=tree.body[0].value)
assert actual == expected


class TestException(Exception):
pass


class TestModule:
pass


class SubLoader:
def __init__(self, ok, text):
self.ok = ok
self.text = text

def get_source(self, module):
assert module == 'TestModule'
return self.text

def exec_module(self, module):
assert module is TestModule
if not self.ok:
raise TestException
print(1)


def test_exec_module():
text = """
import deal
deal.module_load(deal.silent)
print(1)
"""
text = dedent(text)

with pytest.raises(TestException):
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule)

with pytest.raises(deal.SilentContractError):
DealLoader(loader=SubLoader(ok=True, text=text)).exec_module(TestModule)


def test_exec_module_invalid_contract():
text = """
import deal
deal.module_load(deal.pre(something))
print(1)
"""
text = dedent(text)
with pytest.raises(RuntimeError, match='unsupported contract:.+'):
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule)


def test_exec_module_no_contracts():
text = """
import deal
print(1)
"""
text = dedent(text)
DealLoader(loader=SubLoader(ok=True, text=text)).exec_module(TestModule)
with pytest.raises(TestException):
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule)


def test_module_load():
assert deal.activate()
try:
deal.module_load(deal.silent)
finally:
assert deactivate()

with pytest.raises(RuntimeError):
deal.module_load(deal.silent)


def test_activate():
try:
assert deal.activate()
assert not deal.activate()
finally:
assert deactivate()
assert not deactivate()

0 comments on commit b70f112

Please sign in to comment.