In [None]:
import ast
import functools
import inspect
import os
import sys

from IPython.core.magic import register_cell_magic, register_line_magic
from IPython import get_ipython

debug_frame = None

def _is_cell_var(name, frame):
    if name in frame.f_locals:
        if name not in frame.f_globals:
            return False
        if frame.f_locals[name] is frame.f_globals[name]:
            return True
    global debug_frame
    debug_frame = frame
    assert name in frame.f_globals, 'could not find name %s' % name
    return True

class TraceState(object):
    def __init__(self):
        self.call_depth = 0
        self.code_lines = {}
        self.stack = []
        self.source = None
        self.cur_frame_last_line = None
        self.last_event = None
#         self.last_line = None
#         self.last_frame = None
#         self.last_lineno = None

# TODO: maybe frame, scope, indentation, etc
class CodeLine(object):
    def __init__(self, text, ast_node, lineno, call_depth, frame):
        self.text = text
        self.ast_node = ast_node
        self.lineno = lineno
        self.call_depth = call_depth
        self.frame = frame
        self.extra_dependencies = set()
    
    def compute_rval_dependencies(self):
        return get_all_rval_names(self.ast_node) | self.extra_dependencies
    
    @property
    def has_lval(self):
        # TODO: expand to AugAssign, method calls, etc.
        return isinstance(self.ast_node, ast.Assign)

def tracefunc(frame, event, arg, state):    
    # this is a bit of a hack to get the class out of the locals
    # - it relies on 'self' being used... normally a safe assumption!
    try:
        class_name = frame.f_locals['self'].__class__.__name__ 
    except (KeyError, AttributeError):
        class_name = "No Class"
    
    # notebook filenames appear as 'ipython-input...'
    if 'ipython-input' not in frame.f_code.co_filename:
        return
    
    tracer = functools.partial(tracefunc, state=state)
    
    if state.source is None:
        state.source = inspect.getsource(frame).split('\n')
    line = state.source[frame.f_lineno-1]

    old_depth = state.call_depth
    
    # IPython quirk -- every line in outer scope apparently wrapped in lambda
    # We want to skip the outer 'call' and 'return' for these
    if event == 'call':
        state.call_depth += 1
        if old_depth == 0:
            return tracer
    
    if event == 'return':
        state.call_depth -= 1
        if old_depth <= 1:
            return tracer
    
    to_parse = line = line.strip()
    lineno = frame.f_lineno
    print(lineno, state.call_depth, event, line)
    if to_parse[-1] == ':':
        to_parse += '\n    pass'
    node = ast.parse(to_parse).body[0]
    code_line = state.code_lines.get(
        lineno, CodeLine(line, node, lineno, state.call_depth, frame)
    )
    state.code_lines[lineno] = code_line
    
    if event in ('line', 'return') and state.last_event != 'call':
        # this means that the previous line on the current frame is done executing
        if state.cur_frame_last_line is not None:
            print('done processing line {}'.format(state.cur_frame_last_line.text))
    if event == 'line':
        state.cur_frame_last_line = code_line
    if event == 'call':
        state.stack.append(state.cur_frame_last_line)
        state.cur_frame_last_line = None
    if event == 'return':
        ret_line = state.stack.pop()
        assert ret_line is not None
        # reset 'cur_frame_last_line' for the previous frame, so that we push it again if it has another funcall
        state.cur_frame_last_line = ret_line
        print('{} @@returning to@@ {}'.format(code_line.text, ret_line.text))
#         print('callee deps: %s' % list(filter(
#             lambda name: _is_cell_var(name, frame), code_line.compute_rval_dependencies())))
#         print('caller deps: %s' % list(filter(
#             lambda name: _is_cell_var(name, ret_line.frame), ret_line.compute_rval_dependencies())))
        ret_line.extra_dependencies |= code_line.compute_rval_dependencies()
    state.last_event = event
#     state.last_frame = frame
#     state.last_lineno = lineno
#     state.last_line = code_line
    return tracer

In [None]:
class GetAllRvalNames(ast.NodeVisitor):
    def __init__(self):
        self.name_set = set()

    def __call__(self, node):
        self.visit(node)
        return self.name_set

    def visit_Name(self, node):
        self.name_set.add(node.id)
        
    def visit_Assign(self, node):
        # skip node.targets
        self.visit(node.value)
    
    def visit_AugAssign(self, node):
        # skip node.target
        self.visit(node.value)
    
    def visit_For(self, node):
        # skip node.target (gets bound to node.iter)
        # skip body too -- will have dummy since this visitor works line-by-line
        self.visit(node.iter)
    
    def visit_Lambda(self, node):
        # remove node.arguments
        self.visit(node.body)
        self.visit(node.args)
        old = self.name_set
        self.name_set = set()
        # throw away anything appearing in lambda body that isn't bound
        self.visit(node.args.args)
        self.visit(node.args.vararg)
        self.visit(node.args.kwonlyargs)
        self.visit(node.args.kwarg)
        self.name_set = old - self.name_set
            
    def generic_visit(self, node):
        if node is None:
            return
        elif isinstance(node, list):
            for item in node:
                self.visit(item)
        else:
            super().generic_visit(node)
        
    def visit_arguments(self, node):
        # skip over unbound args
        self.visit(node.defaults)
        self.visit(node.kw_defaults)
    
    def visit_arg(self, node):
        self.name_set.add(node.arg)


def get_all_rval_names(node: ast.AST):
    return GetAllRvalNames()(node)

In [None]:
@register_cell_magic
def tracecell(_, cell):
    state = TraceState()
    sys.settrace(functools.partial(tracefunc, state=state))
    get_ipython().run_cell(cell)
    sys.settrace(None)

In [None]:
%%tracecell
y = 3 + 5
def f(x):
    def g(y):
        a = y + 5
        return a * x
    def foo(z):
        return g(z) * 3
    return foo

asdf = f(3)(4)
print(asdf)

In [None]:
%%tracecell
z = 3
def f(x):
    return 3*x
def g(x, y=f(z)):
    return x + y
print(g(3))

In [None]:
x = 11
def h():
    x = 10
    def f():
        def g():
            print(x)
            print(inspect.currentframe().f_locals.get('x', 'not found'))
            print(inspect.currentframe().f_globals.get('x', 'not found'))
        return g
    return f
h()()()