In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model_name = 'openai-community/gpt2'

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(model_name)
base_model.to(device)
base_model.eval()

num_heads = base_model.config.n_head
num_layers = base_model.config.n_layer
d_model = base_model.config.n_embd
d_head = base_model.config.n_embd // num_heads
kv_dtype = torch.float32

In [2]:
prompts = [
  'Once upon a time, there was a',
  'In the future, AI will',
  'The meaning of life is',
  'FastAPI is a great framework for',
  'Transformers models are powerful for',
  'It was a sunny day when',
  'Quantum computing will change',
  'The secret to happiness is',
  'Long ago in a galaxy far',
  'Python is the best language for',
]

In [3]:
inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(device)
input_ids = inputs['input_ids']
input_ids

tensor([[ 7454,  2402,   257,   640,    11,   612,   373,   257],
        [  818,   262,  2003,    11,  9552,   481, 50256, 50256],
        [  464,  3616,   286,  1204,   318, 50256, 50256, 50256],
        [22968, 17614,   318,   257,  1049,  9355,   329, 50256],
        [41762,   364,  4981,   389,  3665,   329, 50256, 50256],
        [ 1026,   373,   257, 27737,  1110,   618, 50256, 50256],
        [24915,   388, 14492,   481,  1487, 50256, 50256, 50256],
        [  464,  3200,   284, 12157,   318, 50256, 50256, 50256],
        [14617,  2084,   287,   257, 16161,  1290, 50256, 50256],
        [37906,   318,   262,  1266,  3303,   329, 50256, 50256]],
       device='cuda:0')

In [4]:
import uuid


batch_size = len(prompts)
seq_ids = [str(uuid.uuid4().hex) for _ in range(batch_size)]

In [5]:
cache_pos = torch.zeros([batch_size, 2], device=device, dtype=torch.long)
cache_pos

tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]], device='cuda:0')

In [6]:
input_lens = inputs['attention_mask'].sum(dim=-1)
input_pos = torch.zeros([batch_size, 2], device=device, dtype=torch.long)
input_pos[:, 1] = input_lens
input_pos

tensor([[0, 8],
        [0, 6],
        [0, 5],
        [0, 7],
        [0, 6],
        [0, 6],
        [0, 5],
        [0, 5],
        [0, 6],
        [0, 6]], device='cuda:0')

In [7]:
from pagebrain.block import BlockManager
from pagebrain.cache import CacheManager


num_blocks = 1000
page_size = 32

block_manager = BlockManager(num_blocks, num_layers, num_heads, d_head, page_size, device, dtype=kv_dtype)
cache_manager = CacheManager(block_manager)

In [8]:
from pagebrain.models.gpt2 import PagedGPT2LMHeadModel


model = PagedGPT2LMHeadModel(base_model, cache_manager)
model.eval()

PagedGPT2LMHeadModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (h): ModuleList(
    (0-11): 12 x PagedGPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (attn): GPT2PagedAttention(
        (base_attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
)

In [9]:
gen_tokens = [[] for _ in range(batch_size)]
for step in range(30):
  with torch.no_grad():
    logits = model(
      input_ids=input_ids,
      seq_ids=seq_ids,
      input_pos=input_pos,
      cache_pos=cache_pos,
    )

    next_logits = logits[torch.arange(batch_size), input_pos[:, 1]-1]
    next_token_ids = next_logits.argmax(dim=-1)
    next_tokens = tokenizer.batch_decode(next_token_ids.tolist())

    input_ids = next_token_ids[:, None]
    cache_pos = torch.stack([cache_pos[:, 0], input_pos[:, 0] + input_pos[:, 1]], dim=-1)
    input_pos = torch.stack([cache_pos[:, 1], torch.ones([batch_size], device=device, dtype=torch.long)], dim=-1)
    
    print(f'========== step: {step} ==========')
    for sample_idx, next_token in enumerate(next_tokens):
      gen_tokens[sample_idx].append(next_token)
      text = prompts[sample_idx] + ''.join(gen_tokens[sample_idx])
      print(f'{sample_idx}: {text}')
    print()

0: Once upon a time, there was a man
1: In the future, AI will be
2: The meaning of life is not
3: FastAPI is a great framework for building
4: Transformers models are powerful for the
5: It was a sunny day when I
6: Quantum computing will change the
7: The secret to happiness is to
8: Long ago in a galaxy far,
9: Python is the best language for building

0: Once upon a time, there was a man who
1: In the future, AI will be able
2: The meaning of life is not the
3: FastAPI is a great framework for building a
4: Transformers models are powerful for the game
5: It was a sunny day when I was
6: Quantum computing will change the way
7: The secret to happiness is to be
8: Long ago in a galaxy far, far
9: Python is the best language for building web

0: Once upon a time, there was a man who was
1: In the future, AI will be able to
2: The meaning of life is not the same
3: FastAPI is a great framework for building a REST
4: Transformers models are powerful for the game,
5: It was a sunny day 