In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
from dataclasses import dataclass

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

In [10]:
import math

In [66]:
torch.tril(torch.ones(5, 5))

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [75]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config = None):
        super().__init__()
        if config is None:
            config = GPTConfig()
        assert config.n_embd % config.n_head == 0
        self.__config = config
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    @property
    def config(self):
        return self.__config

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = F.dropout(att, p=self.__config.dropout)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [76]:
from torch.nn.functional import scaled_dot_product_attention

In [77]:
print(scaled_dot_product_attention.__doc__)


scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor:

Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified.

.. code-block:: python

    # Efficient implementation equivalent to the following:
    def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
        # Efficient implementation equivalent to the following:
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias

In [86]:
import torch.fx as fx
from torch.export import export, Dim

In [87]:
att_module = CausalSelfAttention()

In [88]:
sample_x = torch.randint(0, 100, (1, att_module.config.block_size, att_module.config.n_embd), dtype=torch.long) 

In [119]:
att_module.config.block_size

1024

In [89]:
gm = fx.symbolic_trace(att_module, concrete_args={"x": sample_x})



In [90]:
type(gm)

torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [91]:
nodes_map = {}
for node in gm.graph.nodes:
    nodes_map[node.name] = node

In [92]:
gm.graph.print_tabular()

opcode         name               target                             args                            kwargs
-------------  -----------------  ---------------------------------  ------------------------------  ----------------------------------------------
placeholder    x_1                x_1                                ()                              {}
get_attr       _tensor_constant0  _tensor_constant0                  ()                              {}
call_module    c_attn             c_attn                             (_tensor_constant0,)            {}
call_method    split              split                              (c_attn, 768)                   {'dim': 2}
call_function  getitem            <built-in function getitem>        (split, 0)                      {}
call_function  getitem_1          <built-in function getitem>        (split, 1)                      {}
call_function  getitem_2          <built-in function getitem>        (split, 2)                      {}
call_met

In [93]:
nodes_map["dropout"].op, nodes_map["dropout"].target

('call_function',
 <function torch.nn.functional.dropout(input: torch.Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> torch.Tensor>)

In [99]:
def replacement(q, k, v):
    return scaled_dot_product_attention(
        q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
    )

In [126]:
class PatternModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, q, k, v):
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:1024,:1024] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = F.dropout(att, p=0.0)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(1, 1024, 768)
        return y

In [127]:
def pattern(q, k, v):
    pm = PatternModule(att_module.config)
    return pm.forward(q, k, v)

In [128]:
fx.replace_pattern(gm, pattern=pattern, replacement=replacement)

[Match(anchor=view, nodes_map={view: view_3, contiguous: contiguous, transpose_1: transpose_4, matmul_1: matmul_1, dropout: dropout, softmax: softmax, masked_fill: masked_fill, mul: mul, matmul: matmul, q: transpose_1, transpose: transpose_3, k: transpose, truediv: truediv, sqrt: sqrt, size: size, _tensor_constant0: _tensor_constant1, v: transpose_2})]

In [129]:
gm.graph.print_tabular()

opcode         name                          target                                            args                                   kwargs
-------------  ----------------------------  ------------------------------------------------  -------------------------------------  --------------------------------------------------------
placeholder    x_1                           x_1                                               ()                                     {}
get_attr       _tensor_constant0             _tensor_constant0                                 ()                                     {}
call_module    c_attn                        c_attn                                            (_tensor_constant0,)                   {}
call_method    split                         split                                             (c_attn, 768)                          {'dim': 2}
call_function  getitem                       <built-in function getitem>                       (split, 0)       