### In the cell below, we define a class that has a couple of instances that return self, and an attribute. We are looking to see if we can rewrite the AST to remove the intermediate self references.

In [10]:


class Selfie:
    def __init__(self, value):
        self.value = value

    def method_one(self):
        return self
    
    def method_two(self):
        return self
    

s = Selfie(42).method_one().method_two().value
s

42

In [11]:
src_code = _i
print(src_code)



class Selfie:
    def __init__(self, value):
        self.value = value

    def method_one(self):
        return self
    
    def method_two(self):
        return self
    

s = Selfie(42).method_one().method_two().value
s


In [14]:
import ast
print(ast.dump(ast.parse(src_code), indent=4))

Module(
    body=[
        ClassDef(
            name='Selfie',
            bases=[],
            keywords=[],
            body=[
                FunctionDef(
                    name='__init__',
                    args=arguments(
                        posonlyargs=[],
                        args=[
                            arg(arg='self'),
                            arg(arg='value')],
                        kwonlyargs=[],
                        kw_defaults=[],
                        defaults=[]),
                    body=[
                        Assign(
                            targets=[
                                Attribute(
                                    value=Name(id='self', ctx=Load()),
                                    attr='value',
                                    ctx=Store())],
                            value=Name(id='value', ctx=Load()))],
                    decorator_list=[],
                    type_params=[]),
                FunctionDef(
     

Too complicated. Let's just do the chained calls.

In [56]:
call_ast = ast.parse("Selfie(42).method_one().method_two().value")
print(ast.dump(call_ast, indent=4))

Module(
    body=[
        Expr(
            value=Attribute(
                value=Call(
                    func=Attribute(
                        value=Call(
                            func=Attribute(
                                value=Call(
                                    func=Name(id='Selfie', ctx=Load()),
                                    args=[
                                        Constant(value=42)],
                                    keywords=[]),
                                attr='method_one',
                                ctx=Load()),
                            args=[],
                            keywords=[]),
                        attr='method_two',
                        ctx=Load()),
                    args=[],
                    keywords=[]),
                attr='value',
                ctx=Load()))],
    type_ignores=[])


In [50]:
import copy
class NodeRemover(ast.NodeTransformer):
    def visit_Call(self, node):
        # Check if the call is a method call on an instance of Selfie
        if isinstance(node.func, ast.Attribute): 
            if hasattr(node.func, 'attr') and node.func.attr in ['method_one', 'method_two']:
                # If it's a method call, we can remove the self reference
                node.func.value = ast.Name(id='Selfie', ctx=ast.Load())
                return self.visit(node.func.value)  # Visit the modified node recursively
        return self.generic_visit(node)

transformer = NodeRemover()
transformed_ast = transformer.visit(copy.deepcopy(call_ast))
print(ast.unparse(transformed_ast))

Selfie.value


- That works. But it's a simple example. Let's try something more complex.

In [51]:
s1 = Selfie(42).method_one().method_two().value

s2 = Selfie(42).method_two().method_one().method_two().value

In [52]:
src_code = _i
ast_2 = ast.parse(src_code)

transformed_ast_2 = NodeRemover().visit(ast_2)  # This is calling visit() with no arguments!
print(ast.unparse(transformed_ast_2))

s1 = Selfie.value
s2 = Selfie.value


- If we wanted to use this for telemetry, we'd want to add a side effect to each visit call. Lets modify it to do that

In [64]:
print(ast.unparse(call_ast))
class NodeRemoverTelem(ast.NodeTransformer):
    def visit_Call(self, node):
        # Check if the call is a method call on an instance of Selfie
        node = self.generic_visit(node)  # Visit the node first to ensure all children are processed
        if isinstance(node.func, ast.Attribute): 
            if hasattr(node.func, 'attr') and node.func.attr in ['method_one', 'method_two']:
                # If it's a method call, we can remove the self reference
                print(f"Visiting {node.func.attr}")  # Side effect for telemetry
                # Replace the node with it's parent node
                node.func.value = ast.Name(id='Selfie', ctx=ast.Load())
                node.func.value = self.visit(node.func.value)
        return node

transformer = NodeRemoverTelem()
transformed_ast = transformer.visit(copy.deepcopy(call_ast))
print(ast.unparse(transformed_ast))

Selfie(42).method_one().method_two().value
Visiting method_one
Visiting method_two
Selfie.method_two().value


In [66]:
print(ast.unparse(call_ast))
class NodeRemoverTelem(ast.NodeTransformer):
    def visit_Call(self, node):
        # Check if the call is a method call on an instance of Selfie FIRST
        if isinstance(node.func, ast.Attribute): 
            if hasattr(node.func, 'attr') and node.func.attr in ['method_one', 'method_two']:
                # If it's a method call, we can remove the self reference
                print(f"Visiting {node.func.attr}")  # Side effect for telemetry
                # Replace the node with it's parent node
                node.func.value = ast.Name(id='Selfie', ctx=ast.Load())
                # Now visit the children to handle any nested transformations
                return self.generic_visit(node)
        # Only call generic_visit if we didn't transform this node
        return self.generic_visit(node)

transformer = NodeRemoverTelem()
transformed_ast = transformer.visit(copy.deepcopy(call_ast))
print(ast.unparse(transformed_ast))

Selfie(42).method_one().method_two().value
Visiting method_two
Selfie.method_two().value


- Looks like this is going to be really messy. 

The `astroid` package might help here, it extends `ast` & has parent nodes built in, as well as some more clever stuff

In [75]:
from astroid import parse

print(parse(ast.unparse(call_ast)).repr_tree(indent=" "))

Module(
 name='',
 file='<?>',
 path=['<?>'],
 package=False,
 pure_python=True,
 future_imports=set(),
 doc_node=None,
 body=[Expr(value=Attribute(
    attrname='value',
    expr=Call(
     func=Attribute(
      attrname='method_two',
      expr=Call(
       func=Attribute(
        attrname='method_one',
        expr=Call(
         func=Name(name='Selfie'),
         args=[Const(
           value=42,
           kind=None)],
         keywords=[])),
       args=[],
       keywords=[])),
     args=[],
     keywords=[])))])


In [87]:
from astroid import parse

astroid_tree = parse(ast.unparse(call_ast))
print(astroid_tree.repr_tree(indent=" "))

# Now you can access parent nodes
def show_parents(node, depth=0):
    indent = "  " * depth
    print(f"{indent}{type(node).__name__}: {getattr(node, 'name', getattr(node, 'attrname', ''))}")
    if hasattr(node, 'parent') and node.parent:
        print(f"{indent} Child::{getattr(node, 'attrname', 'no attr')}  -> Parent: {type(node.parent).__name__}")
    
    for child in node.get_children():
        show_parents(child, depth + 1)

print("\nNode hierarchy with parents:")
show_parents(astroid_tree)

Module(
 name='',
 file='<?>',
 path=['<?>'],
 package=False,
 pure_python=True,
 future_imports=set(),
 doc_node=None,
 body=[Expr(value=Attribute(
    attrname='value',
    expr=Call(
     func=Attribute(
      attrname='method_two',
      expr=Call(
       func=Attribute(
        attrname='method_one',
        expr=Call(
         func=Name(name='Selfie'),
         args=[Const(
           value=42,
           kind=None)],
         keywords=[])),
       args=[],
       keywords=[])),
     args=[],
     keywords=[])))])

Node hierarchy with parents:
Module: 
  Expr: 
   Child::no attr  -> Parent: Module
    Attribute: value
     Child::value  -> Parent: Expr
      Call: 
       Child::no attr  -> Parent: Attribute
        Attribute: method_two
         Child::method_two  -> Parent: Call
          Call: 
           Child::no attr  -> Parent: Attribute
            Attribute: method_one
             Child::method_one  -> Parent: Call
              Call: 
               Child::no attr  -> 