In [3]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

model_name = "gpt2"  # smallest GPT-2

tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

text = "Hello, world!"
inputs = tokenizer(text, return_tensors="pt")

output = model(**inputs)
print(output.logits.shape)

torch.Size([1, 4, 50257])


In [17]:
model.transformer.h[0].attn.c_attn.weight.shape

torch.Size([768, 2304])

In [18]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [29]:
text = "Tomorrow I will not speak, nor weep; For my soul is lost through "
inputs = tokenizer(text, return_tensors="pt")
inputs['input_ids']

tensor([[49488,   314,   481,   407,  2740,    11,  4249, 49671,    26,  1114,
           616,  5848,   318,  2626,   832,   220]])

In [30]:
model(**inputs).logits[0].argmax(dim=1)

tensor([  11, 1101,  307,  307,  286,  475,  466,   11,  475,  314, 5848,  318,
         407,   11,  262, 2171])

model(inputs)

In [35]:
model.transformer.wte.weight[:10, 0]

tensor([-0.1101,  0.0403, -0.1275, -0.0927, -0.0506,  0.0112, -0.0839, -0.1300,
        -0.0797, -0.0401], grad_fn=<SelectBackward0>)

In [48]:
[model(**inputs, output_hidden_states=True).hidden_states[i][0][0][0] for i in range(13)]

[tensor(-0.0016, grad_fn=<SelectBackward0>),
 tensor(1.8395, grad_fn=<SelectBackward0>),
 tensor(2.1406, grad_fn=<SelectBackward0>),
 tensor(2.0247, grad_fn=<SelectBackward0>),
 tensor(2.0568, grad_fn=<SelectBackward0>),
 tensor(1.8415, grad_fn=<SelectBackward0>),
 tensor(1.8090, grad_fn=<SelectBackward0>),
 tensor(1.6873, grad_fn=<SelectBackward0>),
 tensor(1.6717, grad_fn=<SelectBackward0>),
 tensor(1.6966, grad_fn=<SelectBackward0>),
 tensor(1.5764, grad_fn=<SelectBackward0>),
 tensor(1.4559, grad_fn=<SelectBackward0>),
 tensor(0.0812, grad_fn=<SelectBackward0>)]

In [49]:
hs = model(**inputs, output_hidden_states=True).hidden_states

In [63]:
model.transformer.h[0].ln_1.weight

Parameter containing:
tensor([0.2232, 0.1820, 0.1534, 0.1917, 0.2036, 0.1948, 0.1467, 0.1865, 0.2143,
        0.1956, 0.2118, 0.2153, 0.1882, 0.2074, 0.1871, 0.2040, 0.2044, 0.1900,
        0.1952, 0.0475, 0.1909, 0.2115, 0.1971, 0.2202, 0.1998, 0.2108, 0.2303,
        0.1879, 0.1939, 0.2018, 0.1891, 0.1861, 0.1958, 0.1832, 0.1978, 0.2243,
        0.0706, 0.1958, 0.1943, 0.1939, 0.1978, 0.1951, 0.1995, 0.1912, 0.2083,
        0.2037, 0.1849, 0.1945, 0.2189, 0.0419, 0.1977, 0.1979, 0.0608, 0.1824,
        0.2055, 0.0476, 0.1892, 0.2079, 0.2047, 0.2233, 0.2097, 0.2075, 0.2076,
        0.1793, 0.1312, 0.1841, 0.1939, 0.1561, 0.0577, 0.1948, 0.2048, 0.1717,
        0.1942, 0.1708, 0.1989, 0.1993, 0.2082, 0.1071, 0.1968, 0.1770, 0.2164,
        0.1864, 0.1938, 0.2184, 0.1343, 0.1707, 0.0683, 0.1401, 0.1823, 0.2045,
        0.2007, 0.1853, 0.1783, 0.1889, 0.1870, 0.1975, 0.2114, 0.2108, 0.2083,
        0.2409, 0.1938, 0.2022, 0.0857, 0.1823, 0.1879, 0.1979, 0.1850, 0.1029,
        0.1762, 0.

In [72]:
a = model.transformer.h[0].ln_1(hs[0])
a

tensor([[[-0.0019, -0.1037,  0.0011,  ...,  0.1639, -0.0937,  0.0341],
         [ 0.1208, -0.0494, -0.0716,  ..., -0.0105, -0.1055, -0.0614],
         [ 0.0803, -0.0301,  0.0376,  ..., -0.0025, -0.1169, -0.1349],
         ...,
         [ 0.0287,  0.1587,  0.0989,  ..., -0.2045,  0.0797, -0.0762],
         [-0.1655, -0.1087,  0.0521,  ..., -0.1481, -0.1840,  0.0692],
         [ 0.1190, -0.0457,  0.0814,  ..., -0.0793, -0.2008,  0.1165]]],
       grad_fn=<NativeLayerNormBackward0>)

In [91]:
q,k,v = model.transformer.h[0].attn.c_attn(a).split((768, 768, 768), dim=2)

In [100]:
768/12

64.0

In [102]:
q[0][:, :64].shape

torch.Size([16, 64])

In [109]:
(q[0][:, :64], k[0][:, :64].T)

(tensor([[ 0.0912, -0.3976, -0.1741,  ..., -0.6445,  0.8169,  0.0283],
         [ 0.2345,  0.4826,  0.0843,  ..., -0.7618,  1.3198,  0.8108],
         [ 1.0434,  0.2520, -0.3885,  ..., -0.3062, -0.7437,  0.4505],
         ...,
         [-0.5736, -0.9146, -1.0684,  ..., -0.4004,  1.0381, -1.4904],
         [ 1.3110, -1.4384, -0.5634,  ..., -0.7802, -0.4422,  0.5120],
         [-0.0185, -0.7402, -0.2828,  ..., -0.7718,  0.6120,  0.4303]],
        grad_fn=<SliceBackward0>),
 tensor([[-1.0132, -2.4647, -1.9612,  ..., -3.2596, -3.1811, -2.5158],
         [ 2.1841,  2.4373,  2.0603,  ...,  2.4532,  2.7200,  1.9560],
         [ 0.5509,  2.5695,  1.9895,  ...,  1.4560,  2.0114,  3.3688],
         ...,
         [-1.4600, -1.3429, -1.5031,  ..., -1.4433, -0.7054, -2.1890],
         [-0.3480, -1.6592, -2.1551,  ..., -3.2121, -2.0480, -1.0909],
         [ 1.3589,  1.9584,  1.9182,  ...,  1.0802,  1.3847,  3.0028]],
        grad_fn=<PermuteBackward0>))

In [110]:
(q[0][:, :64] @ k[0][:, :64].T)

tensor([[  2.4854,  -8.8071, -11.4450, -13.1997, -10.8294, -11.9119, -13.0449,
          -4.3931, -11.3048, -11.4803, -13.5225,  -9.5102, -15.5763, -12.9939,
         -11.9211, -15.6077],
        [  9.3668, -13.5898, -23.5171, -20.7397, -22.4040, -20.5137,  -3.0508,
         -14.9600, -21.2427, -12.9855, -20.3471, -13.3441, -25.0722, -16.9933,
         -21.9441, -12.1705],
        [  0.2705, -12.4505, -19.0622, -16.8973, -12.4341, -16.7591, -11.0500,
          -8.3409, -15.6627, -10.8733, -14.9455,  -6.5480, -18.0123, -10.3377,
         -20.1341, -12.8820],
        [  2.3355,  -5.8269, -10.3535, -16.7616,  -9.0833, -17.6115,  -7.2563,
         -12.5624, -12.1397,  -8.2168,  -5.8472,  -7.2905, -10.4731,  -8.7743,
         -12.0062,  -3.2087],
        [ -0.3024,  -6.1341, -11.9915, -10.1298, -16.1768, -11.7815,  -6.2491,
           2.1572,  -9.8635,  -4.6895,  -8.6208,  -0.4197,  -8.4663,   0.2582,
          -9.2789,  -9.8129],
        [ -1.9413, -13.1588, -22.6552, -22.5430, -19.7533, -

In [112]:
model.state_dict().keys()

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.

In [166]:
import importlib
import gen
import torch
importlib.reload(gen)

g = gen.Graph()


In [None]:
for k, v in model.state_dict().items():
    g.add_tensor(k, v, is_weight=True)

In [133]:
toks = g.add_tensor("toks", torch.tensor([1, 2, 3, 4]))

In [171]:

class lodash:
    def __init__(self):
        self.g = gen.Graph()
        pass

    def t(self, name: str, t: torch.Tensor, is_weight=False):
        return self.g.add_tensor(name, t, is_weight)

    def op(self, name: str, num_blocks: int, inps: list[int], outs: list[int], consts: list[int] = []):
        self.g.add_kernel_op(name, num_blocks, inps, outs, consts)
    
    def __getitem__(self, name: str):
        return self.g.get_tensor_idx(name)



In [134]:
x = g.add_tensor("x", torch.rand((4, 768)))
g.add_kernel_op("embedding", 10, [toks, g.get_tensor_idx("transformer.wte.weight")], [x])

In [135]:
print(g.emit([toks], [x]))


                    
                for(int bidx = blockIdx.x; bidx < 10; bidx += blockDim.x) {
                    embedding((toks),(transformer_wte_weight),(x), bidx);
                }
            
                    


In [142]:
print(g.emit_weight_struct())


                typedef struct {
                    float* transformer.wte.weight;
float* transformer.wpe.weight;
float* transformer.h.0.ln_1.weight;
float* transformer.h.0.ln_1.bias;
float* transformer.h.0.attn.c_attn.weight;
float* transformer.h.0.attn.c_attn.bias;
float* transformer.h.0.attn.c_proj.weight;
float* transformer.h.0.attn.c_proj.bias;
float* transformer.h.0.ln_2.weight;
float* transformer.h.0.ln_2.bias;
float* transformer.h.0.mlp.c_fc.weight;
float* transformer.h.0.mlp.c_fc.bias;
float* transformer.h.0.mlp.c_proj.weight;
float* transformer.h.0.mlp.c_proj.bias;
float* transformer.h.1.ln_1.weight;
float* transformer.h.1.ln_1.bias;
float* transformer.h.1.attn.c_attn.weight;
float* transformer.h.1.attn.c_attn.bias;
float* transformer.h.1.attn.c_proj.weight;
float* transformer.h.1.attn.c_proj.bias;
float* transformer.h.1.ln_2.weight;
float* transformer.h.1.ln_2.bias;
float* transformer.h.1.mlp.c_fc.weight;
float* transformer.h.1.mlp.c_fc.bias;
float* transformer.h.1.mlp.c_p

In [None]:
print(g.emit_weight_allocator())

In [None]:
x = g.add_tensor("x", torch.rand((4, 768)))
g.add_kernel_op("embedding", 10, [toks, g.get_tensor_idx("transformer.wte.weight")], [x])

In [172]:
_ = lodash()


In [None]:
N_EMBD = 768
N_HEAD = 12
N_LAYER = 12
VOCAB_SIZE = 50257
HEAD_DIM = N_EMBD // N_HEAD  # 64
N_INNER = 4 * N_EMBD        # 3072
EPS = 1e-5
THREADS = 256
NUM_BLOCKS = 512
BATCH = 8

In [161]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
# install triton

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Defaulting to user installation because normal site-packages is not writeable


In [189]:
_ = lodash()

# add all weights
for k, v in model.state_dict().items():
    _.t(k, v, is_weight=True)

SEQ_LEN = 100

toks = _.t("toks", torch.tensor([[1, 2, 3, 4]]*BATCH))

embed_out = _.t("embed_out", torch.rand((BATCH, SEQ_LEN, N_EMBD)))
_.op("embedding", 10, [toks, _["transformer.wte.weight"]], [embed_out])


pos_out = _.t("pos_out", torch.rand((BATCH, SEQ_LEN, N_EMBD)))
_.op("position", 10, [embed_out, _["transformer.wpe.weight"]], [pos_out])

res = _.t("residual", torch.rand((BATCH, SEQ_LEN, N_EMBD)))
x = pos_out

for layer_num in range(N_LAYER):
    
    ln1_out = _.t(f"ln1_out_{layer_num}", torch.rand((BATCH, SEQ_LEN, N_EMBD)))
    _.op("layer_norm", 10, [x, _[f"transformer.h.{layer_num}.ln_1.weight"], _[f"transformer.h.{layer_num}.ln_1.bias"]], [ln1_out])
    
    qkv = _.t(f"qkv_{layer_num}", torch.rand((BATCH, SEQ_LEN, 3, N_EMBD)))
    _.op(f"linear_layer", 10, [ln1_out, _[f"transformer.h.{layer_num}.attn.c_attn.weight"], _[f"transformer.h.{layer_num}.attn.c_attn.bias"]], [qkv])
    
    attn_scores = _.t(f"a_scores_{layer_num}", torch.rand((BATCH, N_HEAD, SEQ_LEN, SEQ_LEN)))
    _.op("attn", 10, [qkv], [attn_scores])

    # TODO: masking

    attn_probs = _.t(f"a_probs_{layer_num}", torch.rand((BATCH, N_HEAD, SEQ_LEN, SEQ_LEN)))
    _.op("softmax", 10, [attn_scores], [attn_probs])

    y = _.t(f"y_{layer_num}", torch.rand((BATCH, SEQ_LEN, N_EMBD)))
    _.op("v_mul", 10, [attn_probs, qkv], [y])


    x = y
    # TODO: dropouo


In [190]:
print(_.g.emit(start_idx=[toks], end_idx=[_["y_11"]]))


                    
                for(int bidx = blockIdx.x; bidx < 10; bidx += blockDim.x) {
                    embedding((toks),(transformer_wte_weight),(embed_out), bidx);
                }
            
                    
                    
                for(int bidx = blockIdx.x; bidx < 10; bidx += blockDim.x) {
                    position((embed_out),(transformer_wpe_weight),(pos_out), bidx);
                }
            
                    
                    
                for(int bidx = blockIdx.x; bidx < 10; bidx += blockDim.x) {
                    layer_norm((pos_out),(transformer_h_0_ln_1_weight),(transformer_h_0_ln_1_bias),(ln1_out_0), bidx);
                }
            
                    
                    
                for(int bidx = blockIdx.x; bidx < 10; bidx += blockDim.x) {
                    linear_layer((ln1_out_0),(transformer_h_0_attn_c_attn_weight),(transformer_h_0_attn_c_attn_bias),(qkv_0), bidx);
                }
            
      