-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# built-in | ||
import ast | ||
import sys | ||
from _frozen_importlib_external import PathFinder | ||
from types import ModuleType | ||
from typing import Callable, Optional, List | ||
|
||
from .linter._extractors.common import get_name | ||
from . import _aliases | ||
|
||
|
||
class DealFinder(PathFinder): | ||
@classmethod | ||
def find_spec(cls, *args, **kwargs): | ||
spec = super().find_spec(*args, **kwargs) | ||
if spec is not None: | ||
spec.loader = DealLoader(spec.loader) | ||
return spec | ||
|
||
|
||
class DealLoader: | ||
def __init__(self, loader): | ||
self._loader = loader | ||
|
||
def __getattr__(self, name: str): | ||
return getattr(self._loader, name) | ||
|
||
def exec_module(self, module: ModuleType) -> None: | ||
if not hasattr(self._loader, 'get_source'): | ||
return self._loader.exec_module(module) | ||
|
||
# get nodes with module-level contracts from the source code | ||
source = self._loader.get_source(module.__name__) | ||
if source is None: | ||
return self._loader.exec_module(module) | ||
tree = ast.parse(source) | ||
nodes = self._get_contracts(tree=tree) | ||
if not nodes: | ||
return self._loader.exec_module(module) | ||
|
||
# convert contracts nodes into real contracts | ||
contracts = [] | ||
for node in nodes: | ||
contract = self._exec_contract(node=node) | ||
if contract is None: | ||
msg = 'unsupported contract: {}'.format(ast.dump(node)) | ||
raise RuntimeError(msg) | ||
contracts.append(contract) | ||
|
||
# execute module with contracts | ||
wrapped = _aliases.chain(contract)(self._loader.exec_module) | ||
wrapped(module) | ||
|
||
@staticmethod | ||
def _get_contracts(tree: ast.Module) -> List[ast.AST]: | ||
for node in tree.body: | ||
if not type(node) is ast.Expr: | ||
continue | ||
if not type(node.value) is ast.Call: | ||
continue | ||
if get_name(node.value.func) != 'deal.module_load': | ||
continue | ||
return node.value.args | ||
return [] | ||
|
||
@classmethod | ||
def _exec_contract(cls, node: ast.AST) -> Optional[Callable]: | ||
"""Get AST node and return a contract function | ||
""" | ||
if type(node) is ast.Call and not node.args: | ||
return cls._exec_contract(node.func) | ||
|
||
if not isinstance(node, ast.Attribute): | ||
return None | ||
if node.value.id != 'deal': | ||
return None | ||
contract = getattr(_aliases, node.attr, None) | ||
if contract is None: | ||
return None | ||
return contract | ||
|
||
|
||
def module_load(*contracts) -> None: | ||
if not contracts: | ||
raise RuntimeError('no contracts specified') | ||
if DealFinder not in sys.meta_path: | ||
msg = 'deal.activate must be called ' | ||
msg += 'before importing anything with deal.module_load contract' | ||
raise RuntimeError(msg) | ||
|
||
|
||
def activate() -> bool: | ||
"""Activate module-level checks. | ||
This function must be called before importing anything | ||
with deal.module_load() contract. | ||
""" | ||
if DealFinder in sys.meta_path: | ||
return False | ||
index = sys.meta_path.index(PathFinder) | ||
sys.meta_path[index] = DealFinder | ||
return True | ||
|
||
|
||
def deactivate() -> bool: | ||
"""used in tests | ||
""" | ||
if DealFinder not in sys.meta_path: | ||
return False | ||
index = sys.meta_path.index(DealFinder) | ||
sys.meta_path[index] = PathFinder | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import ast | ||
from textwrap import dedent | ||
|
||
import pytest | ||
|
||
import deal | ||
from deal.linter._extractors.common import get_name | ||
from deal._imports import DealLoader, deactivate | ||
|
||
|
||
def test_get_contracts(): | ||
text = """ | ||
import deal | ||
a = 1 | ||
1 / 0 | ||
not_a_deal.module_load(something) | ||
deal.module_load(deal.silent) | ||
""" | ||
text = dedent(text) | ||
tree = ast.parse(text) | ||
print(ast.dump(tree)) | ||
nodes = DealLoader._get_contracts(tree=tree) | ||
assert [get_name(node) for node in nodes] == ['deal.silent'] | ||
|
||
|
||
@pytest.mark.parametrize('text, expected', [ | ||
('deal.silent', deal.silent), | ||
('deal.silent()', deal.silent), | ||
('deal.pre(something)', None), | ||
('not_a_deal.silent', None), | ||
('deal.typo', None), | ||
]) | ||
def test_exec_contract(text, expected): | ||
tree = ast.parse(text) | ||
print(ast.dump(tree)) | ||
actual = DealLoader._exec_contract(node=tree.body[0].value) | ||
assert actual == expected | ||
|
||
|
||
class TestException(Exception): | ||
pass | ||
|
||
|
||
class TestModule: | ||
pass | ||
|
||
|
||
class SubLoader: | ||
def __init__(self, ok, text): | ||
self.ok = ok | ||
self.text = text | ||
|
||
def get_source(self, module): | ||
assert module == 'TestModule' | ||
return self.text | ||
|
||
def exec_module(self, module): | ||
assert module is TestModule | ||
if not self.ok: | ||
raise TestException | ||
print(1) | ||
|
||
|
||
def test_exec_module(): | ||
text = """ | ||
import deal | ||
deal.module_load(deal.silent) | ||
print(1) | ||
""" | ||
text = dedent(text) | ||
|
||
with pytest.raises(TestException): | ||
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule) | ||
|
||
with pytest.raises(deal.SilentContractError): | ||
DealLoader(loader=SubLoader(ok=True, text=text)).exec_module(TestModule) | ||
|
||
|
||
def test_exec_module_invalid_contract(): | ||
text = """ | ||
import deal | ||
deal.module_load(deal.pre(something)) | ||
print(1) | ||
""" | ||
text = dedent(text) | ||
with pytest.raises(RuntimeError, match='unsupported contract:.+'): | ||
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule) | ||
|
||
|
||
def test_exec_module_no_contracts(): | ||
text = """ | ||
import deal | ||
print(1) | ||
""" | ||
text = dedent(text) | ||
DealLoader(loader=SubLoader(ok=True, text=text)).exec_module(TestModule) | ||
with pytest.raises(TestException): | ||
DealLoader(loader=SubLoader(ok=False, text=text)).exec_module(TestModule) | ||
|
||
|
||
def test_module_load(): | ||
assert deal.activate() | ||
try: | ||
deal.module_load(deal.silent) | ||
finally: | ||
assert deactivate() | ||
|
||
with pytest.raises(RuntimeError): | ||
deal.module_load(deal.silent) | ||
|
||
|
||
def test_activate(): | ||
try: | ||
assert deal.activate() | ||
assert not deal.activate() | ||
finally: | ||
assert deactivate() | ||
assert not deactivate() |