In [1]:
import torch

from eagle.model.ea_model import EaModel


model = EaModel.from_pretrained(
    base_model_path='/models/Meta-Llama-3-8B-Instruct',
    ea_model_path='yuhuili/EAGLE-LLaMA3-Instruct-8B',
    total_token=60,
    depth=5,
    top_k=10,
    #torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cpu"
)
model

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /models/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'm

EaModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNorm

In [2]:
model.eval()
tokenizer = model.get_tokenizer()

In [3]:
messages = [
	{
		'role': 'user',
		'content': 'Give a name of a color, answer in one word.',
	}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
prompt

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nGive a name of a color, answer in one word.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [4]:
input_ids = tokenizer([prompt], add_special_tokens=False, ).input_ids
input_ids

[[128000,
  128006,
  882,
  128007,
  271,
  36227,
  264,
  836,
  315,
  264,
  1933,
  11,
  4320,
  304,
  832,
  3492,
  13,
  128009,
  128006,
  78191,
  128007,
  271]]

In [20]:
from eagle.model.kv_cache import initialize_past_key_values
from eagle.model.utils import reset_tree_mode


(
    past_key_values,
    past_key_values_data,
    current_length_data,
) = initialize_past_key_values(model.base_model)
model.past_key_values = past_key_values
model.past_key_values_data = past_key_values_data
model.current_length_data = current_length_data

reset_tree_mode(model)

In [22]:
input_ids_tensor = torch.as_tensor(input_ids)

with torch.no_grad():
    outputs, orig, hidden_states = model(
        input_ids_tensor, past_key_values=past_key_values, output_orig=True
    )
outputs

BaseModelOutputWithPast(last_hidden_state=tensor([[[ 4.1851, -0.2059, -1.8382,  ..., -2.8908,  1.3605,  0.3110],
         [-0.0892, -0.0484,  0.0638,  ...,  0.7794, -0.8125,  0.8029],
         [ 2.6079,  0.7637, -3.9777,  ..., -0.0527, -0.3737,  1.9626],
         ...,
         [-2.6085,  4.6601, -4.3184,  ...,  0.4044, -0.4416,  0.1648],
         [-2.3383,  1.4609, -2.3124,  ...,  2.0149,  0.0813, -0.9260],
         [-2.5767,  3.5765,  6.0180,  ..., -0.2055,  2.0939, -0.1105]]]), past_key_values=(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None), hidden_states=None, attentions=None)

In [27]:
outputs.last_hidden_state.shape, hidden_states.shape

(torch.Size([1, 22, 4096]), torch.Size([1, 22, 4096]))

In [25]:
orig.shape

torch.Size([1, 22, 128256])

In [26]:
sample_p = orig[0, -1].softmax(dim=-1)
sample_p

tensor([2.1700e-08, 4.4491e-04, 2.8210e-07,  ..., 3.6314e-11, 3.6319e-11,
        3.6317e-11])

In [18]:
topk = torch.topk(sample_p, 10)
topk

torch.return_types.topk(
values=tensor([0.3908, 0.1933, 0.1912, 0.0796, 0.0599, 0.0225, 0.0105, 0.0099, 0.0090,
        0.0053]),
indices=tensor([10544, 75613,  1451,  6161, 38062,    34,    53, 43069,    45, 48799]))

In [19]:
tokenizer.convert_ids_to_tokens(topk.indices.tolist())

['Blue', 'Purple', 'Ind', 'Red', 'Tur', 'C', 'V', 'Orange', 'N', 'Yellow']

In [28]:
hidden_states = hidden_states[:, :-1]
hidden_states.shape

torch.Size([1, 21, 4096])

In [32]:
input_ids_tensor.shape

torch.Size([1, 22])

In [33]:
out_hidden, past_key_values = model.ea_layer(hidden_states, input_ids=input_ids_tensor[:, 1:], use_cache=True)
out_hidden.shape, out_hidden

torch.Size([1, 21, 4096])

In [35]:
last_hidden = out_hidden[:, -1]
last_headout = model.base_model.lm_head(last_hidden)
last_headout.shape

torch.Size([1, 128256])

In [36]:
ea_p = last_headout.softmax(dim=-1)
ea_p.shape, ea_p

(torch.Size([1, 128256]),
 tensor([[1.7842e-07, 1.0383e-02, 3.1208e-06,  ..., 3.0543e-07, 3.0549e-07,
          3.0547e-07]], grad_fn=<SoftmaxBackward0>))

In [41]:
ea_top_p = ea_p[0, topk.indices]
ea_top_p

tensor([0.1553, 0.0490, 0.0010, 0.1466, 0.0018, 0.0069, 0.0102, 0.0328, 0.0003,
        0.0461], grad_fn=<IndexBackward0>)