In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device_type = 'cuda'
huggingface_token = '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'

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-7b",
    token=huggingface_token
)
big_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-7b",
    token=huggingface_token,
    torch_dtype=torch.bfloat16
).to(device_type)
small_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    token=huggingface_token,
    torch_dtype=torch.bfloat16
).to(device_type)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
input_text = 'once upon'
gamma = 5

In [None]:
# generate gamma tokens using the small model
input_ids = tokenizer(
    input_text,
    return_tensors='pt'
)['input_ids'].to(device_type)[0]
small_model_output = small_model.generate(
    input_ids[None],
    max_new_tokens=gamma,
    return_dict_in_generate=True,
    output_scores=True
)

In [None]:
# calculate q (the probability the small model assigned to the generated tokens)
small_model_generated_tokens = small_model_output.sequences[0, -gamma:]
q = torch.stack(small_model_output.scores, dim=1)[0].softmax(dim=-1)
q_of_generated = q[torch.arange(gamma), small_model_generated_tokens]
q_of_generated

tensor([0.9676, 0.9292, 0.2872, 0.2792, 0.7979], device='cuda:0')

In [None]:
# calculate p (the probability the big model assigned to the generated tokens)
p = big_model(small_model_output.sequences).logits.softmax(dim=-1)[0, -gamma-1:]
p_of_generated = p[torch.arange(gamma), small_model_generated_tokens]
p_of_generated

tensor([0.9662, 0.9092, 0.2634, 0.3173, 0.7331], device='cuda:0',
       grad_fn=<IndexBackward0>)

In [None]:
# acceptance rule
ratio = p_of_generated / q_of_generated
is_accepted = torch.rand(gamma).to(device_type) < ratio
index_to_reject = torch.argmin(torch.cat([
    is_accepted,
    torch.tensor([False]).to(device_type)
]).to(int))

In [None]:
accepted_tokens = small_model_generated_tokens[:index_to_reject]

if index_to_reject == gamma:  # all tokens are accepted
  p_for_sample = p[-1]
else:  # some token was rejected
  p_for_sample = p[index_to_reject] - q[index_to_reject]
  p_for_sample[p_for_sample<0] = 0
  p_for_sample /= p_for_sample.sum()

In [None]:
# generate one token using the big model
big_model_generated_token = torch.multinomial(p_for_sample, num_samples=1)
generated_tokens = torch.cat([accepted_tokens, big_model_generated_token])
generated_tokens

tensor([   476,   1069, 235269,   1104,    729,    476], device='cuda:0')

In [None]:
import colorama

In [None]:
# glue all the above code together and wrap it inside a while loop that
# generates the number of tokens requested by the user
input_text = 'once upon'
gamma = 5
seqlen = 30
input_ids = tokenizer(
    input_text,
    return_tensors='pt'
)['input_ids'].to(device_type)[0]
all_ids = input_ids

while len(all_ids) - len(input_ids) < seqlen:
  # generate gamma tokens using the small model
  small_model_output = small_model.generate(
      all_ids[None],
      max_new_tokens=gamma,
      return_dict_in_generate=True,
      output_scores=True
  )

  # calculate q (the probability the small model assigned to the generated tokens)
  small_model_generated_tokens = small_model_output.sequences[0, -gamma:]
  q = torch.stack(small_model_output.scores, dim=1)[0].softmax(dim=-1)
  q_of_generated = q[torch.arange(gamma), small_model_generated_tokens]

  # calculate p (the probability the big model assigned to the generated tokens)
  p = big_model(small_model_output.sequences).logits.softmax(dim=-1)[0, -gamma-1:]
  p_of_generated = p[torch.arange(gamma), small_model_generated_tokens]

  # acceptance rule
  ratio = p_of_generated / q_of_generated
  is_accepted = torch.rand(gamma).to(device_type) < ratio
  index_to_reject = torch.argmin(torch.cat([
      is_accepted,
      torch.tensor([False]).to(device_type)
  ]).to(int))
  accepted_tokens = small_model_generated_tokens[:index_to_reject]

  if index_to_reject == gamma:  # all tokens are accepted
    p_for_sample = p[-1]
  else:  # some token was rejected
    p_for_sample = p[index_to_reject] - q[index_to_reject]
    p_for_sample[p_for_sample<0] = 0
    p_for_sample /= p_for_sample.sum()

  # generate one token using the big model
  big_model_generated_token = torch.multinomial(p_for_sample, num_samples=1)

  print(colorama.Fore.BLACK, tokenizer.decode(all_ids), end='')
  print(colorama.Fore.BLUE, tokenizer.decode(accepted_tokens), end='')
  if index_to_reject < gamma:
    print(colorama.Fore.RED, tokenizer.decode(small_model_generated_tokens[index_to_reject]), end='')
  print(colorama.Fore.GREEN, tokenizer.decode(big_model_generated_token), end='')
  print()

  all_ids = torch.cat([all_ids, accepted_tokens, big_model_generated_token])

[30m <bos>once upon[34m  a time, there was[32m  a
[30m <bos>once upon a time, there was a[34m [31m  little[32m  place
[30m <bos>once upon a time, there was a place[34m  called[31m  the[32m  yak
[30m <bos>once upon a time, there was a place called yak[34m [31m sha[32m ima
[30m <bos>once upon a time, there was a place called yakima[34m . it was a place[32m  where
[30m <bos>once upon a time, there was a place called yakima. it was a place where[34m [31m  people[32m  heaven
[30m <bos>once upon a time, there was a place called yakima. it was a place where heaven[34m  and earth[31m  met[32m  meet
[30m <bos>once upon a time, there was a place called yakima. it was a place where heaven and earth meet[34m . it was a place[32m  of
[30m <bos>once upon a time, there was a place called yakima. it was a place where heaven and earth meet. it was a place of[34m [31m  peace[32m  gentle
[30m <bos>once upon a time, there was a place called yakima. it was a place where he