# Inference Example with Medusa

In this Jupyter notebook, we're going to demonstrate how to perform inference using the Medusa model on an interesting story prompt. Let's get the ball rolling!

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # define GPU id, remove if you want to use all GPUs available
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel, MedusaConfig
from medusa.model.kv_cache import *
from medusa.model.utils import *
from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


## Medusa Forward Function

We define the medusa_forward function that will be used for generating stories based on the provided prompts.


In [2]:
@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            # Load the cached medusa buffer
            medusa_buffers = model.medusa_buffers
        else:
            # Initialize the medusa buffer
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        # Initialize the past key and value states
        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


## Model Loading

We load the model and tokenizer using the specified paths and configurations.


In [3]:
from huggingface_hub import hf_hub_download
import json

model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'

# Download the config.json from the model repository
config_file = hf_hub_download(repo_id=model_name, filename="config.json")

# Load the config.json content
with open(config_file, 'r') as f:
    config_data = json.load(f)

# Add the model_type if it's missing
if 'model_type' not in config_data:
    config_data['model_type'] = 'llama'

# Save the modified config back to disk (if needed for future use)
with open(config_file, 'w') as f:
    json.dump(config_data, f, indent=2)

# Now load the configuration using transformers
config = MedusaConfig.from_pretrained(config_file)

# Proceed to load the model with the updated configuration
from transformers import AutoModel

model = MedusaModel.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

tokenizer = model.get_tokenizer()
medusa_choices = mc_sim_7b_63

You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  return torch.load(checkpoint_file, map_location=map_location)
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.03s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at lmsys/vicuna-7b-

## Setting Inference Parameters

Next, we set some parameters that will be used during the inference process.


In [4]:
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

## Setting The Prompt

The following is the story prompt we will use for generating our story in the demo.


In [5]:
prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi, could you share a tale about a charming llama that grows Medusa-like hair and starts its own coffee shop? ASSISTANT:"

In [6]:
import torch
print(torch.cuda.is_available())  # Should print True if CUDA is available
print(torch.cuda.device_count())  

True
1


## Performing Inference

Using the set parameters and the defined function, let's generate our story!


In [7]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 403
Compression ratio: tensor(1.0025, device='cuda:0')


## Decoding The Output

Let's decode the generated output to obtain our story.


In [8]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

Once upon a time, in a small village nestled in the Andes mountains, there lived a charming llama named Luna. Luna was known for her kind heart and her love of coffee. She would often spend her afternoons sipping on a steaming cup of joe at the local café, chatting with the villagers and enjoying the warmth of the sun on her back.

One day, as Luna was grazing on some fresh grass, she noticed that her hair was starting to grow longer and thicker. At first, she didn't think much of it, but as the days went on, her hair continued to grow and change. It became thick and wiry, with sharp spikes protruding from it.

Luna was confused and a little scared by her new appearance. She had always been a gentle creature, and now she looked like a monster. She knew that she couldn't stay in the village anymore, so she set off on a journey to find a new home.

As she wandered through the mountains, Luna stumbled upon a beautiful clearing. In the center of the clearing stood a small cottage, with a s

## Analyzing Wall Times

We will now break down and analyze the wall times during the inference process.

You might notice a significant time consumption during the initialization phase. This is primarily due to the GPU cache initialization process on the first run.

For a clearer perspective, you can try rerunning the decoding segment again to observe the differences.

In [9]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              1.450
Wall time medusa:                            0.132
Wall time Tree:                            251.298
Wall time Posterior:                         0.153
Wall time Update:                            0.143
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.993
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.592
