In [1]:
import ast, copy, dis
import nbimport, sublambda

In [2]:
import Episode13

# Step 1: Setup episode 13 test case

In [3]:
%%mycell sublambda.sublambdas('ep13_test_case')
f lambda x: (z := x ** 2, z + x + 1)[-1]

<function sublambda.inline_lambdas.<locals>.inliner(source)>

In [4]:
%%mycell ep13_test_case

def wonkification():
    z = 42
    f(99)
    return z

In [5]:
wonky = wonkification()
print(wonky)
wonky == 42 # Nope...

9801


False

In [11]:
print(ast.dump(ast.parse('lambda x, y: (x := 2*x, x + y)[-1]', mode='eval'), indent=2))

Expression(
  body=Lambda(
    args=arguments(
      posonlyargs=[],
      args=[
        arg(arg='x'),
        arg(arg='y')],
      kwonlyargs=[],
      kw_defaults=[],
      defaults=[]),
    body=Subscript(
      value=Tuple(
        elts=[
          NamedExpr(
            target=Name(id='x', ctx=Store()),
            value=BinOp(
              left=Constant(value=2),
              op=Mult(),
              right=Name(id='x', ctx=Load()))),
          BinOp(
            left=Name(id='x', ctx=Load()),
            op=Add(),
            right=Name(id='y', ctx=Load()))],
        ctx=Load()),
      slice=UnaryOp(
        op=USub(),
        operand=Constant(value=1)),
      ctx=Load())))


# Step 2: Write renaming transformer

In [23]:
class RenameNameTransformer(ast.NodeTransformer):
    def __init__(self, names=None):
        if names is None:
            names = []
        self.symbols = {name: None for name in names}

    def visit_Lambda(self, node):
        for arg in node.args.args:
            self.symbols[arg.arg] = arg.arg
        return self.generic_visit(node)

    def visit_NamedExpr(self, node):
        rhs = self.visit(node.value)
        lhs = self.visit(node.target)
        return ast.NamedExpr(lhs, rhs)

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Store):
            new_id = f'${node.id}'
            self.symbols[node.id] = new_id
            return ast.Name(new_id, node.ctx)
        elif node.id in self.symbols:
            sub_id = self.symbols[node.id]
            if sub_id is not None:
                return ast.Name(sub_id, node.ctx)
        return self.generic_visit(node)

In [22]:
renamer = RenameNameTransformer()
lambda_tree_0 = ast.parse('lambda x, y: (x := 2*x, x + y)[-1]', mode='eval')
lambda_tree_1 = ast.fix_missing_locations(renamer.visit(lambda_tree_0))
print(ast.dump(lambda_tree_1, indent=2))
print(ast.unparse(lambda_tree_1))

Name(id='x', ctx=Load())
Name(id='x', ctx=Store())
Name(id='x', ctx=Load())
Name(id='y', ctx=Load())
Expression(
  body=Lambda(
    args=arguments(
      posonlyargs=[],
      args=[
        arg(arg='x'),
        arg(arg='y')],
      kwonlyargs=[],
      kw_defaults=[],
      defaults=[]),
    body=Subscript(
      value=Tuple(
        elts=[
          NamedExpr(
            target=Name(id='$x', ctx=Store()),
            value=BinOp(
              left=Constant(value=2),
              op=Mult(),
              right=Name(id='x', ctx=Load()))),
          BinOp(
            left=Name(id='$x', ctx=Load()),
            op=Add(),
            right=Name(id='y', ctx=Load()))],
        ctx=Load()),
      slice=UnaryOp(
        op=USub(),
        operand=Constant(value=1)),
      ctx=Load())))
lambda x, y: (($x := (2 * x)), $x + y)[-1]


In [25]:
code_obj = compile(lambda_tree_1, '<transformed-string>', 'eval')
dis.dis(code_obj)
eval(code_obj)(2, 3)

  1           0 LOAD_CONST               0 (<code object <lambda> at 0x118c063a0, file "<transformed-string>", line 1>)
              2 LOAD_CONST               1 ('<lambda>')
              4 MAKE_FUNCTION            0
              6 RETURN_VALUE

Disassembly of <code object <lambda> at 0x118c063a0, file "<transformed-string>", line 1>:
  1           0 LOAD_CONST               1 (2)
              2 LOAD_FAST                0 (x)
              4 BINARY_MULTIPLY
              6 DUP_TOP
              8 STORE_FAST               2 ($x)
             10 LOAD_FAST                2 ($x)
             12 LOAD_FAST                1 (y)
             14 BINARY_ADD
             16 BUILD_TUPLE              2
             18 LOAD_CONST               2 (-1)
             20 BINARY_SUBSCR
             22 RETURN_VALUE


7

In [26]:
ast.parse(ast.unparse(lambda_tree_1), mode='eval')

SyntaxError: invalid syntax (<unknown>, line 1)

# Step 3: Inject transformer into rewrite chain

In [37]:
def build_hygienic_closure(in_expr):
    renamed_expr = RenameNameTransformer().visit(in_expr)
    print(ast.dump(renamed_expr, indent=2))
    visitor = sublambda.LambdaVisitor(renamed_expr)
    def _sub_closure(*args, **kws):
        assert len(args) == len(visitor._lambda_args)
        result = copy.deepcopy(visitor._lambda_body)
        transformer = NameSubstitutionTransformer(**{k: v for k, v in zip(visitor._lambda_args, args)})
        return ast.fix_missing_locations(transformer.visit(result))
    return _sub_closure

def inline_hygienic_lambdas(pairs):
    substitutions = {name: build_hygienic_closure(expr) for name, expr in pairs}
    transformer = sublambda.ParametricSubstitutionTransformer(**substitutions)
    def inliner(source):
        target = ast.parse(source)
        return ast.fix_missing_locations(transformer.visit(target))
    return inliner

def build_hygienic_sublambdas(cell):
    the_lambdas0 = [ln.split(' ', 1) for ln in cell.split('\n')]
    the_lambdas = [(content[0], content[1])
                   for content in the_lambdas0
                   if len(content) == 2]
    return inline_hygienic_lambdas(the_lambdas)

def hygienic_sublambdas(transformer_name):
    def binding_hygienic_sublambdas(cell0, shell):
        inliner = build_hygienic_sublambdas(cell0)
        def the_magic(cell1, shell):
            inlined = inliner(cell1)
            return shell.ex(compile(inlined, '<substituted-string>', 'exec'))
        shell.user_global_ns[transformer_name] = the_magic
        return inliner
    return binding_hygienic_sublambdas


In [39]:
%%mycell hygienic_sublambdas('ep14_test_case')
f lambda x: (z := x ** 2, z + x + 1)[-1]

AttributeError: 'str' object has no attribute '_fields'

In [30]:
%%mycell ep14_test_case
def wonkification2():
    z = 42
    f(99)
    return z

In [31]:
not_so_wonky = wonkification2()
print(not_so_wonky)
not_so_wonky == 42 # Maybe?!

9801


False

In [38]:
# Fail!?! Let's debug...
inliner = _29
inliner_result = inliner(Episode13.test_code)
print(ast.dump(inliner_result, indent=2))

Module(
  body=[
    FunctionDef(
      name='wonkification',
      args=arguments(
        posonlyargs=[],
        args=[],
        kwonlyargs=[],
        kw_defaults=[],
        defaults=[]),
      body=[
        Assign(
          targets=[
            Name(id='z', ctx=Store())],
          value=Constant(value=42)),
        Expr(
          value=Subscript(
            value=Tuple(
              elts=[
                NamedExpr(
                  target=Name(id='z', ctx=Store()),
                  value=BinOp(
                    left=Constant(value=99),
                    op=Pow(),
                    right=Constant(value=2))),
                BinOp(
                  left=BinOp(
                    left=Name(id='z', ctx=Load()),
                    op=Add(),
                    right=Constant(value=99)),
                  op=Add(),
                  right=Constant(value=1))],
              ctx=Load()),
            slice=UnaryOp(
              op=USub(),
              operand=Consta