In [31]:
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


In [32]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [33]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [34]:
prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt")

In [35]:
inputs

{'input_ids': tensor([[  464,  2068,  7586, 21831, 11687,   625,   262]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [36]:
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits

In [37]:
logits[0, 2, :].shape

torch.Size([50257])

El modelo genera las probabilidades de generar el próximo token para cada token de entrada
pero sólo nos interesa el último

In [38]:
print(logits.shape)

torch.Size([1, 7, 50257])


In [39]:
last_logits = logits[0,-1, :]
# Greedy
next_token_id = last_logits.argmax()


In [40]:
print(next_token_id)

tensor(13990)


In [41]:
tokenizer.decode(next_token_id)

' fence'

In [42]:
top_k = torch.topk(last_logits, k = 10)
tokens = [tokenizer.decode(tk) for tk in top_k.indices]

In [43]:
tokens

[' fence',
 ' edge',
 ' railing',
 ' wall',
 ' table',
 ' tree',
 ' top',
 ' counter',
 ' ground',
 ' side']

In [44]:
inputs

{'input_ids': tensor([[  464,  2068,  7586, 21831, 11687,   625,   262]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [45]:
next_inputs = {
    "input_ids" : torch.cat(
        [inputs["input_ids"], next_token_id.reshape((1,1))],
        dim=1
    ),
    "attention_mask" : torch.cat(
        [inputs["attention_mask"], torch.tensor([[1]])],
        dim = 1
    )
}

In [46]:
print(next_inputs["input_ids"],
      next_inputs["input_ids"].shape)
print(next_inputs["attention_mask"],
      next_inputs["attention_mask"].shape)


tensor([[  464,  2068,  7586, 21831, 11687,   625,   262, 13990]]) torch.Size([1, 8])
tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) torch.Size([1, 8])


In [47]:
def generate_token(inputs):
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    last_logits = logits[0,-1, :]
    next_token_id = last_logits.argmax()
    return next_token_id

In [48]:
generated_tokens = []
next_inputs = inputs
duration_s = []

In [49]:
for _ in range(10):
    t0 = time.time()
    next_token_id = generate_token(next_inputs)
    duration_s += [time.time() - t0]

    next_inputs = {
        "input_ids" : torch.cat(
            [next_inputs["input_ids"], next_token_id.reshape((1,1))],
            dim = 1
        ),
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.tensor([[1]])],
            dim = 1
        )
    }

    next_token = tokenizer.decode(next_token_id)
    generated_tokens.append(next_token)


In [50]:
print(f"{sum(duration_s)} s")

2.839244842529297 s


In [51]:
print(generated_tokens)

[' fence', ' and', ' ran', ' to', ' the', ' other', ' side', ' of', ' the', ' fence']
