In [1]:
import time

import numpy as np
from transformers import GPT2Model, GPT2Tokenizer

In [2]:
def sample(p):
  if np.sum(p) != 1:
    p = p / p.sum()
  return np.random.choice(range(len(p)), p=p)

class GPT2:

  hparams_dict = {
    "gpt2":        dict(n_layer=12, n_head=12, n_embed=768),   # 124M params
    "gpt2-medium": dict(n_layer=24, n_head=16, n_embed=1024),  # 350M params
    "gpt2-large":  dict(n_layer=36, n_head=20, n_embed=1280),  # 774M params
    "gpt2-xl":     dict(n_layer=48, n_head=25, n_embed=1600),  # 1558M params
  }
  
  context_len = 1024

  def __init__(self, model_type):
    self.hparams = self.hparams_dict[model_type]
    self.ws = {k: v.numpy() for k, v in GPT2Model.from_pretrained(model_type).state_dict().items()}

  def generate(self, start_ids, max_new_tokens, temperature=1.0, draft_model=None, K=4, stream=True, stream_printer=None):
    ret_p, ret_ids = [], []
    cnt = 0
    if stream_printer is not None: stream_printer(start_ids)
      
    while cnt < max_new_tokens:
      ids_cond = (start_ids + ret_ids)[-self.context_len:]

      if draft_model is not None:  # speculative_sampling
        # 1. sample K steps from draft model
        p_draft, ids_draft = draft_model.generate(ids_cond, K, temperature=temperature)
        # 2. forward target model
        p = self.forward(ids_cond + ids_draft)[-K-1:]
        # 3. loop throught draft tokens and perform reject samping
        new_p, new_ids = [], []
        all_accepted = True
        for i in range(K):
          j = ids_draft[i]
          if np.random.uniform() >= min(1, p[i][j]/p_draft[i][j]):
            # if current draft token j is rejected, we resample a token from normalized max(0, p-q)
            new_ids.append(sample(np.maximum(p[i] - p_draft[i], 0)))
            new_p.append(p[i])
            all_accepted = False
            break
          new_ids.append(j)
          new_p.append(p[i])
        if all_accepted:
          # sample extra token x_{n+k+1} if all draft tokens were accepted
          new_ids.append(sample(p[-1]))
          new_p.append(p[-1])
      else:
        # autoregressive sampling
        p = self.forward(ids_cond)[-1]
        new_p, new_ids = [p], [sample(p)]

      ret_p += new_p
      ret_ids += new_ids
      cnt += len(new_ids)
      if stream_printer is not None: stream_printer(new_ids)
        
    if stream_printer is not None: print()
    return np.vstack(ret_p), ret_ids

  def forward(self, ids, only_last=True):
    """minimal numpy implementation of gpt2 forward pass"""
    ws, hparams = self.ws, self.hparams

    def layer_norm(x, w, b, eps=1e-5):
      mean = np.mean(x, axis=-1, keepdims=True)
      var = np.var(x, axis=-1, keepdims=True)
      return ((x - mean) / (var + eps)**0.5) * w + b

    def softmax(x, axis=-1):
      x -= x.max(axis=axis, keepdims=True)
      x = np.exp(x, x)
      x /= x.sum(axis=axis, keepdims=True)
      return x

    def transformer_block(x, i):

      def linear(x, w, b):
        return x @ w + b

      def gelu(x):
        return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

      def mha(x, i):
        T, C = x.shape
        x = linear(x, ws[f"h.{i}.attn.c_attn.weight"], ws[f"h.{i}.attn.c_attn.bias"])
        n_head, hs = hparams["n_head"], C // hparams["n_head"]
        q, k, v = [np.transpose(h.reshape((T, n_head, hs)), (1,0,2)) for h in np.split(x, 3, axis=-1)]
        attn = softmax(q @ np.transpose(k, (0,2,1)) / hs**0.5 + (1 - np.tri(T, dtype=np.float32)) * -1e10)
        x = np.transpose(attn @ v, (1,0,2)).reshape((T, C))
        x = linear(x, ws[f"h.{i}.attn.c_proj.weight"], ws[f"h.{i}.attn.c_proj.bias"])
        return x

      def mlp(x, i):
        x = gelu(linear(x, ws[f"h.{i}.mlp.c_fc.weight"], ws[f"h.{i}.mlp.c_fc.bias"]))
        x = linear(x, ws[f"h.{i}.mlp.c_proj.weight"], ws[f"h.{i}.mlp.c_proj.bias"])
        return x

      x = x + mha(layer_norm(x, ws[f"h.{i}.ln_1.weight"], ws[f"h.{i}.ln_1.bias"]), i)
      x = x + mlp(layer_norm(x, ws[f"h.{i}.ln_2.weight"], ws[f"h.{i}.ln_2.bias"]), i)
      return x

    wte, wpe = ws["wte.weight"], ws["wpe.weight"]
    x = wte[ids] + wpe[range(len(ids))]
    for i in range(hparams["n_layer"]):
      x = transformer_block(x, i)
    x = layer_norm(x, ws["ln_f.weight"], ws["ln_f.bias"])
    logits = (x @ wte.T) / (temperature + 1e-8)
    return softmax(logits)

In [3]:
# configs
target_model_name = "gpt2-xl"
draft_model_name = "gpt2"
max_new_tokens = 50
temperature = 0  # large temperature -> more random, 0 -> greedy
K = 4
prompt = "Alan Turing theorized that computers would one day become"

In [4]:
draft_model = GPT2(draft_model_name)
target_model = GPT2(target_model_name)

tokenizer = GPT2Tokenizer.from_pretrained(target_model_name)
start_ids = tokenizer(prompt)["input_ids"]

def stream_printer(ids):
  print(tokenizer.decode(ids), end="", flush=True)

In [5]:
print("regular autoregressive sampling")
print("-"*50)
st = time.monotonic()
_, ids = target_model.generate(start_ids, max_new_tokens, temperature, stream_printer=stream_printer)
cost = time.monotonic() - st
print("-"*50)
print(f"cost: {cost:.2f}s, {len(ids)/cost:.2f} tokens/s")

regular autoregressive sampling
--------------------------------------------------
Alan Turing theorized that computers would one day become so powerful that they would be able to think like humans.

In the 1950s, he proposed a way to build a computer that could think like a human. He called it the "Turing machine."

The machine was a mechanical
--------------------------------------------------
cost: 53.53s, 0.93 tokens/s


In [6]:
print("speculative sampling")
print("-"*50)
st = time.monotonic()
_, ids = target_model.generate(start_ids, max_new_tokens, temperature, draft_model=draft_model, K=K, stream_printer=stream_printer)
cost = time.monotonic() - st
print("-"*50)
print(f"cost: {cost:.2f}s, {len(ids)/cost:.2f} tokens/s")

speculative sampling
--------------------------------------------------
Alan Turing theorized that computers would one day become so powerful that they would be able to think like humans.

In the 1950s, he proposed a way to build a computer that could think like a human. He called it the "Turing machine."

The machine was a mechanical
--------------------------------------------------
cost: 31.53s, 1.59 tokens/s
