Skip to content

Commit

Permalink
Merge e645e16 into 2a1f277
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Nov 23, 2019
2 parents 2a1f277 + e645e16 commit 3a2760f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._exceptions import * # noQA
from ._schemes import Scheme
from ._state import reset, switch
from ._imports import load, register
from ._testing import TestCase, cases


Expand All @@ -54,4 +55,8 @@
# aliases
'invariant',
'require',

# module level
'load',
'register',
]
85 changes: 85 additions & 0 deletions deal/_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# built-in
import ast
import sys
from _frozen_importlib_external import PathFinder
from typing import Callable, Optional

from .linter._extractors.common import get_name
from . import _aliases


class DealFinder(PathFinder):
@classmethod
def find_spec(cls, *args, **kwargs):
spec = super().find_spec(*args, **kwargs)
if spec is not None:
spec.loader = DealLoader(spec.loader)
return spec


class DealLoader:
def __init__(self, loader):
self._loader = loader

def __getattr__(self, name):
return getattr(self._loader, name)

def exec_module(self, module) -> None:
if not hasattr(self._loader, 'get_source'):
return self._loader.exec_module(module)

# get nodes with module-level contracts from the source code
source = self._loader.get_source(module.__name__)
if source is None:
return self._loader.exec_module(module)
tree = ast.parse(source)
nodes = self._get_contracts(tree=tree)
if not nodes:
return self._loader.exec_module(module)

# convert contracts nodes into real contracts
contracts = []
for node in nodes:
contract = self._exec_contract(node=node)
if contract is None:
msg = 'unsupported contract: {}'.format(ast.dump(contract))
raise RuntimeError(msg)
contracts.append(contract)

# execute module with contracts
wrapped = _aliases.chain(contract)(self._loader.exec_module)
wrapped(module)

@staticmethod
def _get_contracts(tree) -> Optional[list]:
for node in tree.body:
if not type(node) is ast.Expr:
continue
if not type(node.value) is ast.Call:
continue
if get_name(node.value.func) != 'deal.load':
continue
return node.value.args
return []

@staticmethod
def _exec_contract(node) -> Optional[Callable]:
if not isinstance(node, ast.Attribute):
return None
if node.value.id != 'deal':
return None
contract = getattr(_aliases, node.attr, None)
if contract is None:
return None
return contract


def load(*contracts) -> None:
return


def register():
if DealFinder in sys.meta_path:
return
index = sys.meta_path.index(PathFinder)
sys.meta_path[index] = DealFinder

0 comments on commit 3a2760f

Please sign in to comment.