In [37]:
import torch

from eagle.model.ea_model import EaModel


model = EaModel.from_pretrained(
    base_model_path='/models/Meta-Llama-3-8B-Instruct',
    ea_model_path='yuhuili/EAGLE-LLaMA3-Instruct-8B',
    total_token=60,
    depth=5,
    top_k=10,
    #torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cpu"
)

model.eval()
tokenizer = model.get_tokenizer()

messages = [
	{
		'role': 'user',
		'content': 'Hello World!',
	}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
input_ids = tokenizer([prompt], add_special_tokens=False, ).input_ids

prompt, input_ids

Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.84s/it]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /models/Meta-Llama-3-8B-Instruct and are newly initialized: ['model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'mod

('<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n',
 [[128000,
   128006,
   882,
   128007,
   271,
   9906,
   4435,
   0,
   128009,
   128006,
   78191,
   128007,
   271]])

---
## eagenerate

In [38]:
from eagle.model.kv_cache import initialize_past_key_values


input_ids = torch.as_tensor(input_ids)
padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device)

model.ea_layer.reset_kv()

(
    past_key_values,
    past_key_values_data,
    current_length_data,
) = initialize_past_key_values(model.base_model)
model.past_key_values = past_key_values
model.past_key_values_data = past_key_values_data
model.current_length_data = current_length_data

# Reset the past key and value states
current_length_data.zero_()

model.base_model.model.tree_mask = None
model.base_model.model.tree_mode = None

input_ids

tensor([[128000, 128006,    882, 128007,    271,   9906,   4435,      0, 128009,
         128006,  78191, 128007,    271]])

---
## initialize_tree

In [39]:
outputs, orig, hidden_states = model(
    input_ids, past_key_values=past_key_values, output_orig=True
)

orig.shape, hidden_states.shape

(torch.Size([1, 13, 128256]), torch.Size([1, 13, 4096]))

In [40]:
token = torch.argmax(orig[:, -1])
token

tensor(9906)

In [41]:
token = token[None, None]
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)

input_ids.shape, input_ids

(torch.Size([1, 14]),
 tensor([[128000, 128006,    882, 128007,    271,   9906,   4435,      0, 128009,
          128006,  78191, 128007,    271,   9906]]))

In [42]:
head = model.base_model.lm_head
head

Linear(in_features=4096, out_features=128256, bias=False)

In [43]:
hidden_states.shape, input_ids.shape

(torch.Size([1, 13, 4096]), torch.Size([1, 14]))

In [44]:
self = model.ea_layer
self

Model(
  (embed_tokens): Embedding(128256, 4096, padding_idx=0)
  (layers): ModuleList(
    (0): LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (fc): Linear(in_features=8192, out_features=4096, bias=False)
  (act): SiLU()
  (logsoftmax): LogSoftmax(dim=-1)
)

---
## topK_genrate

`genrate` is a typo in EAGLE

In [45]:
total_tokens = self.total_tokens
depth = self.depth
top_k = self.top_k

total_tokens, depth, top_k

(59, 5, 10)

In [46]:
sample_token = input_ids[:, -1]

scores_list = []
parents_list = []
ss_token = []

input_ids = input_ids[:, 1:]
input_ids = input_ids.to(hidden_states.device)

len_posi = input_ids.shape[1]
self.reset()

sample_token, len_posi

(tensor([9906]), 13)

In [47]:
print(hasattr(self, "stable_kv") and self.stable_kv)

None


In [48]:
out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True)
out_hidden.shape, past_key_values[0][0].shape, past_key_values[0][1].shape

(torch.Size([1, 13, 4096]),
 torch.Size([1, 8, 13, 128]),
 torch.Size([1, 8, 13, 128]))

In [49]:
self.stable_kv = past_key_values
last_hidden = out_hidden[:, -1]

last_headout = head(last_hidden)
last_hidden.shape, last_headout.shape

(torch.Size([1, 4096]), torch.Size([1, 128256]))

In [50]:
top_k

10

In [51]:
last_p = self.logsoftmax(last_headout)
top = torch.topk(last_p, top_k, dim=-1)
topk_index, topk_p = top.indices, top.values
topk_index, topk_p

(tensor([[1070,    0, 4435,   11, 1578, 2684, 2268, 1203, 1917,  323]]),
 tensor([[-0.6652, -0.8468, -3.4629, -5.3292, -5.6466, -6.3770, -6.5991, -6.8702,
          -7.3081, -7.5023]], grad_fn=<TopkBackward0>))

In [52]:
scores = topk_p[0]
scores_list.append(scores[None])
parents_list.append(torch.zeros(1, dtype=torch.long, device=scores.device))
ss_token.append(topk_index)

scores_list, parents_list, ss_token

([tensor([[-0.6652, -0.8468, -3.4629, -5.3292, -5.6466, -6.3770, -6.5991, -6.8702,
           -7.3081, -7.5023]], grad_fn=<UnsqueezeBackward0>)],
 [tensor([0])],
 [tensor([[1070,    0, 4435,   11, 1578, 2684, 2268, 1203, 1917,  323]])])

In [53]:
input_ids = topk_index
input_hidden = last_hidden[None].repeat(1, top_k, 1)
tree_mask = self.tree_mask_init
topk_cs_index = torch.arange(top_k, device=self.embed_tokens.weight.device)

input_ids.shape, input_hidden.shape, tree_mask.shape, topk_cs_index.shape

(torch.Size([1, 10]),
 torch.Size([1, 10, 4096]),
 torch.Size([1, 1, 10, 10]),
 torch.Size([10]))

In [54]:
tree_mask, topk_cs_index

(tensor([[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
           [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
           [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
           [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
           [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
           [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
           [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]]]),
 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

In [55]:
depth

5

In [56]:
##########################################################################################################
# the loop

for i in range(depth):
    self.tree_mask = tree_mask
    position_ids = len_posi + self.position_ids
    # with Timer("draft one"):
    out_hidden, past_key_values = self(input_hidden, input_ids=input_ids, past_key_values=past_key_values,
                                       position_ids=position_ids, use_cache=True)
    len_posi += 1

    # with Timer("sort1"):
    bias1 = top_k if i > 0 else 0
    bias2 = max(0, i - 1)
    bias = 1 + top_k ** 2 * bias2 + bias1
    parents = (topk_cs_index + bias)
    parents_list.append(parents)

    last_headout = head(out_hidden[0])
    last_p = self.logsoftmax(last_headout)

    top = torch.topk(last_p, top_k, dim=-1)
    topk_index, topk_p = top.indices, top.values

    cu_scores = topk_p + scores[:, None]

    topk_cs = torch.topk(cu_scores.view(-1), top_k, dim=-1)
    topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
    scores = topk_cs_p

    out_ids = topk_cs_index // top_k
    input_hidden = out_hidden[:, out_ids]
    # with Timer("2index"):
    #     in_ids = topk_cs_index % top_k
    #     input_ids = topk_index[out_ids, in_ids][None]
    # with Timer("1index"):
    input_ids = topk_index.view(-1)[topk_cs_index][None]
    # print(input_ids.equal(input_ids0))

    ss_token.append(topk_index)
    scores_list.append(cu_scores)
    tree_mask = torch.cat((tree_mask[:, :, out_ids], self.tree_mask_init), dim=3)


In [57]:
dict(
	len_posi=len_posi,
	parents_list=parents_list,
	ss_token=ss_token,
	scores_list=scores_list,
)

{'len_posi': 18,
 'parents_list': [tensor([0]),
  tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]),
  tensor([11, 21, 31, 22, 23, 24, 25, 26, 51, 27]),
  tensor([111, 121, 112, 122, 131, 113, 114, 115, 116, 117]),
  tensor([211, 221, 222, 223, 231, 212, 251, 224, 271, 225]),
  tensor([311, 312, 351, 313, 314, 341, 321, 391, 381, 342])],
 'ss_token': [tensor([[1070,    0, 4435,   11, 1578, 2684, 2268, 1203, 1917,  323]]),
  tensor([[    0,  2268,    11,   758,  4999, 15114,  9135, 91597,   374,  1578],
          [ 1102, 20776,   353,   358, 14262, 22691,  2181,  9906,  1666, 29959],
          [    0,  2268, 91597,  4999,  9135, 15114, 16715,   758, 69322, 36675],
          [ 4435,   433, 20776,  1070, 10343, 14262, 10788,  9906, 22691, 24748],
          [    0,  2268,   758,    11,  1070,  4999, 91597,  9135, 15114, 16715],
          [    0,  2268,    11,   374,   758,  4999, 91597, 15114,  9135,  1578],
          [ 9906, 14262,  2181,  2170, 92886,   791,    32, 46078,  3947, 10343],
 

In [95]:
[ss.shape for ss in ss_token]

[torch.Size([1, 10]),
 torch.Size([10, 10]),
 torch.Size([10, 10]),
 torch.Size([10, 10]),
 torch.Size([10, 10]),
 torch.Size([10, 10])]

In [58]:
scores_list = torch.cat(scores_list, dim=0).view(-1)
ss_token_list = torch.cat(ss_token, dim=0).view(-1)
top_scores = torch.topk(scores_list, total_tokens, dim=-1)

scores_list.shape, ss_token_list.shape

(torch.Size([510]), torch.Size([510]))

In [59]:
top_scores_index = top_scores.indices
top_scores_index.shape, top_scores_index

(torch.Size([59]),
 tensor([  0,  10,   1, 110,  20, 210, 120, 111, 310, 311,   2,  30, 220, 221,
         121,  21, 222, 230, 350, 211, 420, 130, 312, 250, 410, 440, 112, 313,
         340, 320, 430, 113, 223, 421, 270, 390, 380, 490, 341, 500, 411,  22,
         321, 224, 114, 314,  23,  24, 225, 226, 280, 431, 115, 315, 316, 116,
         117, 227, 140]))

In [60]:
top_scores_index = torch.sort(top_scores_index).values
top_scores_index.shape, top_scores_index

(torch.Size([59]),
 tensor([  0,   1,   2,  10,  20,  21,  22,  23,  24,  30, 110, 111, 112, 113,
         114, 115, 116, 117, 120, 121, 130, 140, 210, 211, 220, 221, 222, 223,
         224, 225, 226, 227, 230, 250, 270, 280, 310, 311, 312, 313, 314, 315,
         316, 320, 321, 340, 341, 350, 380, 390, 410, 411, 420, 421, 430, 431,
         440, 490, 500]))

In [62]:
sample_token.shape

torch.Size([1])

In [61]:
draft_tokens = ss_token_list[top_scores_index]
draft_tokens = torch.cat((sample_token, draft_tokens), dim=0)
draft_tokens.shape

torch.Size([60])

In [64]:
draft_tokens

tensor([  9906,   1070,      0,   4435,      0,   1102,  20776,    353,    358,
         14262,      0,   1102,  20776,  14262,  22691,   9906, 128009,    353,
          2181,    596,    374,   1102,   4435,    596,    374,  10788,    374,
         20776,  24748,   1027,   2751,   6555,      6,   4435,    596,   4435,
          4435,  10788,  20776,  24748,    374,  14262,  92886,   6555,      0,
          4435,      0,   4435,      0,   4435,      0,      0,   4435,      0,
          4435,   1102, 128009,   4435,      0,      0])

In [63]:
draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long()
draft_parents.shape, draft_parents

(torch.Size([59]),
 tensor([  0,   0,   0,   1,   2,   2,   2,   2,   2,   3,  11,  11,  11,  11,
          11,  11,  11,  11,  21,  21,  31,  22, 111, 111, 121, 121, 121, 121,
         121, 121, 121, 121, 112, 131, 114, 115, 211, 211, 211, 211, 211, 211,
         211, 221, 221, 223, 223, 231, 224, 271, 311, 311, 312, 312, 351, 351,
         313, 381, 342]))

In [65]:
mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False)
mask_index.shape, mask_index

(torch.Size([59]),
 tensor([ 0,  0,  0,  0,  1,  1,  1,  1,  1,  2,  3,  3,  3,  3,  3,  3,  3,  3,
          4,  4,  9,  5, 10, 10, 18, 18, 18, 18, 18, 18, 18, 18, 11, 20, 13, 14,
         22, 22, 22, 22, 22, 22, 22, 24, 24, 26, 26, 32, 27, 34, 36, 36, 37, 37,
         47, 47, 38, 48, 46]))

In [66]:
mask_index[draft_parents == 0] = -1
mask_index

tensor([-1, -1, -1,  0,  1,  1,  1,  1,  1,  2,  3,  3,  3,  3,  3,  3,  3,  3,
         4,  4,  9,  5, 10, 10, 18, 18, 18, 18, 18, 18, 18, 18, 11, 20, 13, 14,
        22, 22, 22, 22, 22, 22, 22, 24, 24, 26, 26, 32, 27, 34, 36, 36, 37, 37,
        47, 47, 38, 48, 46])

In [67]:
# mask_index & mask_index_list is the parent indices in draft tokens (60)
mask_index = mask_index + 1
mask_index_list = mask_index.tolist()
mask_index_list

[0,
 0,
 0,
 1,
 2,
 2,
 2,
 2,
 2,
 3,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 5,
 5,
 10,
 6,
 11,
 11,
 19,
 19,
 19,
 19,
 19,
 19,
 19,
 19,
 12,
 21,
 14,
 15,
 23,
 23,
 23,
 23,
 23,
 23,
 23,
 25,
 25,
 27,
 27,
 33,
 28,
 35,
 37,
 37,
 38,
 38,
 48,
 48,
 39,
 49,
 47]

In [69]:
tree_mask = torch.eye(total_tokens + 1).bool()
tree_mask.shape, tree_mask

(torch.Size([60, 60]),
 tensor([[ True, False, False,  ..., False, False, False],
         [False,  True, False,  ..., False, False, False],
         [False, False,  True,  ..., False, False, False],
         ...,
         [False, False, False,  ...,  True, False, False],
         [False, False, False,  ..., False,  True, False],
         [False, False, False,  ..., False, False,  True]]))

In [70]:
tree_mask[:, 0] = True
tree_mask

tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False,  True,  ..., False, False, False],
        ...,
        [ True, False, False,  ...,  True, False, False],
        [ True, False, False,  ..., False,  True, False],
        [ True, False, False,  ..., False, False,  True]])

In [71]:
for i in range(total_tokens):
    tree_mask[i + 1].add_(tree_mask[mask_index_list[i]])
tree_mask

tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True, False,  True,  ..., False, False, False],
        ...,
        [ True,  True, False,  ...,  True, False, False],
        [ True, False,  True,  ..., False,  True, False],
        [ True, False,  True,  ..., False, False,  True]])

In [80]:
print('\n'.join([''.join(map(lambda x: '1' if x else '-', line)) for line in tree_mask.int().tolist()]))

1-----------------------------------------------------------
11----------------------------------------------------------
1-1---------------------------------------------------------
1--1--------------------------------------------------------
11--1-------------------------------------------------------
1-1--1------------------------------------------------------
1-1---1-----------------------------------------------------
1-1----1----------------------------------------------------
1-1-----1---------------------------------------------------
1-1------1--------------------------------------------------
1--1------1-------------------------------------------------
11--1------1------------------------------------------------
11--1-------1-----------------------------------------------
11--1--------1----------------------------------------------
11--1---------1---------------------------------------------
11--1----------1--------------------------------------------
11--1-----------1-------

In [81]:
tree_position_ids = torch.sum(tree_mask, dim=1) - 1
tree_position_ids.shape, tree_position_ids

(torch.Size([60]),
 tensor([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4,
         4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6]))

In [82]:
tree_mask = tree_mask.float()[None, None]
draft_tokens = draft_tokens[None]
tree_mask.shape, draft_tokens.shape

(torch.Size([1, 1, 60, 60]), torch.Size([1, 60]))

In [83]:
max_depth = torch.max(tree_position_ids) + 1
max_depth

tensor(7)

In [85]:
noleaf_index = torch.unique(mask_index).tolist()
len(noleaf_index), noleaf_index

(26,
 [0,
  1,
  2,
  3,
  4,
  5,
  6,
  10,
  11,
  12,
  14,
  15,
  19,
  21,
  23,
  25,
  27,
  28,
  33,
  35,
  37,
  38,
  39,
  47,
  48,
  49])

In [86]:
noleaf_num = len(noleaf_index) - 1
leaf_num = total_tokens - noleaf_num
noleaf_num, leaf_num

(25, 34)

In [88]:
retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1
print(retrieve_indices.shape)

retrieve_indices = retrieve_indices.tolist()
retrieve_indices

torch.Size([34, 7])


[[-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1],
 [-1, -1

In [90]:
tree_position_ids.shape, tree_position_ids

(torch.Size([60]),
 tensor([0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4,
         4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6]))

In [92]:
rid = 0
position_ids_list = tree_position_ids.tolist()

for i in range(total_tokens + 1):
    if i not in noleaf_index:
        cid = i
        depth = position_ids_list[i]
        for j in reversed(range(depth + 1)):	# from the leaf to the root
            retrieve_indices[rid][j] = cid
            cid = mask_index_list[cid - 1]
        rid += 1

rid, retrieve_indices

(34,
 [[0, 2, 7, -1, -1, -1, -1],
  [0, 2, 8, -1, -1, -1, -1],
  [0, 2, 9, -1, -1, -1, -1],
  [0, 1, 4, 13, -1, -1, -1],
  [0, 1, 4, 16, -1, -1, -1],
  [0, 1, 4, 17, -1, -1, -1],
  [0, 1, 4, 18, -1, -1, -1],
  [0, 2, 5, 20, -1, -1, -1],
  [0, 2, 6, 22, -1, -1, -1],
  [0, 1, 4, 11, 24, -1, -1],
  [0, 2, 5, 19, 26, -1, -1],
  [0, 2, 5, 19, 29, -1, -1],
  [0, 2, 5, 19, 30, -1, -1],
  [0, 2, 5, 19, 31, -1, -1],
  [0, 2, 5, 19, 32, -1, -1],
  [0, 3, 10, 21, 34, -1, -1],
  [0, 1, 4, 15, 36, -1, -1],
  [0, 1, 4, 11, 23, 40, -1],
  [0, 1, 4, 11, 23, 41, -1],
  [0, 1, 4, 11, 23, 42, -1],
  [0, 1, 4, 11, 23, 43, -1],
  [0, 2, 5, 19, 25, 44, -1],
  [0, 2, 5, 19, 25, 45, -1],
  [0, 2, 5, 19, 27, 46, -1],
  [0, 1, 4, 14, 35, 50, -1],
  [0, 1, 4, 11, 23, 37, 51],
  [0, 1, 4, 11, 23, 37, 52],
  [0, 1, 4, 11, 23, 38, 53],
  [0, 1, 4, 11, 23, 38, 54],
  [0, 1, 4, 12, 33, 48, 55],
  [0, 1, 4, 12, 33, 48, 56],
  [0, 1, 4, 11, 23, 39, 57],
  [0, 2, 5, 19, 28, 49, 58],
  [0, 2, 5, 19, 27, 47, 59]])

In [93]:
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
tree_position_ids = tree_position_ids.to(hidden_states.device)

draft_tokens.shape, retrieve_indices.shape, tree_mask.shape, tree_position_ids.shape

(torch.Size([1, 60]),
 torch.Size([34, 7]),
 torch.Size([1, 1, 60, 60]),
 torch.Size([60]))