In [13]:
import torch
from medusa.model.medusa_model import MedusaModel


model_path = '/models/train/0626_medusa_mlp_Meta-Llama-3-8B-Instruct_medusa_3_lr_0.0001_layers_1'


model = MedusaModel.from_pretrained(
	model_path,
	torch_dtype=torch.float32,
	low_cpu_mem_usage=True,
	device_map='cpu',
)
model.eval()
tokenizer = model.get_tokenizer()

model

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.14s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at /models/Meta-Llama-3-8B-Instruct and are newly initialized: ['medusa_head.1.0.linear.weight', 'medusa_head.0.0.linear.bias', 'medusa_head.3.1.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.0.1.weight', 'medusa_head.0.0.linear.weight', 'medusa_head.2.0.linear.weight', 'medusa_head.1.1.weight', 'medusa_head.3.0.linear.weight', 'medusa_head.2.0.linear.bias', 'medusa_head.4.0.linear.weight', 'medusa_head.2.1.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.4.0.linear.bias', 'medusa_head.4.1.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MedusaModelLlama(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Li

In [115]:
query = 'Give a name of a color, answer in one word.'

In [116]:
from fastchat.model.model_adapter import get_conversation_template


conv = get_conversation_template(model_path)
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

prompt

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nSay a word, answer in one word.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [117]:
input_ids = tokenizer.encode(prompt, return_tensors="pt")
input_ids

tensor([[128000, 128000, 128006,    882, 128007,    271,  46864,    264,   3492,
             11,   4320,    304,    832,   3492,     13, 128009, 128006,  78191,
         128007,    271]])

In [118]:
from medusa.model.kv_cache import initialize_past_key_values
from medusa.model.utils import reset_medusa_mode


(
	past_key_values,
	past_key_values_data,
	current_length_data,
) = initialize_past_key_values(model.base_model)
model.past_key_values = past_key_values
model.past_key_values_data = past_key_values_data
model.current_length_data = current_length_data

reset_medusa_mode(model)

In [119]:
logits = model(input_ids, past_key_values=model.past_key_values, output_orig=True, medusa_forward=False).logits
logits

tensor([[[  4.8914,   6.0422,  10.7782,  ...,  -3.6068,  -3.6067,  -3.6066],
         [  4.8914,   6.0422,  10.7782,  ...,  -3.6068,  -3.6067,  -3.6066],
         [ -2.6605,   1.5771,  -2.3671,  ...,   6.2346,   6.2343,   6.2343],
         ...,
         [-15.8961, -12.5432, -12.6700,  ...,  11.0865,  11.0863,  11.0865],
         [ -0.9013,   1.4144,  -4.5220,  ...,   4.8689,   4.8691,   4.8691],
         [  7.2457,  17.8507,   4.8114,  ...,  -0.0681,  -0.0680,  -0.0680]]],
       grad_fn=<UnsafeViewBackward0>)

In [120]:
sample_p = torch.softmax(logits[:, -1], dim=-1)
sample_p

tensor([[1.4926e-06, 6.0201e-02, 1.3084e-07,  ..., 9.9444e-10, 9.9453e-10,
         9.9453e-10]], grad_fn=<SoftmaxBackward0>)

In [121]:
sample_p.shape

torch.Size([1, 128256])

In [122]:
p10 = torch.topk(sample_p, 10, dim=-1)
p10

torch.return_types.topk(
values=tensor([[0.6802, 0.0654, 0.0602, 0.0467, 0.0160, 0.0137, 0.0092, 0.0058, 0.0056,
         0.0054]], grad_fn=<TopkBackward0>),
indices=tensor([[ 9906, 27665,     1, 16440,  3968, 38432, 77119, 50040,  6670, 68583]]))

In [123]:
tokenizer.decode(p10.indices.view(-1), skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)

'HelloApple"CloudFlComputerMoonDogTreeSpark'

In [124]:
[tokenizer.decode([id]) for id in p10.indices.view(-1)]

['Hello',
 'Apple',
 '"',
 'Cloud',
 'Fl',
 'Computer',
 'Moon',
 'Dog',
 'Tree',
 'Spark']

In [125]:
count = {}

for i in range(100000):
    id = torch.multinomial(sample_p, 1).item()
    if id not in count:
        count[id] = 0
    count[id] += 1

count

{9906: 6753,
 1: 616,
 16440: 462,
 27665: 631,
 61570: 1,
 11839: 4,
 3968: 167,
 6670: 68,
 38432: 126,
 77119: 96,
 25099: 5,
 43069: 42,
 62190: 24,
 10713: 11,
 50040: 64,
 44: 42,
 31998: 16,
 34: 45,
 47416: 23,
 33413: 11,
 22691: 12,
 1966: 1,
 37: 49,
 94357: 1,
 42737: 14,
 97824: 3,
 8325: 2,
 74627: 70,
 29707: 26,
 68781: 10,
 10370: 1,
 13347: 30,
 15339: 23,
 10115: 9,
 30233: 2,
 49540: 1,
 10902: 24,
 25821: 1,
 33813: 30,
 51812: 1,
 91056: 1,
 68583: 50,
 11116: 2,
 112584: 1,
 27899: 33,
 31955: 12,
 31192: 2,
 83380: 7,
 97283: 20,
 23182: 7,
 111491: 1,
 33947: 23,
 95570: 1,
 24581: 1,
 75613: 4,
 30197: 1,
 62816: 3,
 6219: 4,
 24818: 4,
 19753: 12,
 29353: 3,
 77610: 2,
 62528: 3,
 79178: 1,
 25025: 1,
 47: 13,
 33274: 7,
 330: 1,
 25173: 1,
 119581: 1,
 36152: 1,
 103724: 1,
 1671: 3,
 29296: 4,
 9642: 4,
 14588: 1,
 51787: 9,
 31380: 1,
 2118: 2,
 25310: 3,
 55471: 1,
 87703: 10,
 46240: 1,
 17863: 6,
 80871: 2,
 7280: 14,
 81117: 4,
 60139: 4,
 44638: 2,
 8

In [126]:
# sort count dict
sc = sorted(count.items(), key=lambda x: -x[1])
sc

[(9906, 6753),
 (27665, 631),
 (1, 616),
 (16440, 462),
 (3968, 167),
 (38432, 126),
 (77119, 96),
 (74627, 70),
 (6670, 68),
 (50040, 64),
 (68583, 50),
 (37, 49),
 (34, 45),
 (43069, 42),
 (44, 42),
 (27899, 33),
 (13347, 30),
 (33813, 30),
 (29707, 26),
 (62190, 24),
 (10902, 24),
 (47416, 23),
 (15339, 23),
 (33947, 23),
 (97283, 20),
 (31998, 16),
 (42737, 14),
 (7280, 14),
 (47, 13),
 (22691, 12),
 (31955, 12),
 (19753, 12),
 (10713, 11),
 (33413, 11),
 (68781, 10),
 (87703, 10),
 (10115, 9),
 (51787, 9),
 (83380, 7),
 (23182, 7),
 (33274, 7),
 (8586, 7),
 (29351, 7),
 (17863, 6),
 (51341, 6),
 (87655, 6),
 (25099, 5),
 (12331, 5),
 (11839, 4),
 (75613, 4),
 (6219, 4),
 (24818, 4),
 (29296, 4),
 (9642, 4),
 (81117, 4),
 (60139, 4),
 (45443, 4),
 (3513, 4),
 (97824, 3),
 (62816, 3),
 (29353, 3),
 (62528, 3),
 (1671, 3),
 (25310, 3),
 (91963, 3),
 (12988, 3),
 (9028, 3),
 (11087, 3),
 (58922, 3),
 (24748, 3),
 (65641, 3),
 (26208, 3),
 (48799, 3),
 (8325, 2),
 (30233, 2),
 (11116, 