In [1]:
import torch

from eagle.model.ea_model import EaModel


model = EaModel.from_pretrained(
    base_model_path='/models/Meta-Llama-3-8B-Instruct',
    ea_model_path='yuhuili/EAGLE-LLaMA3-Instruct-8B',
    total_token=60,
    depth=5,
    top_k=10,
    #torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cpu"
)
model

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:15<00:00,  3.94s/it]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /models/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'm

EaModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNorm

In [2]:
model.eval()
tokenizer = model.get_tokenizer()

In [3]:
messages = [
	{
		'role': 'user',
		'content': 'Hello World!',
	}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
prompt

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [4]:
input_ids = tokenizer([prompt], add_special_tokens=False, ).input_ids
input_ids

[[128000,
  128006,
  882,
  128007,
  271,
  9906,
  4435,
  0,
  128009,
  128006,
  78191,
  128007,
  271]]

In [5]:
gen = model.ea_generate(
    torch.as_tensor(input_ids),
    temperature=0,
    log=True,
    is_llama3=True,
)
gen

<generator object EaModel.ea_generate at 0x7f4ee0ed3140>

In [6]:
next(gen)

dtype=torch.float32 13 cpu
dtype=torch.float32 60 cpu


tensor([[128000, 128006,    882, 128007,    271,   9906,   4435,      0, 128009,
         128006,  78191, 128007,    271,   9906,   4435,      0,   1102,    596]])

In [7]:
input_ids = torch.as_tensor(input_ids)
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)

model.ea_layer.reset_kv()

(input_ids, padding)

(tensor([[128000, 128006,    882, 128007,    271,   9906,   4435,      0, 128009,
          128006,  78191, 128007,    271]]),
 tensor([[-1]]))

In [9]:
past_key_values = model.past_key_values
past_key_values_data = model.past_key_values_data
current_length_data = model.current_length_data
# Reset the past key and value states
current_length_data.zero_()

past_key_values, past_key_values_data, current_length_data

([[<eagle.model.kv_cache.KVCache at 0x7f50880ad2d0>,
   <eagle.model.kv_cache.KVCache at 0x7f50880aea40>],
  [<eagle.model.kv_cache.KVCache at 0x7f50880afcd0>,
   <eagle.model.kv_cache.KVCache at 0x7f508808a290>],
  [<eagle.model.kv_cache.KVCache at 0x7f50880af520>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f47df0>],
  [<eagle.model.kv_cache.KVCache at 0x7f508808b430>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f47ee0>],
  [<eagle.model.kv_cache.KVCache at 0x7f4ee0f47d60>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f44c70>],
  [<eagle.model.kv_cache.KVCache at 0x7f4ee0f450c0>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f47f10>],
  [<eagle.model.kv_cache.KVCache at 0x7f4ee0f44df0>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f44e80>],
  [<eagle.model.kv_cache.KVCache at 0x7f4ee0f45300>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f47e50>],
  [<eagle.model.kv_cache.KVCache at 0x7f4ee0f45390>,
   <eagle.model.kv_cache.KVCache at 0x7f4ee0f44fa0>],
  [<eagle.model.kv_cache.KVCache at 0

In [10]:
from eagle.model.utils import initialize_tree


input_len = input_ids.shape[1]

model.base_model.model.tree_mask = None
model.base_model.model.tree_mode = None

draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
    input_ids, model, past_key_values, None
)
new_token = 0


dict(
	input_len=input_len,
	draft_tokens=draft_tokens,
	retrieve_indices=retrieve_indices,
	tree_mask=tree_mask,
	tree_position_ids=tree_position_ids,
	logits=logits,
	hidden_state=hidden_state,
	sample_token=sample_token,
)

dtype=torch.float32 13 cpu


{'input_len': 13,
 'draft_tokens': tensor([[  9906,   1070,      0,   4435,      0,   1102,  20776,    353,    358,
           14262,      0,   1102,  20776,  14262,  22691,   9906, 128009,    353,
            2181,    596,    374,   1102,   4435,    596,    374,  10788,    374,
           20776,  24748,   1027,   2751,   6555,      6,   4435,    596,   4435,
            4435,  10788,  20776,  24748,    374,  14262,  92886,   6555,      0,
            4435,      0,   4435,      0,   4435,      0,      0,   4435,      0,
            4435,   1102, 128009,   4435,      0,      0]]),
 'retrieve_indices': tensor([[ 0,  2,  7, -1, -1, -1, -1],
         [ 0,  2,  8, -1, -1, -1, -1],
         [ 0,  2,  9, -1, -1, -1, -1],
         [ 0,  1,  4, 13, -1, -1, -1],
         [ 0,  1,  4, 16, -1, -1, -1],
         [ 0,  1,  4, 17, -1, -1, -1],
         [ 0,  1,  4, 18, -1, -1, -1],
         [ 0,  2,  5, 20, -1, -1, -1],
         [ 0,  2,  6, 22, -1, -1, -1],
         [ 0,  1,  4, 11, 24, -1, -1],
   

In [11]:
from eagle.model.utils import tree_decoding


model.base_model.model.tree_mask = tree_mask

logits, hidden_state_new, outputs = tree_decoding(
    model,
    draft_tokens,
    past_key_values,
    tree_position_ids,
    input_ids,
    retrieve_indices,
)

logits.shape, hidden_state_new.shape, outputs.last_hidden_state.shape

dtype=torch.float32 60 cpu


(tensor([[[20.6799,  1.7042, -1.0207,  ..., -0.7479, -0.7478, -0.7478],
          [-5.5799, -3.5172, -2.7334,  ...,  1.3887,  1.3891,  1.3892],
          [ 3.4875,  3.2434,  3.4027,  ..., -1.1164, -1.1162, -1.1159],
          ...,
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663],
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663],
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663]],
 
         [[20.6799,  1.7042, -1.0207,  ..., -0.7479, -0.7478, -0.7478],
          [-5.5799, -3.5172, -2.7334,  ...,  1.3887,  1.3891,  1.3892],
          [ 5.3050,  9.5382,  3.2356,  ..., -2.4858, -2.4862, -2.4860],
          ...,
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663],
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663],
          [ 1.7278,  0.0927, -0.6355,  ...,  0.2661,  0.2663,  0.2663]],
 
         [[20.6799,  1.7042, -1.0207,  ..., -0.7479, -0.7478, -0.7478],
          [-5.5799, -3.5172,

In [13]:
logits.shape, hidden_state_new.shape, outputs.last_hidden_state.shape

(torch.Size([34, 7, 128256]),
 torch.Size([1, 60, 4096]),
 torch.Size([1, 60, 4096]))

In [15]:
draft_tokens = torch.cat((draft_tokens,padding),dim=1)
candidates = draft_tokens[0,retrieve_indices]

draft_tokens.shape, candidates.shape

(torch.Size([1, 62]), torch.Size([34, 7]))

In [18]:
posterior_mask = (
        candidates[:, 1:].to(logits.device) == torch.argmax(logits[:, :-1], dim=-1)
).int()
posterior_mask.shape, posterior_mask

(torch.Size([34, 6]),
 tensor([[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0],
         [0, 1, 0, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 0, 1, 1, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 0, 0, 1, 0],
         [0, 1, 0, 0, 1, 1],
         [0, 1, 1, 1, 0, 0],
         [0, 1, 1, 0, 0, 1],
         [0, 1, 1, 0,

In [21]:
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
candidates_accept_length.shape, candidates_accept_length

(torch.Size([34]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [22]:
accept_length = candidates_accept_length.max()
accept_length

tensor(4)

In [23]:
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
best_candidate

tensor(15)