In [1]:
from model import CausalSelfAttention

In [2]:
from torch.nn.functional import scaled_dot_product_attention
import math
import torch.nn.functional as F

In [1]:
import torch
import torch.nn as nn
import torch.fx as fx
import torch.onnx as onnx

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
att_module = CausalSelfAttention()

In [None]:
from torch.export.

In [5]:
sample_x = torch.randn((1, att_module.config.block_size, att_module.config.n_embd)) 

In [61]:
gm = fx.symbolic_trace(att_module, concrete_args={"x": sample_x})



In [62]:
gm.graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  --------------------------------------------
placeholder    x_1                x_1                                ()                              {}
get_attr       _tensor_constant0  _tensor_constant0                  ()                              {}
call_module    c_attn             c_attn                             (_tensor_constant0,)            {}
call_method    split              split                              (c_attn, 768)                   {'dim': 2}
call_function  getitem            <built-in function getitem>        (split, 0)                      {}
call_function  getitem_1          <built-in function getitem>        (split, 1)                      {}
call_function  getitem_2          <built-in function getitem>        (split, 2)                      {}
call_metho

In [63]:
for node in gm.graph.nodes:
    if node.op == "call_module":
        target = node.target
        sub_module = getattr(gm, target)
        if not isinstance(sub_module, nn.Dropout):
            continue
        node.kwargs = {"p": sub_module.p, "training": sub_module.training, "inplace": sub_module.inplace}
        node.target = F.dropout
        node.op = "call_function"
_ = gm.recompile()

In [64]:
gm.graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  ----------------------------------------------
placeholder    x_1                x_1                                ()                              {}
get_attr       _tensor_constant0  _tensor_constant0                  ()                              {}
call_module    c_attn             c_attn                             (_tensor_constant0,)            {}
call_method    split              split                              (c_attn, 768)                   {'dim': 2}
call_function  getitem            <built-in function getitem>        (split, 0)                      {}
call_function  getitem_1          <built-in function getitem>        (split, 1)                      {}
call_function  getitem_2          <built-in function getitem>        (split, 2)                      {}
call_met

In [65]:
gm(sample_x)

tensor([[[-5.7285e-01,  5.5057e-01, -7.4936e-01,  ...,  5.2238e-04,
          -3.5449e-01, -1.6398e-01],
         [-1.2847e-01,  1.0453e-01, -3.9423e-02,  ..., -1.1701e-01,
          -2.7683e-01, -9.2623e-02],
         [-1.3511e-01, -2.8250e-02, -8.8111e-02,  ...,  1.2655e-01,
          -3.9318e-02, -2.9731e-01],
         ...,
         [-4.6995e-02,  1.9405e-02, -5.9787e-03,  ...,  4.6384e-02,
           2.2823e-02, -4.5394e-02],
         [-4.1740e-02,  1.3565e-02, -1.4079e-02,  ...,  3.2467e-02,
           2.4587e-02, -3.9794e-02],
         [-3.1664e-02,  1.4018e-02, -1.6495e-02,  ...,  2.2946e-02,
           1.7570e-02, -3.2156e-02]]], grad_fn=<ViewBackward0>)

----

In [66]:
def replacement(q, k, v):
    return scaled_dot_product_attention(
        q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
    )

In [67]:
config = att_module.config

In [68]:
class PatternModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, q, k, v):
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:1024,:1024] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = F.dropout(att, p=0.0)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(1, 1024, 768)
        return y

In [69]:
fx.symbolic_trace(replacement).graph.print_tabular()

opcode         name                          target                                            args                             kwargs
-------------  ----------------------------  ------------------------------------------------  -------------------------------  --------------------------------------------------------
placeholder    q                             q                                                 ()                               {}
placeholder    k                             k                                                 ()                               {}
placeholder    v                             v                                                 ()                               {}
call_function  scaled_dot_product_attention  <built-in function scaled_dot_product_attention>  (q, k, v)                        {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}
output         output                        output                                            (scaled

In [70]:
def pattern(q, k, v):
    pm = PatternModule(att_module.config)
    return pm.forward(q, k, v)

In [71]:
fx.symbolic_trace(pattern).graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  ----------------------------------------------
placeholder    q                  q                                  ()                              {}
placeholder    k                  k                                  ()                              {}
placeholder    v                  v                                  ()                              {}
call_method    transpose          transpose                          (k, -2, -1)                     {}
call_function  matmul             <built-in function matmul>         (q, transpose)                  {}
call_method    size               size                               (k, -1)                         {}
call_function  sqrt               <built-in function sqrt>           (size,)                         {}
call_function  t

In [72]:
fx.replace_pattern(gm, pattern=pattern, replacement=replacement)

[Match(anchor=view, nodes_map={view: view_3, contiguous: contiguous, transpose_1: transpose_4, matmul_1: matmul_1, dropout: attn_dropout, softmax: softmax, masked_fill: masked_fill, mul: mul, matmul: matmul, q: transpose_1, transpose: transpose_3, k: transpose, truediv: truediv, sqrt: sqrt, size: size, _tensor_constant0: _tensor_constant4, v: transpose_2})]

In [73]:
gm.graph.print_tabular()

opcode         name                          target                                            args                                   kwargs
-------------  ----------------------------  ------------------------------------------------  -------------------------------------  --------------------------------------------------------
placeholder    x_1                           x_1                                               ()                                     {}
get_attr       _tensor_constant0             _tensor_constant0                                 ()                                     {}
call_module    c_attn                        c_attn                                            (_tensor_constant0,)                   {}
call_method    split                         split                                             (c_attn, 768)                          {'dim': 2}
call_function  getitem                       <built-in function getitem>                       (split, 0)       

In [83]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        return self.relu(x)

In [85]:
model = nn.Sequential(MLP(256, 128), MLP(128, 64))

In [87]:
x = torch.randn(1, 256)

In [89]:
ep = export(model, (x,))

In [94]:
nodes_map = {}
for node in ep.module().graph.nodes:
    print(node.op, node.target, node.name, [k for k in node.meta])
    nodes_map[node.name] = node

get_attr 0_linear_weight _0_linear_weight ['val', 'tensor_meta', 'stack_trace', 'nn_module_stack', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value']
get_attr 0_linear_bias _0_linear_bias ['val', 'tensor_meta', 'stack_trace', 'nn_module_stack', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value']
get_attr 1_linear_weight _1_linear_weight ['val', 'tensor_meta', 'stack_trace', 'nn_module_stack', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value']
get_attr 1_linear_bias _1_linear_bias ['val', 'tensor_meta', 'stack_trace', 'nn_module_stack', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value']
placeholder l_args_0_ l_args_0_ ['val', 'tensor_meta']
call_function aten.t.default t ['stack_trace', 'nn_module_stack', 'source_fn_stack', 'original_aten', 'from_node', 'seq_nr', 'val', 'tensor_meta']
call_function aten.addmm.default addmm ['stack_trace', 'nn_module_stack', 'source_fn_stack', 'original_aten', 'from_node', 'seq_nr', 'val', 'tensor_meta']
call_function aten.

In [96]:
nodes_map['t'].meta

{'stack_trace': '  File "/Users/dboyliao/Work/Sciwork/SciConf_2024/.venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner\n    return fn(*args, **kwargs)\n  File "/Users/dboyliao/Work/Sciwork/SciConf_2024/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl\n    return forward_call(*args, **kwargs)\n  File "/var/folders/5l/jrch3z5x6jx6tcz33qkv1mgh0000gn/T/ipykernel_44342/1711239079.py", line 8, in forward\n    x = self.linear(x)\n',
 'nn_module_stack': {'L__self__': ('', torch.nn.modules.container.Sequential),
  'fn': ("L['fn']", torch.nn.modules.container.Sequential),
  'fn_0': ("getattr(L['fn'], '0')", __main__.MLP),
  'getattr_L__fn_____0___linear': ("getattr(L['fn'], '0').linear",
   torch.nn.modules.linear.Linear)},
 'source_fn_stack': [('getattr_l__fn_____0___linear',
   torch.nn.modules.linear.Linear)],
 'original_aten': <OpOverload(op='aten.t', overload='default')>,
 'from_node': [('x', 'getattr_L__fn_____0___l

In [97]:
from torchinfo import summary

In [102]:
summary(ep.module(), x.shape)

Layer (type:depth-idx)                   Output Shape              Param #
GraphModule                              [1, 64]                   41,152
Total params: 41,152
Trainable params: 41,152
Non-trainable params: 0
Total mult-adds (M): 2.63
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.16
Estimated Total Size (MB): 0.17