In [1]:
args = dict(
	model='FasterDecoding/medusa-1.0-vicuna-7b-v1.5',
	temperature=0.7,
	max_steps=512,
	style='simple',
	multiline=False,
	mouse=False,
	debug=False,
)


In [2]:
import torch

from medusa.model.medusa_model import MedusaModel


model = MedusaModel.from_pretrained(
    args['model'],
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
)
model

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.78s/it]


MedusaModelLlama(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [3]:
tokenizer = model.get_tokenizer()
tokenizer

Using sep_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.


LlamaTokenizerFast(name_or_path='FasterDecoding/medusa-1.0-vicuna-7b-v1.5', vocab_size=32000, model_max_length=4096, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<unk>', 'additional_special_tokens': ['<unk>', '<s>', '</s>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [4]:
from fastchat.model.model_adapter import get_conversation_template


conv = get_conversation_template(args['model'])
conv

Conversation(name='vicuna_v1.1', system_template='{system_message}', system_message="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=('USER', 'ASSISTANT'), messages=[], offset=0, sep_style=<SeparatorStyle.ADD_COLON_TWO: 2>, sep=' ', sep2='</s>', stop_str=None, stop_token_ids=None)

In [5]:
conv.roles

('USER', 'ASSISTANT')

In [6]:
from fastchat.serve.cli import SimpleChatIO


chatio = SimpleChatIO(args['multiline'])

inp = chatio.prompt_for_input(conv.roles[0])
inp

'hello'

In [7]:
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

prompt

"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: hello ASSISTANT:"

In [8]:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device)
input_ids.shape, input_ids

(torch.Size([1, 40]),
 tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
          21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
            322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
          29889,  3148,  1001, 29901, 22172,   319,  1799,  9047, 13566, 29901]],
        device='cuda:0'))

---

In [9]:
gen = model.medusa_generate(
    input_ids,
    temperature=args['temperature'],
    max_steps=args['max_steps'],
)

it = iter(gen)
next(it)

{'text': 'Hello! How can I'}

In [10]:
next(it)

{'text': 'Hello! How can I help you today?'}

---

In [11]:
input_ids = input_ids.clone()
medusa_choices = model.get_medusa_choice(model.base_model_name_or_path)

medusa_choices

[(0,),
 (0, 0),
 (1,),
 (0, 1),
 (0, 0, 0),
 (1, 0),
 (2,),
 (0, 2),
 (0, 0, 1),
 (0, 3),
 (3,),
 (0, 1, 0),
 (2, 0),
 (4,),
 (0, 0, 2),
 (0, 4),
 (1, 1),
 (1, 0, 0),
 (0, 0, 0, 0),
 (5,),
 (0, 0, 3),
 (0, 5),
 (0, 2, 0),
 (3, 0),
 (0, 1, 1),
 (0, 6),
 (6,),
 (0, 7),
 (0, 0, 4),
 (4, 0),
 (1, 2),
 (0, 8),
 (7,),
 (0, 3, 0),
 (0, 0, 0, 1),
 (0, 0, 5),
 (2, 1),
 (0, 0, 6),
 (1, 0, 1),
 (0, 0, 1, 0),
 (2, 0, 0),
 (5, 0),
 (0, 9),
 (0, 1, 2),
 (8,),
 (0, 4, 0),
 (0, 2, 1),
 (1, 3),
 (0, 0, 7),
 (0, 0, 0, 2),
 (0, 0, 8),
 (1, 1, 0),
 (0, 1, 0, 0),
 (6, 0),
 (9,),
 (0, 1, 3),
 (0, 0, 0, 3),
 (1, 0, 2),
 (0, 5, 0),
 (3, 1),
 (0, 0, 2, 0),
 (7, 0),
 (1, 4)]

In [12]:
medusa_buffers = model.medusa_buffers
medusa_buffers

{'medusa_attn_mask': tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
           [1., 1., 0.,  ..., 0., 0., 0.],
           [1., 0., 1.,  ..., 0., 0., 0.],
           ...,
           [1., 1., 0.,  ..., 1., 0., 0.],
           [1., 1., 0.,  ..., 0., 1., 0.],
           [1., 1., 0.,  ..., 0., 0., 1.]]]], device='cuda:0'),
 'tree_indices': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 11, 12, 13, 14, 15, 11, 12, 11, 12, 11, 11, 11, 11, 21, 22,
         23, 24, 25, 26, 27, 28, 29, 21, 22, 23, 24, 21, 22, 21, 21, 21, 21, 22,
         23, 21, 21, 31, 32, 33, 34, 31, 31, 31], device='cuda:0'),
 'medusa_position_ids': tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4], device='cuda:0'),
 'retrieve_indices': tensor([[ 0,  1, 12, 43, 63],
         [ 0,  1, 11, 36, 62],
         [ 0,  

In [13]:
past_key_values = model.past_key_values
past_key_values

[[<medusa.model.kv_cache.KVCache at 0x7341e16df190>,
  <medusa.model.kv_cache.KVCache at 0x7341e16dee90>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df010>,
  <medusa.model.kv_cache.KVCache at 0x7341e16def50>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16deef0>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df460>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df310>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df370>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df3d0>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df1c0>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df160>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df6d0>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df400>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df610>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df640>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df850>],
 [<medusa.model.kv_cache.KVCache at 0x7341e16df6a0>,
  <medusa.model.kv_cache.KVCache at 0x7341e16df1f0>],
 [<medusa.model.kv_cache.KVCache at 0

In [14]:
past_key_values_data = model.past_key_values_data
past_key_values_data.shape, past_key_values_data

(torch.Size([64, 1, 32, 4096, 128]),
 tensor([[[[[-4.4336e-01,  2.6489e-02,  1.8616e-02,  ...,  7.9224e-02,
             -7.9041e-02, -5.9998e-02],
            [-2.4634e-01, -5.1239e-02,  3.7292e-02,  ...,  2.6025e-01,
             -1.6504e-01,  2.5195e-01],
            [-1.0193e-02, -1.9775e-01, -4.5996e-01,  ...,  1.3757e-01,
              2.1448e-01,  2.4438e-01],
            ...,
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00],
            [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
              0.0000e+00,  0.0000e+00]],
 
           [[ 1.1064e+00,  9.3066e-01, -3.1470e-01,  ...,  4.9194e-01,
             -2.4622e-01,  4.5483e-01],
            [-2.2021e-01,  2.7075e-01,  2.5049e-01,  ..., -5.0488e-01,
              1.4160e-01, -3.7695e-01],
            [-1.2012e-01,  3.2251e-01, -3.7842e-02,  ..

In [15]:
current_length_data = model.current_length_data
current_length_data.shape, current_length_data

(torch.Size([64]),
 tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50, 50, 50, 50, 50, 50, 50]))

In [16]:
current_length_data.zero_()
current_length_data

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [17]:
input_len = input_ids.shape[1]
input_len

40

In [18]:
model.base_model.model.medusa_mask, model.base_model.model.medusa_mode

(tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
           [1., 1., 0.,  ..., 0., 0., 0.],
           [1., 0., 1.,  ..., 0., 0., 0.],
           ...,
           [1., 1., 0.,  ..., 1., 0., 0.],
           [1., 1., 0.,  ..., 0., 1., 0.],
           [1., 1., 0.,  ..., 0., 0., 1.]]]], device='cuda:0'),
 None)

In [19]:
from medusa.model.utils import reset_medusa_mode, initialize_medusa, generate_candidates, tree_decoding, evaluate_posterior, update_inference_inputs

reset_medusa_mode(model)

In [20]:
medusa_logits, logits = initialize_medusa(
    input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
)

medusa_logits.shape, logits.shape

(torch.Size([5, 1, 40, 32000]), torch.Size([1, 40, 32000]))

In [21]:
new_token = 0

In [22]:
candidates, tree_candidates = generate_candidates(
    medusa_logits,
    logits,
    medusa_buffers["tree_indices"],
    medusa_buffers["retrieve_indices"],
    temperature=args['temperature'],
    posterior_alpha=0.3,
    posterior_threshold=0.09,
    top_p=0.8,
    sampling='typical',
    fast=True,
)

candidates.shape, tree_candidates.shape

(torch.Size([42, 5]), torch.Size([1, 64]))

In [23]:
tree_candidates

tensor([[15043, 29991,   727, 29892, 30166,   518,     2,   304,  1738, 14332,
           322,  1128, 29991,   306,  1763,   920,     2,  1317, 29871,   739,
           727,  1128, 29991,   306,  1763,   920,  1128, 29991,  1128, 29991,
          1128,  1128,  1128,  1128,   508, 29915,   306,  1128,   338,  1122,
         29991,   674,   366,   508, 29915,   306,  1128,   508, 29915,   508,
           508,   508,   508, 29915,   306,   508,   508,   306,   508,   366,
         29885,   306,   306,   306]], device='cuda:0')

In [24]:
candidates[:4]

tensor([[15043, 29991, 29991,   508,   306],
        [15043, 29991,  1128,   306,   306],
        [15043, 29991,  1128, 29915,   306],
        [15043, 29991,  1128,   508, 29885]], device='cuda:0')

In [25]:
medusa_logits, logits, outputs = tree_decoding(
    model,
    tree_candidates,
    past_key_values,
    medusa_buffers["medusa_position_ids"],
    input_ids,
    medusa_buffers["retrieve_indices"],
)

medusa_logits.shape, logits.shape, outputs

(torch.Size([5, 42, 5, 32000]),
 torch.Size([42, 5, 32000]),
 BaseModelOutputWithPast(last_hidden_state=tensor([[[ 0.0200, -0.3262,  0.2083,  ...,  0.9995, -0.4351, -1.2793],
          [-0.3823, -1.6719, -0.4983,  ...,  1.2344, -1.3271,  1.2070],
          [ 0.2930, -1.1426,  0.0505,  ..., -0.1533, -0.6206, -0.6357],
          ...,
          [-0.5840, -1.6250,  0.0761,  ...,  1.2686, -1.6855, -0.2402],
          [-0.1727, -0.8457, -0.5044,  ...,  1.3057, -2.3984,  0.5483],
          [ 0.6421, -0.9990,  0.3555,  ...,  1.1768, -1.9121,  0.4226]]],
        device='cuda:0', dtype=torch.float16), past_key_values=None, hidden_states=None, attentions=None))

In [26]:
best_candidate, accept_length = evaluate_posterior(
    logits, candidates, args['temperature'], 0.09, 0.3, top_p=0.8, sampling='typical', fast=True
)

best_candidate, accept_length

(tensor(6, device='cuda:0'), tensor(4, device='cuda:0'))

In [27]:
input_ids.shape, input_ids

(torch.Size([1, 40]),
 tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
          21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
            322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
          29889,  3148,  1001, 29901, 22172,   319,  1799,  9047, 13566, 29901]],
        device='cuda:0'))

In [28]:
input_ids, logits, medusa_logits, new_token = update_inference_inputs(
    input_ids,
    candidates,
    best_candidate,
    accept_length,
    medusa_buffers["retrieve_indices"],
    outputs,
    logits,
    medusa_logits,
    new_token,
    past_key_values_data,
    current_length_data,
)

input_ids.shape, input_ids, logits.shape, medusa_logits.shape, new_token

(torch.Size([1, 45]),
 tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
          21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
            322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
          29889,  3148,  1001, 29901, 22172,   319,  1799,  9047, 13566, 29901,
          15043, 29991,  1128,   508,   306]], device='cuda:0'),
 torch.Size([1, 1, 32000]),
 torch.Size([5, 1, 1, 32000]),
 tensor(5, device='cuda:0'))