In [2]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
model = model.to('mps')
model.device

device(type='mps', index=0)

In [3]:
sum(p.numel() for p in model.parameters()) / 10**6

124.439808

## Generation

In [4]:
prompt = "The great gatsby was writen by"
inputs = tokenizer(prompt, return_tensors='pt').to('mps')
inputs

{'input_ids': tensor([[ 464, 1049,  308, 1381, 1525,  373, 1991,  268,  416]],
       device='mps:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')}

In [5]:
for tok in inputs['input_ids'][0]:
    print(f'Token: {tok} --> "{tokenizer.decode(tok)}"')

Token: 464 --> "The"
Token: 1049 --> " great"
Token: 308 --> " g"
Token: 1381 --> "ats"
Token: 1525 --> "by"
Token: 373 --> " was"
Token: 1991 --> " writ"
Token: 268 --> "en"
Token: 416 --> " by"


In [6]:
with torch.no_grad():
    output_ids = model.generate(**inputs, max_new_tokens=50)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [8]:
# output_ids.shape
generation = tokenizer.decode(output_ids[0])
print(generation)

The great gatsby was writen by the great gatsby, and the great gatsby was writen by the great gatsby.

The great gatsby was writen by the great gatsby, and the great gatsby was writen by the


## Looking at `state_dict`

In [9]:
sd = model.state_dict()
for k,v in sd.items():
    print(f'{k} --> {v.shape}')

transformer.wte.weight --> torch.Size([50257, 768])
transformer.wpe.weight --> torch.Size([1024, 768])
transformer.h.0.ln_1.weight --> torch.Size([768])
transformer.h.0.ln_1.bias --> torch.Size([768])
transformer.h.0.attn.c_attn.weight --> torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias --> torch.Size([2304])
transformer.h.0.attn.c_proj.weight --> torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias --> torch.Size([768])
transformer.h.0.ln_2.weight --> torch.Size([768])
transformer.h.0.ln_2.bias --> torch.Size([768])
transformer.h.0.mlp.c_fc.weight --> torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias --> torch.Size([3072])
transformer.h.0.mlp.c_proj.weight --> torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias --> torch.Size([768])
transformer.h.1.ln_1.weight --> torch.Size([768])
transformer.h.1.ln_1.bias --> torch.Size([768])
transformer.h.1.attn.c_attn.weight --> torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias --> torch.Size([2304])
transformer.h.1.att

In [10]:
for name, buf in model.named_buffers():
    print(f'{name} --> {buf.shape}')

transformer.h.0.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.0.attn.masked_bias --> torch.Size([])
transformer.h.1.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.1.attn.masked_bias --> torch.Size([])
transformer.h.2.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.2.attn.masked_bias --> torch.Size([])
transformer.h.3.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.3.attn.masked_bias --> torch.Size([])
transformer.h.4.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.4.attn.masked_bias --> torch.Size([])
transformer.h.5.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.5.attn.masked_bias --> torch.Size([])
transformer.h.6.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.6.attn.masked_bias --> torch.Size([])
transformer.h.7.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.7.attn.masked_bias --> torch.Size([])
transformer.h.8.attn.bias --> torch.Size([1, 1, 1024, 1024])
transformer.h.8.attn.masked_bias --

## Understanding concepts: trial & error

In [8]:
import torch
b,t,n_embd,nh = 4,8,32,2
attn = torch.randn((b,nh,t,t))

tril = torch.tril(torch.ones(t,t))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [9]:
print(attn[0][0])
print(attn[0][1])

tensor([[ 1.0778,  0.7493, -1.5499,  0.9596,  0.8623,  1.4935, -0.4171,  2.5074],
        [ 0.4212,  0.0917,  0.8075, -0.2261,  0.4150,  0.5669,  0.5562,  1.5054],
        [-0.3372,  0.3031, -2.5702, -0.2523,  0.0713,  0.7166,  0.3191, -0.1069],
        [ 0.1994, -3.3223, -2.2039, -1.4927,  0.7381,  0.6496, -1.1768, -1.7605],
        [-0.6342, -0.5386,  1.2512, -1.0820, -0.3892,  0.9771, -0.0119, -0.4760],
        [-0.3823, -0.6987, -1.0331, -0.1693,  1.2746,  1.4286, -0.4293,  2.3684],
        [-0.0249,  0.4048, -1.1959, -0.4532, -0.0941, -0.7197,  0.1757, -0.6452],
        [-1.8539, -1.0931, -0.8180,  2.2152,  1.3110, -0.3808, -0.4433,  1.5723]])
tensor([[ 0.0167, -0.0424, -0.0515, -0.4647, -0.8529,  1.5306, -0.2215, -1.1013],
        [ 1.0849,  0.7927, -0.4453,  0.5311, -1.2426, -0.7290,  0.6556,  0.6728],
        [ 0.5365,  0.4651,  0.9445, -0.5449,  0.5725,  0.2194, -0.5900, -0.1532],
        [-0.3240,  0.0746, -1.3033, -1.0187,  1.0816, -1.1399,  1.5229,  1.2672],
        [-0.444

In [10]:
print((attn.masked_fill(tril == 0, float('-inf')))[0][0])
print((attn.masked_fill(tril == 0, float('-inf')))[0][1])

tensor([[ 1.0778,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.4212,  0.0917,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3372,  0.3031, -2.5702,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1994, -3.3223, -2.2039, -1.4927,    -inf,    -inf,    -inf,    -inf],
        [-0.6342, -0.5386,  1.2512, -1.0820, -0.3892,    -inf,    -inf,    -inf],
        [-0.3823, -0.6987, -1.0331, -0.1693,  1.2746,  1.4286,    -inf,    -inf],
        [-0.0249,  0.4048, -1.1959, -0.4532, -0.0941, -0.7197,  0.1757,    -inf],
        [-1.8539, -1.0931, -0.8180,  2.2152,  1.3110, -0.3808, -0.4433,  1.5723]])
tensor([[ 0.0167,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 1.0849,  0.7927,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.5365,  0.4651,  0.9445,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3240,  0.0746, -1.3033, -1.0187,    -inf,    -inf,    -inf,    -inf],
        [-0.444

In [11]:
tril1 = torch.tril(torch.ones(t,t))
tril2 = torch.tril(torch.ones(t,t)).view(1,1,t,t)
tril1.shape, tril2.shape

(torch.Size([8, 8]), torch.Size([1, 1, 8, 8]))

In [12]:
att1 = attn.masked_fill(tril1==0, float('-inf'))
att2 = attn.masked_fill(tril2[:,:,]==0, float('-inf'))

In [13]:
torch.allclose(att1, att2)

True

In [14]:
emb = 2
qkv = torch.randint(0, 10, (emb, 3*emb))
# print(qkv.shape)
print(qkv)

q,k,v = qkv.split(emb, dim=-1)
# q.shape, k.shape, v.shape
print(q)
print(k)
print(v)

tensor([[2, 7, 0, 5, 5, 2],
        [9, 0, 8, 4, 5, 5]])
tensor([[2, 7],
        [9, 0]])
tensor([[0, 5],
        [8, 4]])
tensor([[5, 2],
        [5, 5]])


## `state_dict` keys manipulations

In [17]:
print(list(sd.keys()), '\n')
for key in list(sd.keys()):
    if key.endswith('.attn.bias'):
        print(key)

['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_attn.bias

## testing `any` function in string

In [22]:
my_string = "transformer.h.11.ln_2.weight"
text_list = ["ln_1", "ln_3", "bias"]

In [24]:
if any(w in my_string for w in text_list):
    print('Fuck')
else:
    print("suck")

suck


In [25]:
any("")

False

In [26]:
any("fuck")

True