In [1]:
import torch

from eagle.model.ea_model import EaModel


model = EaModel.from_pretrained(
    base_model_path='/models/Meta-Llama-3-8B-Instruct',
    ea_model_path='yuhuili/EAGLE-LLaMA3-Instruct-8B',
    total_token=60,
    depth=5,
    top_k=10,
    #torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cpu"
)
model

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /models/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'm

EaModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNorm

In [2]:
model.eval()
tokenizer = model.get_tokenizer()

In [92]:
messages = [
	{
		'role': 'user',
		'content': 'Give a name of capital city, answer in one word.',
	}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
prompt

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nGive a name of capital city, answer in one word.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [93]:
input_ids = tokenizer([prompt], add_special_tokens=False, ).input_ids
input_ids

[[128000,
  128006,
  882,
  128007,
  271,
  36227,
  264,
  836,
  315,
  6864,
  3363,
  11,
  4320,
  304,
  832,
  3492,
  13,
  128009,
  128006,
  78191,
  128007,
  271]]

In [94]:
from eagle.model.kv_cache import initialize_past_key_values
from eagle.model.utils import reset_tree_mode


(
    past_key_values,
    past_key_values_data,
    current_length_data,
) = initialize_past_key_values(model.base_model)
model.past_key_values = past_key_values
model.past_key_values_data = past_key_values_data
model.current_length_data = current_length_data

reset_tree_mode(model)

In [95]:
input_ids_tensor = torch.as_tensor(input_ids)

with torch.no_grad():
    outputs, orig, hidden_states = model(
        input_ids_tensor, past_key_values=past_key_values, output_orig=True
    )
outputs

BaseModelOutputWithPast(last_hidden_state=tensor([[[ 4.1851, -0.2059, -1.8382,  ..., -2.8908,  1.3605,  0.3110],
         [-0.0892, -0.0484,  0.0638,  ...,  0.7794, -0.8125,  0.8029],
         [ 2.6079,  0.7637, -3.9777,  ..., -0.0527, -0.3737,  1.9626],
         ...,
         [-2.4613,  4.2603, -3.7489,  ...,  0.4261, -0.6128,  0.4078],
         [-1.4931,  0.7278, -2.0237,  ...,  0.8637,  1.0607, -1.2457],
         [ 1.6675, -0.1484,  2.3228,  ...,  0.9989,  4.8757, -2.0715]]]), past_key_values=(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None), hidden_states=None, attentions=None)

In [96]:
outputs.last_hidden_state.shape, hidden_states.shape

(torch.Size([1, 22, 4096]), torch.Size([1, 22, 4096]))

In [97]:
orig.shape

torch.Size([1, 22, 128256])

In [98]:
sample_p = orig[0, -1].softmax(dim=-1)
sample_p

tensor([1.8789e-08, 4.6902e-05, 4.0770e-07,  ..., 2.7712e-10, 2.7716e-10,
        2.7716e-10])

In [99]:
topk = torch.topk(sample_p, 10)
topk

torch.return_types.topk(
values=tensor([0.3381, 0.1700, 0.1417, 0.1064, 0.0568, 0.0268, 0.0260, 0.0259, 0.0145,
        0.0129]),
indices=tensor([60704, 53954,    46,  3513, 77406, 39231,    42,    49, 40672, 95509]))

In [100]:
tokenizer.convert_ids_to_tokens(topk.indices.tolist())

['Paris', 'Tok', 'O', 'Be', 'Mos', 'Washington', 'K', 'R', 'London', 'Berlin']

In [101]:
hidden_states = hidden_states[:, :-1]
hidden_states.shape

torch.Size([1, 21, 4096])

In [102]:
input_ids_tensor.shape

torch.Size([1, 22])

In [112]:
with torch.no_grad():
	out_hidden, past_key_values = model.ea_layer(hidden_states, input_ids=input_ids_tensor[:, 1:], use_cache=True)
out_hidden.shape, out_hidden

(torch.Size([1, 21, 4096]),
 tensor([[[-0.1511,  0.4875, -0.0915,  ...,  0.2875,  0.7770, -1.7313],
          [ 1.7614,  1.0684, -0.9005,  ..., -0.6003,  0.7143, -0.0374],
          [ 1.3904,  1.6742,  0.5768,  ..., -0.9929,  0.4055, -0.1562],
          ...,
          [-0.0364,  0.9519, -0.4526,  ...,  1.7459,  0.2040, -0.7841],
          [-0.2421,  1.9918, -0.4941,  ...,  1.3882,  1.1585, -1.1444],
          [-1.2170,  1.5401,  2.6356,  ..., -0.0735,  1.4186,  0.6648]]]))

In [119]:
last_hidden = out_hidden[:, -1]
with torch.no_grad():
	last_headout = model.base_model.lm_head(last_hidden)
last_headout.shape

torch.Size([1, 128256])

In [120]:
ea_p = last_headout.softmax(dim=-1)
ea_p.shape, ea_p

(torch.Size([1, 128256]),
 tensor([[7.7621e-07, 3.8634e-03, 5.9852e-06,  ..., 1.2224e-06, 1.2226e-06,
          1.2225e-06]]))

In [121]:
ea_top_p = ea_p[0, topk.indices]
ea_top_p

tensor([0.2160, 0.0140, 0.0005, 0.0123, 0.0009, 0.0022, 0.0010, 0.0110, 0.0830,
        0.0347])

In [122]:
ea_topk = torch.topk(ea_p, 10)
ea_topk

torch.return_types.topk(
values=tensor([[0.2160, 0.0830, 0.0347, 0.0238, 0.0180, 0.0176, 0.0140, 0.0123, 0.0110,
         0.0076]]),
indices=tensor([[60704, 40672, 95509, 72437, 63919, 92928, 53954,  3513,    49,  3648]]))

In [123]:
tokenizer.convert_ids_to_tokens(ea_topk.indices[0].tolist())

['Paris',
 'London',
 'Berlin',
 'Toronto',
 'Bang',
 'Singapore',
 'Tok',
 'Be',
 'R',
 'New']

In [124]:
dict(
	prompt=messages[0]['content'],
	scores=topk.values,
	ea_scores=ea_top_p,
	tokens=tokenizer.convert_ids_to_tokens(topk.indices.tolist()),
)

{'prompt': 'Give a name of capital city, answer in one word.',
 'scores': tensor([0.3381, 0.1700, 0.1417, 0.1064, 0.0568, 0.0268, 0.0260, 0.0259, 0.0145,
         0.0129]),
 'ea_scores': tensor([0.2160, 0.0140, 0.0005, 0.0123, 0.0009, 0.0022, 0.0010, 0.0110, 0.0830,
         0.0347]),
 'tokens': ['Paris',
  'Tok',
  'O',
  'Be',
  'Mos',
  'Washington',
  'K',
  'R',
  'London',
  'Berlin']}