In [None]:
%%capture
!pip install tuned-lens

In [None]:
import time
import numpy as np
import transformers
from datasets import load_dataset
import torch
from tuned_lens.nn.lenses import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer

huggingface_token = TODO
device_type = torch.device('cuda')
model_name = 'meta-llama/Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    token=huggingface_token
)
model = model.to(device_type)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=huggingface_token)
tuned_lens = TunedLens.from_model_and_pretrained(model)
tuned_lens = tuned_lens.to(device_type)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

params.pt:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

(…)ens/meta-llama/Llama-2-7b-hf/config.json:   0%|          | 0.00/258 [00:00<?, ?B/s]

In [None]:
ds = load_dataset(
    "HuggingFaceFW/fineweb",
    "CC-MAIN-2014-10",
    split="train",
    streaming=True
)

Downloading readme:   0%|          | 0.00/38.7k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/212 [00:00<?, ?it/s]

In [None]:
def generate_with_tuned_lens(model, tuned_lens, layer_index, prefix_tokens, gamma):
  removed_layers = [
      model.model.layers.pop(-1)
      for _ in range(model.config.num_hidden_layers - layer_index)
  ]

  for _ in range(gamma):
    h = model.forward(
        prefix_tokens[None],
        output_hidden_states=True
    )['hidden_states']
    tuned_lens_scores = tuned_lens.forward(h[layer_index], layer_index)
    p_for_next_token = tuned_lens_scores.softmax(dim=-1)[0, -1, :]
    next_token = torch.multinomial(p_for_next_token, num_samples=1)
    prefix_tokens = torch.cat([prefix_tokens, next_token])

  model.model.layers.extend(reversed(removed_layers))

  return {
      'sequences': prefix_tokens[None],
      'scores': tuned_lens_scores.unbind(dim=1)[-gamma:]
  }

In [None]:
gamma = 1
seqlen = 10
layer_index = 3
num_prefix_tokens=10
num_of_examples=100

n_accepted = 0
n_generated = 0
speculative_times = []
vanilla_times = []

transformers.set_seed(42)
with torch.no_grad():
  for idx, example in enumerate(ds):
    if idx > num_of_examples:
      break

    input_text = example['text']
    input_ids = tokenizer(
        input_text,
        return_tensors='pt'
    )['input_ids'].to(device_type)[0][:num_prefix_tokens]
    all_ids = input_ids

    start = time.time()
    while len(all_ids) - len(input_ids) < seqlen:
      # generate gamma tokens using the small model
      small_model_output = generate_with_tuned_lens(
          model, tuned_lens, layer_index, all_ids, gamma
      )

      # calculate q (the probability the small model assigned to the generated tokens)
      small_model_generated_tokens = small_model_output['sequences'][0, -gamma:]
      q = torch.stack(small_model_output['scores'], dim=1)[0].softmax(dim=-1)
      q_of_generated = q[torch.arange(gamma), small_model_generated_tokens]

      # calculate p (the probability the big model assigned to the generated tokens)
      p = model(small_model_output['sequences']).logits.softmax(dim=-1)[0, -gamma-1:]
      p_of_generated = p[torch.arange(gamma), small_model_generated_tokens]

      # acceptance rule
      ratio = p_of_generated / q_of_generated
      is_accepted = torch.rand(gamma).to(device_type) < ratio
      index_to_reject = torch.argmin(torch.cat([
          is_accepted,
          torch.tensor([False]).to(device_type)
      ]).to(int))
      accepted_tokens = small_model_generated_tokens[:index_to_reject]

      if index_to_reject == gamma:  # all tokens are accepted
        p_for_sample = p[-1]
      else:  # some token was rejected
        p_for_sample = p[index_to_reject] - q[index_to_reject]
        p_for_sample[p_for_sample<0] = 0
        p_for_sample /= p_for_sample.sum()

      n_accepted += index_to_reject
      n_generated += index_to_reject
      if index_to_reject < gamma:
        n_generated += 1

      # generate one token using the big model
      big_model_generated_token = torch.multinomial(p_for_sample, num_samples=1)

      all_ids = torch.cat([all_ids, accepted_tokens, big_model_generated_token])

    speculative_times.append(time.time() - start)

    start = time.time()
    model.generate(
        input_ids[None],
        max_new_tokens=len(all_ids) - len(input_ids)
    )
    vanilla_times.append(time.time() - start)

    print(
        f'alpha: {n_accepted / n_generated:.2%} ' +
        f'improvement: {1 - np.mean(speculative_times) / np.mean(vanilla_times)}'
    )

alpha: 11.11% improvement: 0.005758333114685987
alpha: 23.53% improvement: 0.10597612045324756
alpha: 23.08% improvement: 0.10002485134778183
alpha: 20.00% improvement: 0.07788200524251421
alpha: 20.93% improvement: 0.08382580690680907
alpha: 21.57% improvement: 0.08900450709500152
alpha: 22.03% improvement: 0.09166837587321475
alpha: 18.84% improvement: 0.06729916547623727
alpha: 17.95% improvement: 0.05987469311739724
alpha: 18.39% improvement: 0.0649356401354253
alpha: 17.71% improvement: 0.060310228805825594
alpha: 16.04% improvement: 0.04772288038074657
alpha: 15.65% improvement: 0.0439835091145806
alpha: 16.26% improvement: 0.04822504449396203
alpha: 16.67% improvement: 0.05110563650517341
alpha: 17.86% improvement: 0.06100516118450594
alpha: 17.45% improvement: 0.05801927923634065
alpha: 18.59% improvement: 0.06574689564586256
alpha: 18.18% improvement: 0.06228616392113384
alpha: 19.19% improvement: 0.07083550240753667
alpha: 20.00% improvement: 0.07785068525738714
alpha: 19.58%