https://docs.pytorch.org/docs/stable/fx.html

In [1]:
import torch


# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


module = MyModule()

from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp



def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    


'\ndef forward(self, x):\n    param = self.param\n    add = x + param;  x = param = None\n    linear = self.linear(add);  add = None\n    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None\n    return clamp\n'

In [None]:
def forward(self, x):
    param = self.param
    add = x + param
    x = param = None
    linear = self.linear(add)
    add = None
    clamp = linear.clamp(min=0.0, max=1.0)
    linear = None
    return clamp

In [7]:
!pip install tabulate -q

In [8]:
import torch
import torch.fx


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(
            torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
        )


m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7c28a9ef6f80>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7c28a9ef6f80>  (sum_1, 3) 

In [None]:
import torch
from torch import fx


# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)


def transform(m: torch.nn.Module, tracer_class: type = fx.Tracer) -> torch.nn.Module:
    graph: fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == "call_function":
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint()  # Does some checks to make sure the
    # Graph is well-formed.

    return fx.GraphModule(m, graph)  #