In [11]:
from datasets import load_dataset
from transformers import GPT2Tokenizer
import torch as t
from torch.utils.data import DataLoader, Dataset

from gpt2_ai.model import GPT2, CausalSelfAttention, Block
from gpt2_ai.trainer import Trainer
from gpt2_ai.config import GPT2Config, TrainerConfig

[![Open in Colab](https://img.shields.io/badge/Open%20in%20Colab-Notebook-orange?logo=google-colab)](https://colab.research.google.com/github/your_username/your_repository/blob/master/path/to/your_notebook.ipynb)


In [12]:
conf = GPT2Config(n_layer=4, d_model=64, n_ctx=24, n_head=8)

In [13]:
dataset_name = "stas/openwebtext-10k"
name = 'data/' + dataset_name.split('/')[-1]
ds = load_dataset(dataset_name, split='train')

Reusing dataset openwebtext10k (/home/gerold/.cache/huggingface/datasets/stas___openwebtext10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


In [14]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [15]:
tokenizer.pad_token = tokenizer.eos_token
loader = DataLoader(ds, batch_size=8, shuffle=True)
sample = next(iter(loader))

In [16]:
in_x = tokenizer(
    sample['text'], return_tensors='pt',
    return_attention_mask=False,
    max_length=conf.n_ctx,
     padding='max_length', truncation=True)

In [17]:
in_x['input_ids'].shape

torch.Size([8, 24])

In [18]:
model = GPT2(conf)

In [19]:
model.count_params()

Number of trainable parameters: 4,919,825


In [20]:
logits = model(in_x['input_ids'])

In [11]:
logits.shape

torch.Size([8, 24, 50257])

In [21]:
trainer_conf = TrainerConfig(ckpt_path='../ckpt', log_path='../logs')

In [22]:
trainer = Trainer(trainer_conf, model, loader, loader, tokenizer)

In [14]:
%load_ext autoreload

In [15]:
%autoreload 2

In [25]:
trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: Parent directory ckpt/2023-30-2700:30:28 does not exist.

In [None]:
tokenizer.

In [30]:
import numpy as np

In [31]:
np.log(tokenizer.vocab_size)

10.82490511970208

In [19]:
crit = t.nn.CrossEntropyLoss()

In [22]:
in_x['input_ids'][:, 1:].shape

torch.Size([8, 23])

In [25]:
in_x['input_ids'][:, 1:].shape

torch.Size([8, 23])

In [23]:
logits[:, :-1].shape

torch.Size([8, 23, 50257])

torch.Size([8, 24, 50257])

In [37]:
pred = logits[:, :-1].reshape(-1, 50257)

In [41]:
target = in_x['input_ids'][:, 1:].reshape(-1)

In [42]:
crit(pred, target)

tensor(10.9579, grad_fn=<NllLossBackward0>)

tensor([[[ 8.3116e-01, -4.0233e-01, -8.4505e-01,  ...,  6.8999e-01,
           3.6214e-01,  1.5408e+00],
         [ 4.6434e-01,  1.9396e-01, -6.1156e-01,  ...,  6.8430e-01,
           5.2703e-01,  1.0217e+00],
         [ 2.2291e-01, -4.4124e-01, -1.7862e-02,  ...,  1.1990e+00,
           4.5777e-01,  7.8050e-01],
         ...,
         [ 5.6607e-01,  2.8459e-01, -2.1476e-01,  ...,  8.6783e-01,
           4.9642e-01,  7.7099e-01],
         [ 3.4211e-01,  2.3933e-01,  3.8665e-01,  ...,  1.0125e+00,
           5.9718e-01,  1.1299e+00],
         [ 6.4597e-01,  1.8843e-01, -5.6215e-01,  ...,  3.4234e-01,
           6.4102e-01,  1.3953e+00]],

        [[ 6.4040e-01, -1.9757e-01, -2.4741e-01,  ...,  7.8582e-01,
           4.0753e-01,  1.3126e+00],
         [ 2.3055e-01,  1.4652e-01, -5.4134e-02,  ...,  8.4779e-01,
          -1.5070e-01,  1.2744e+00],
         [ 4.0272e-01, -3.9435e-01, -1.3298e-01,  ...,  1.1853e+00,
           7.0849e-02,  1.7149e+00],
         ...,
         [ 4.9057e-01,  1

In [25]:
from transformers import BatchEncoding

In [26]:
BatchEncoding(data=sample, encoding=tokenizer)

TypeError: 'GPT2Tokenizer' object is not subscriptable

In [10]:
sample = tokenizer(ds[0]['text'], return_tensors='pt', return_attention_mask=False, max_length=CONTEXT_LENGTH, truncation=True)

In [11]:
model = GPT2(GPT2Config())

In [12]:
x = model(sample['input_ids'])

In [13]:
x

tensor([[[ 0.3097,  0.4870,  0.6961,  ...,  0.5891, -0.4094,  0.4599],
         [ 0.5984,  0.6257,  0.1372,  ...,  0.5681,  0.3522,  0.5912],
         [ 0.1387,  0.5918,  0.5854,  ...,  0.4782,  0.6790,  0.5634],
         ...,
         [ 0.5673,  0.1265, -0.0827,  ...,  0.2595,  0.7240,  0.4301],
         [ 0.4464,  0.5339, -0.3811,  ...,  0.3145,  0.7945,  0.1410],
         [ 0.8457,  0.3049,  0.1613,  ...,  0.0021,  0.4955,  0.2551]]],
       grad_fn=<ViewBackward0>)

In [152]:
act = getattr(t.nn, 'GELU')()

In [153]:
out = act(t.Tensor([1., 2., 3.]))

In [154]:
out

tensor([0.8413, 1.9545, 2.9960])

In [148]:
m = t.nn.GELU()
input = t.randn(2)
output = m(input)
output

tensor([-0.1697, -0.0300])

In [132]:
attn = CausalSelfAttention(GPT2Config())

In [103]:
block = Block(GPT2Config())

In [105]:
block(x).shape

torch.Size([1, 1024, 768])

In [106]:
conf = GPT2Config()

In [107]:
model = GPT2(conf)

In [110]:
out = model(sample['input_ids'])

In [112]:
out.shape

torch.Size([1, 1024, 50257])

In [115]:
out.log_softmax(-1).max(-1)

torch.return_types.max(
values=tensor([[-8.4164, -8.4164, -8.4164,  ..., -8.4158, -8.4158, -8.4158]],
       grad_fn=<MaxBackward0>),
indices=tensor([[28404, 28404, 28404,  ..., 28404, 28404, 28404]]))

In [94]:
Q = attn(x)

In [95]:
Q.shape

torch.Size([1, 1024, 768])

In [55]:
Q.shape

torch.Size([1, 1024, 12, 64])

In [56]:
Q.permute(0, 2, 1, 3)

tensor([[[[ 0.4051,  1.6578, -0.1196,  ..., -1.6676, -0.5051, -2.2094],
          [-0.2219,  0.3631,  0.1526,  ..., -1.5622,  0.9381, -0.4839],
          [ 0.1367,  0.1490,  0.2739,  ...,  0.4365,  0.0528,  0.9219],
          ...,
          [-0.9549, -0.4559, -1.0144,  ..., -0.9441,  0.2906,  0.0647],
          [ 0.6279,  0.8872,  0.8348,  ..., -1.4402,  0.9681, -0.4357],
          [ 0.0086, -0.4049, -0.5526,  ...,  0.2327, -0.0858,  1.1537]],

         [[ 0.2602, -0.4546,  0.4854,  ..., -0.0907, -0.0111,  0.6941],
          [ 0.1380,  1.3028, -0.2632,  ...,  1.1256, -0.0943, -1.3199],
          [-0.5720, -0.1067,  1.1817,  ..., -0.3961, -0.2230,  0.2691],
          ...,
          [-0.5954, -0.0263, -0.7146,  ...,  0.7584,  0.6328, -2.0703],
          [ 1.0791, -0.1267, -0.7209,  ...,  0.1225, -1.2627, -0.2249],
          [ 0.5651, -0.7779,  0.0820,  ..., -0.9344, -0.8203, -0.7762]],

         [[-0.1204, -0.4468, -0.6602,  ..., -0.9360,  0.7501,  0.2150],
          [-1.8273, -2.4964, -

In [67]:
batch_size, seq_len, n_embd, n_head = 2, 128, 64, 4
d_head = int(n_embd / n_head)

In [68]:
a = t.randn(batch_size, seq_len, n_embd)

In [69]:
a.shape

torch.Size([2, 128, 64])

In [70]:
a.view(batch_size, seq_len, n_head, d_head).shape

torch.Size([2, 128, 4, 16])

In [73]:
out1 = a.view(batch_size, seq_len, n_head, d_head).permute(0, 2, 1, 3)

In [78]:
out2 = a.view(batch_size, n_head, seq_len, d_head)

In [80]:
out2 == out1

tensor([[[[ True,  True,  True,  ...,  True,  True,  True],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False,

In [75]:
out1

tensor([[[[ 4.3545e-02, -1.9569e+00,  1.4791e+00,  ..., -9.6137e-01,
            9.6278e-01, -2.8604e-02],
          [-6.3019e-01,  7.8473e-01,  3.0342e-01,  ..., -9.3203e-01,
           -7.9259e-01,  6.9010e-01],
          [ 1.6866e+00,  2.7878e-01, -9.4930e-02,  ..., -9.5732e-01,
           -1.2384e+00,  7.2290e-01],
          ...,
          [ 1.0535e+00,  8.7548e-01, -7.7688e-02,  ..., -8.7108e-01,
           -9.5720e-02,  5.8334e-01],
          [ 5.3082e-01,  2.8144e+00, -2.0082e-01,  ..., -1.3521e+00,
           -2.9785e-01,  1.1507e+00],
          [-1.3991e+00,  9.8552e-01, -5.2617e-01,  ...,  3.5328e-01,
           -2.9659e-01,  3.6307e-01]],

         [[-9.1917e-01, -2.8233e-01,  4.6615e-01,  ..., -7.5273e-01,
            6.9479e-01, -1.0137e+00],
          [ 1.9282e+00, -6.7646e-02,  4.4653e-01,  ...,  9.0008e-01,
            1.0312e+00,  3.7749e-01],
          [-8.2837e-01, -1.4259e+00, -6.0044e-01,  ..., -1.2208e+00,
            7.8010e-01, -1.0375e+00],
          ...,
     

In [58]:
a

tensor([[[[ 1.3311e+00,  4.0612e-01, -9.6838e-01,  ...,  7.6811e-01,
            2.0105e+00, -2.4304e-01],
          [-1.0193e-01, -8.1819e-02, -1.2097e+00,  ...,  1.8281e+00,
           -1.5259e+00,  1.0163e+00],
          [-1.3933e+00,  3.2227e-01,  6.7729e-01,  ...,  1.4077e+00,
           -2.5910e+00, -1.2534e+00],
          ...,
          [-4.8123e-01,  8.4734e-01, -5.4851e-01,  ...,  7.2403e-01,
            3.0951e-01, -3.4596e-01],
          [ 8.0987e-01,  1.0177e+00, -1.0420e+00,  ..., -1.1196e+00,
           -1.3367e+00, -4.1055e-01],
          [ 5.3703e-01,  2.2097e+00,  1.2387e-01,  ..., -1.4005e+00,
           -1.7899e+00,  6.4690e-01]],

         [[ 5.1508e-01,  8.2310e-01, -3.0063e+00,  ...,  9.2919e-01,
           -9.7570e-01, -8.9491e-01],
          [ 1.6003e+00, -1.7011e+00, -2.0192e+00,  ...,  2.4410e-01,
           -5.2997e-01,  1.1098e+00],
          [-1.8489e+00,  1.8971e-01,  1.0997e+00,  ...,  9.3536e-01,
            5.5458e-01,  1.8069e+00],
          ...,
     

In [117]:
from transformers import GPT2LMHeadModel, GPT2Model

In [118]:
pretrained = GPT2Model.from_pretrained("gpt2")

Downloading: 100%|██████████| 523M/523M [00:48<00:00, 11.3MB/s] 


In [129]:
for el in pretrained.named_parameters():
    print(el[0])

wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
h.1.ln_1.weight
h.1.ln_1.bias
h.1.attn.c_attn.weight
h.1.attn.c_attn.bias
h.1.attn.c_proj.weight
h.1.attn.c_proj.bias
h.1.ln_2.weight
h.1.ln_2.bias
h.1.mlp.c_fc.weight
h.1.mlp.c_fc.bias
h.1.mlp.c_proj.weight
h.1.mlp.c_proj.bias
h.2.ln_1.weight
h.2.ln_1.bias
h.2.attn.c_attn.weight
h.2.attn.c_attn.bias
h.2.attn.c_proj.weight
h.2.attn.c_proj.bias
h.2.ln_2.weight
h.2.ln_2.bias
h.2.mlp.c_fc.weight
h.2.mlp.c_fc.bias
h.2.mlp.c_proj.weight
h.2.mlp.c_proj.bias
h.3.ln_1.weight
h.3.ln_1.bias
h.3.attn.c_attn.weight
h.3.attn.c_attn.bias
h.3.attn.c_proj.weight
h.3.attn.c_proj.bias
h.3.ln_2.weight
h.3.ln_2.bias
h.3.mlp.c_fc.weight
h.3.mlp.c_fc.bias
h.3.mlp.c_proj.weight
h.3.mlp.c_proj.bias
h.4.ln_1.weight
h.4.ln_1.bias
h.4.attn.c_attn.weight
h.4.attn.c_at

In [29]:
x.size()

torch.Size([1, 1024, 768])

In [32]:
from gpt2.model import MLP

ImportError: cannot import name 'MLP' from 'gpt2.model' (/home/gerold/projects/gpt2-all-in/gpt2/model.py)

In [30]:
attn = gpt2.model.CausalSelfAttention(gpt2.config.GPT2Config())

AttributeError: module 'gpt2.model' has no attribute 'CausalSelfAttention'

In [27]:
mask = t.tril(t.ones(10, 10))

In [26]:
rand = t.randn(10, 10)

In [28]:
rand.masked_fill(mask == 0, -1e5)

tensor([[ 2.5514e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,
         -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 1.4752e+00,  1.9423e+00, -1.0000e+05, -1.0000e+05, -1.0000e+05,
         -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 2.7961e-01, -7.6509e-01, -1.8800e+00, -1.0000e+05, -1.0000e+05,
         -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [-2.6870e-01,  9.1219e-01, -7.0824e-01,  9.7453e-01, -1.0000e+05,
         -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 4.1186e-03, -7.5786e-01,  1.3436e+00,  1.1266e+00,  8.9082e-01,
         -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [-3.2767e-01, -1.1534e-01, -1.1756e+00,  7.5173e-02,  1.5129e-01,
         -3.0663e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [-4.4853e-01, -1.3076e+00,  7.4041e-01, -4.8402e-01,  6.1193e-02,
         -2.1940e+00, -6.8383e-0