In [1]:
from model import CausalSelfAttention

In [2]:
import math
from copy import deepcopy

import torch
import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.functional import scaled_dot_product_attention

In [3]:
att_module = CausalSelfAttention()

In [4]:
sample_x = torch.randn((1, att_module.config.block_size, att_module.config.n_embd)) 

In [5]:
gm = fx.symbolic_trace(att_module, concrete_args={"x": sample_x})



In [6]:
gm.graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  --------------------------------------------
placeholder    x_1                x_1                                ()                              {}
get_attr       _tensor_constant0  _tensor_constant0                  ()                              {}
call_module    c_attn             c_attn                             (_tensor_constant0,)            {}
call_method    split              split                              (c_attn, 768)                   {'dim': 2}
call_function  getitem            <built-in function getitem>        (split, 0)                      {}
call_function  getitem_1          <built-in function getitem>        (split, 1)                      {}
call_function  getitem_2          <built-in function getitem>        (split, 2)                      {}
call_metho

In [7]:
def normalize_dropout(gm: fx.GraphModule):
    """
    Find all `self.dropout(x)` and replace it with `F.dropout`
    """
    gm = deepcopy(gm)
    for node in gm.graph.nodes:
        if node.op == "call_module":
            target = node.target
            sub_module = getattr(gm, target)
            if not isinstance(sub_module, nn.Dropout):
                continue
            node.kwargs = {"p": sub_module.p, "training": sub_module.training, "inplace": sub_module.inplace}
            node.target = F.dropout
            node.op = "call_function"
    _ = gm.recompile()
    return gm

In [8]:
gm = normalize_dropout(gm)
gm.graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  ----------------------------------------------
placeholder    x_1                x_1                                ()                              {}
get_attr       _tensor_constant0  _tensor_constant0                  ()                              {}
call_module    c_attn             c_attn                             (_tensor_constant0,)            {}
call_method    split              split                              (c_attn, 768)                   {'dim': 2}
call_function  getitem            <built-in function getitem>        (split, 0)                      {}
call_function  getitem_1          <built-in function getitem>        (split, 1)                      {}
call_function  getitem_2          <built-in function getitem>        (split, 2)                      {}
call_met

----

In [10]:
config = att_module.config

In [11]:
def replacement(q, k, v):
    return scaled_dot_product_attention(
        q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
    )

In [13]:
fx.symbolic_trace(replacement).graph.print_tabular()

opcode         name                          target                                            args                             kwargs
-------------  ----------------------------  ------------------------------------------------  -------------------------------  --------------------------------------------------------
placeholder    q                             q                                                 ()                               {}
placeholder    k                             k                                                 ()                               {}
placeholder    v                             v                                                 ()                               {}
call_function  scaled_dot_product_attention  <built-in function scaled_dot_product_attention>  (q, k, v)                        {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}
output         output                        output                                            (scaled

In [12]:
class PatternModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, q, k, v):
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:1024,:1024] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = F.dropout(att, p=0.0)
        return att @ v

In [14]:
def pattern(q, k, v):
    pm = PatternModule(att_module.config)
    return pm.forward(q, k, v)

In [15]:
fx.symbolic_trace(pattern).graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  ----------------------------------------------
placeholder    q                  q                                  ()                              {}
placeholder    k                  k                                  ()                              {}
placeholder    v                  v                                  ()                              {}
call_method    transpose          transpose                          (k, -2, -1)                     {}
call_function  matmul             <built-in function matmul>         (q, transpose)                  {}
call_method    size               size                               (k, -1)                         {}
call_function  sqrt               <built-in function sqrt>           (size,)                         {}
call_function  t

In [16]:
fx.replace_pattern(gm, pattern=pattern, replacement=replacement)

[Match(anchor=matmul_1, nodes_map={matmul_1: matmul_1, dropout: attn_dropout, softmax: softmax, masked_fill: masked_fill, mul: mul, matmul: matmul, q: transpose_1, transpose: transpose_3, k: transpose, truediv: truediv, sqrt: sqrt, size: size, _tensor_constant0: _tensor_constant1, v: transpose_2})]

In [17]:
gm.graph.print_tabular()

opcode         name                          target                                            args                                   kwargs
-------------  ----------------------------  ------------------------------------------------  -------------------------------------  --------------------------------------------------------
placeholder    x_1                           x_1                                               ()                                     {}
get_attr       _tensor_constant0             _tensor_constant0                                 ()                                     {}
call_module    c_attn                        c_attn                                            (_tensor_constant0,)                   {}
call_method    split                         split                                             (c_attn, 768)                          {'dim': 2}
call_function  getitem                       <built-in function getitem>                       (split, 0)       

In [18]:
import numpy as np

np.allclose(
    att_module(sample_x).detach().cpu().numpy(),
    gm(sample_x).detach().cpu().numpy(),
    atol=1e-7
)

True

---