# `torch.fx.Graph` Quick Introduction

In [1]:
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

## Symbolic Tracing

In [2]:
class MyModel(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()

    def forward(self, x, y):
        x = torch.add(x, y)
        x = F.tanh(x)
        x = self.linear(x)
        return self.relu(x)

In [3]:
import torch.fx as fx

In [4]:
gm = fx.symbolic_trace(MyModel(100, 10))

In [5]:
type(gm)

torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [6]:
gm.graph.print_tabular()

opcode         name    target                                               args       kwargs
-------------  ------  ---------------------------------------------------  ---------  --------
placeholder    x       x                                                    ()         {}
placeholder    y       y                                                    ()         {}
call_function  add     <built-in method add of type object at 0x114dbf150>  (x, y)     {}
call_method    tanh    tanh                                                 (add,)     {}
call_module    linear  linear                                               (tanh,)    {}
call_module    relu    relu                                                 (linear,)  {}
output         output  output                                               (relu,)    {}


## `fx.Graph`

In [7]:
# `nodes` are in topological ordering
for node in gm.graph.nodes:
    print(node)

x
y
add
tanh
linear
relu
output


- Attributes of `fx.Node`
    - `op`(`str`): the op code (`call_function`, `call_method`, ...etc)
    - `target`: the target of the op
        - the type of `target` depends on the `op`
    - `args`: the argumments (a.k.a inputs) of the node
    - `kwargs`: the keyword arguments of the node

Read more about `fx.Node` by running:

```python
help(fx.Node)
```

## Basic Graph Editing

You can find a lot of examples [here](https://github.com/pytorch/examples/tree/1bef748fab064e2fc3beddcbda60fd51cb9612d2/fx).

In [8]:
# replace `x + y` with x - y
def transform(gm: fx.GraphModule):
    gm = deepcopy(gm)
    # loop over nodes to find the right target
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target is torch.add:
            node.target = torch.sub
    # **IMPORTANT**: remember to run `.recompile` once you are done with the editing
    gm.recompile()
    return gm

In [9]:
gm_new = transform(gm)

In [10]:
gm_new.graph.print_tabular()

opcode         name    target                                               args       kwargs
-------------  ------  ---------------------------------------------------  ---------  --------
placeholder    x       x                                                    ()         {}
placeholder    y       y                                                    ()         {}
call_function  add     <built-in method sub of type object at 0x114dbf150>  (x, y)     {}
call_method    tanh    tanh                                                 (add,)     {}
call_module    linear  linear                                               (tanh,)    {}
call_module    relu    relu                                                 (linear,)  {}
output         output  output                                               (relu,)    {}


In [11]:
# replace (tanh -> linear -> relu) with (linear -> dropout)
class ReplacementModule(nn.Module):
    def __init__(self, ori_module):
        super().__init__()
        self.add_module("ori_module", ori_module)

    def forward(self, x):
        ori_module = self.get_submodule("ori_module")
        return F.dropout(ori_module.linear(x))

In [12]:
class PatternModule(nn.Module):
    def __init__(self, ori_module):
        super().__init__()
        self.add_module("ori_module", ori_module)

    def forward(self, x):
        ori_module = self.get_submodule("ori_module")
        x = F.tanh(x)
        x = ori_module.linear(x)
        return ori_module.relu(x)

In [13]:
def transform_match(gm):
    gm = deepcopy(gm)
    fx.replace_pattern(
        gm,
        PatternModule(gm),
        ReplacementModule(gm)
    )
    return gm

In [14]:
gm_new2 = transform_match(gm_new)

In [15]:
# A failure :P
gm_new2.graph.print_tabular()

opcode         name    target                                               args       kwargs
-------------  ------  ---------------------------------------------------  ---------  --------
placeholder    x       x                                                    ()         {}
placeholder    y       y                                                    ()         {}
call_function  add     <built-in method sub of type object at 0x114dbf150>  (x, y)     {}
call_method    tanh    tanh                                                 (add,)     {}
call_module    linear  linear                                               (tanh,)    {}
call_module    relu    relu                                                 (linear,)  {}
output         output  output                                               (relu,)    {}
