Skip to content

Commit

Permalink
Merge pull request #9 from eriknw/recurse_funcs
Browse files Browse the repository at this point in the history
Allow wrapped functions to define functions and classes.
  • Loading branch information
eriknw committed Sep 9, 2020
2 parents 9bfc3e4 + 4ab17de commit 4687c27
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 16 deletions.
67 changes: 51 additions & 16 deletions innerscope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,46 @@
import functools
import inspect
from collections.abc import Mapping
from types import CellType, FunctionType
from types import CellType, CodeType, FunctionType
from tlz import concatv, merge


def _get_globals_recursive(func, *, seen=None, isclass=False):
""" Get all global names used by func and all functions and classes defined within it."""
if isclass:
global_names = set()
local_names = {"__name__"}
for inst in dis.get_instructions(func):
if inst.opname == "STORE_NAME":
local_names.add(inst.argval)
elif inst.opname == "LOAD_NAME":
if inst.argval not in local_names:
global_names.add(inst.argval)
elif inst.opname == "LOAD_GLOBAL": # pragma: no cover
global_names.add(inst.argval)
else:
global_names = {
inst.argval for inst in dis.get_instructions(func) if inst.opname == "LOAD_GLOBAL"
}
if seen is None:
seen = set()
num_classes = 0
for inst in dis.get_instructions(func):
if inst.opname == "LOAD_CONST" and type(inst.argval) is CodeType:
if num_classes > 0:
code_inst = next(dis.get_instructions(inst.argval))
isclass = code_inst.opname == "LOAD_NAME" and code_inst.argval == "__name__"
num_classes -= isclass
if inst.argval in seen: # pragma: no cover
# I don't know how to get into a recursive cycle, but let's prevent it anyway.
continue
seen.add(inst.argval)
global_names.update(_get_globals_recursive(inst.argval, seen=seen, isclass=isclass))
elif inst.opname == "LOAD_BUILD_CLASS":
num_classes += 1
return global_names


def _get_repr_table(title, scope, add_break=False):
if not scope:
return f'{"<br>" if add_break else ""}<tt>- {title}: {{}}</tt>'
Expand Down Expand Up @@ -48,7 +84,7 @@ def _get_repr_set(title, names):


class Scope(Mapping):
""" A read-only mapping of the inner and outer scope of a function.
"""A read-only mapping of the inner and outer scope of a function.
This is the return value when a `ScopedFunction` is called.
"""
Expand All @@ -71,7 +107,7 @@ def __len__(self):
return len(self.outer_scope) + len(self.inner_scope)

def bindto(self, func, *, use_closures=None, use_globals=None):
""" Bind the variables of this object to a function.
"""Bind the variables of this object to a function.
>>> @call
... def haz_cheezburger():
Expand Down Expand Up @@ -115,7 +151,7 @@ def bindto(self, func, *, use_closures=None, use_globals=None):
return ScopedFunction(func, self, use_closures=use_closures, use_globals=use_globals)

def call(self, func, *args, **kwargs):
""" Bind the variables of this object to a function and call the function.
"""Bind the variables of this object to a function and call the function.
>>> @call
... def haz_cheezburger():
Expand Down Expand Up @@ -156,7 +192,7 @@ def call(self, func, *args, **kwargs):
return self.bindto(func)(*args, **kwargs)

def callwith(self, *args, **kwargs):
""" ♪ But here's my number, so call me maybe ♪
"""♪ But here's my number, so call me maybe ♪
>>> @call
... def haz_cheezburger():
Expand Down Expand Up @@ -255,7 +291,7 @@ def _repr_html_(self):


class ScopedFunction:
""" Use to expose the inner scope of a wrapped function after being called.
"""Use to expose the inner scope of a wrapped function after being called.
The wrapped function should have no return statements. Instead of a return value,
a `Scope` object is returned when called, which is a Mapping of the inner scope.
Expand Down Expand Up @@ -328,10 +364,7 @@ def __init__(self, func, *mappings, use_closures=True, use_globals=True):
]
)

# co_names has more than just the global names
global_names = {
inst.argval for inst in dis.get_instructions(self.func) if inst.opname == "LOAD_GLOBAL"
}
global_names = _get_globals_recursive(self.func)
# Only keep variables needed by the function (globals and closures)
outer_scope = {
key: outer_scope[key]
Expand Down Expand Up @@ -365,7 +398,9 @@ def __init__(self, func, *mappings, use_closures=True, use_globals=True):
else:
# stacksize must be at least 3, because we make a length three tuple
self._code = code.replace(
co_code=co_code, co_names=co_names, co_stacksize=max(code.co_stacksize, 3),
co_code=co_code,
co_names=co_names,
co_stacksize=max(code.co_stacksize, 3),
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -449,7 +484,7 @@ def __call__(self, *args, **kwargs):
)

def bind(self, *mappings, **kwargs):
""" Bind variables to a function's outer scope.
"""Bind variables to a function's outer scope.
This returns a new ScopedFunction object and leaves the original unmodified.
Expand Down Expand Up @@ -515,7 +550,7 @@ def _repr_html_(self):


def scoped_function(func=None, *mappings, use_closures=True, use_globals=True):
""" Use to expose the inner scope of a wrapped function after being called.
"""Use to expose the inner scope of a wrapped function after being called.
The wrapped function should have no return statements. Instead of a return value,
a `Scope` object is returned when called, which is a Mapping of the inner scope.
Expand Down Expand Up @@ -561,7 +596,7 @@ def inner_scoped_func(func):


def bindwith(*mappings, **kwargs):
""" Bind variables to a function's outer scope, but don't yet call the function.
"""Bind variables to a function's outer scope, but don't yet call the function.
>>> @bindwith(cheez='cheddar')
... def makez_cheezburger():
Expand Down Expand Up @@ -589,7 +624,7 @@ def bindwith_inner(func, *, use_closures=True, use_globals=True):


def call(func, *args, **kwargs):
""" Useful for making simple pipelines to go from functions to scopes.
"""Useful for making simple pipelines to go from functions to scopes.
>>> @call
... def haz_cheezburger():
Expand Down Expand Up @@ -630,7 +665,7 @@ def call(func, *args, **kwargs):


def callwith(*args, **kwargs):
""" Useful for making simple pipelines to go from functions with arguments to scopes.
"""Useful for making simple pipelines to go from functions with arguments to scopes.
>>> @callwith(extra_cheez_pleez=True)
... def haz_cheezburger(extra_cheez_pleez=False):
Expand Down
76 changes: 76 additions & 0 deletions innerscope/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,79 @@ def f(w, x=1, *args, y=2, z, **kwargs):
pass

assert f == {"w": 0, "x": 1, "y": 2, "z": 3, "args": (), "kwargs": {}}


def test_list_comprehension():
closure_val = 2

def f():
y = [i for i in range(global_x)]
z = [j for j in range(closure_val)]

assert innerscope.call(f) == {"y": [0], "z": [0, 1], "global_x": 1, "closure_val": 2}
scoped_f = scoped_function(f, use_globals=False, use_closures=False)
assert scoped_f.missing == {"global_x", "closure_val"}
scope = scoped_f.bind(global_x=2, closure_val=1)()
assert scope == {"y": [0, 1], "z": [0], "global_x": 2, "closure_val": 1}


def test_inner_functions():
def f():
closure_val = 10

def g():
y = global_x + 1
z = closure_val + 1
return y, z

scope = innerscope.call(f)
assert scope.keys() == {"closure_val", "g", "global_x"}
assert scope["g"]() == (2, 11)
scoped_f = scoped_function(f, use_globals=False, use_closures=False)
assert scoped_f.missing == {"global_x"}
scope = scoped_f.bind(global_x=2)()
assert scope.keys() == {"closure_val", "g", "global_x"}
assert scope["g"]() == (3, 11)


def test_inner_class():
def f1():
class A:
x = global_x + 1

scope = innerscope.call(f1)
assert scope.keys() == {"A", "global_x"}
assert scope["A"].x == 2
scoped_f = scoped_function(f1, use_globals=False, use_closures=False)
assert scoped_f.missing == {"global_x"}
assert scoped_f.bind(global_x=2)()["A"].x == 3

a = 10

def f2():
b = 100

def g(self):
pass

class A:
x = global_x + 1

def __init__(self):
pass

y = x + 1
z = a + b
gm = g

scope = innerscope.call(f2)
assert scope.outer_scope.keys() == {"a", "global_x"}
assert scope.inner_scope.keys() == {"b", "g", "A"}
assert scope["A"].x == 2
assert scope["A"].z == 110
assert scope["A"]().gm() is None
scoped_f = scoped_function(f2, use_globals=False, use_closures=False)
assert scoped_f.missing == {"a", "global_x"}
scope = scoped_f.bind(a=20, global_x=2)()
assert scope["A"].x == 3
assert scope["A"].z == 120

0 comments on commit 4687c27

Please sign in to comment.