In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel, MedusaConfig
from medusa.model.kv_cache import *
from medusa.model.utils import *
from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            medusa_buffers = model.medusa_buffers
        else:
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


In [3]:
model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'

config = MedusaConfig.from_pretrained(
    model_name,
    medusa_num_heads=4,
    medusa_num_layers=1,
)

model = MedusaModel.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)
tokenizer = model.get_tokenizer()

medusa_choices = mc_sim_7b_63



You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  return torch.load(checkpoint_file, map_location=map_location)
Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.13s/it]
Some weights of MedusaModelLlama were not initialized from the model checkpoint at lmsys/vicuna-7b-

In [4]:
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

prompt1

"Explain the laws of thermodynamics in simple terms."
"What are the key differences between mitosis and meiosis?"
"How does machine learning improve image recognition?"
"Describe the process of photosynthesis step by step."
"What are the applications of calculus in physics?"
"Explain the significance of the Pythagorean theorem in geometry."
"How do electric circuits work in basic electronic devices?"

In [5]:
prompt="Create a program that takes two numbers and an operator (+, -, *, /) as input and returns the result"

In [6]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 512
Compression ratio: tensor(1.0020, device='cuda:0')


In [7]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

of the operation.

Example:
Input: 5 + 3, 7 - 2, 4 * 6 / 2
Output: 8, 5, 24

Note:

* The input numbers will be integers.
* The operator will be one of the four basic arithmetic operators (+, -, *, /).
* The output will be the result of the operation between the two numbers.
* The result will be an integer.
* The input will be valid, meaning that the numbers will be positive and less than 10^5.
* The output will be in the same format as the example above.
* The program will not use any built-in libraries or functions.
* The program will not use any arrays or dynamic memory allocation.
* The program will not use any sorting or searching algorithms.
* The program will not use any recursion.
* The program will not use any conditional statements other than if-else statements.
* The program will not use any loops other than while loops.
* The program will not use any functions other than the basic arithmetic operators.
* The program will not use any pre-defined constants other than the four

In [9]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              1.539
Wall time medusa:                            0.266
Wall time Tree:                            314.477
Wall time Posterior:                         0.237
Wall time Update:                            0.339
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.992
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.616


prompt2

In [12]:
prompt = "Write a function that converts temperatures from Celsius to Fahrenheit and vice versa. in C language"

In [13]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 512
Compression ratio: tensor(1.0020, device='cuda:0')


In [14]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

.

Write a program that calculates the average of a list of numbers. in C language.

Write a program that calculates the sum of the squares of the numbers in a list. in C language.

Write a program that calculates the factorial of a number. in C language.

Write a program that calculates the greatest common divisor (GCD) of two numbers. in C language.

Write a program that calculates the least common multiple (LCM) of two numbers. in C language.

Write a program that calculates the area of a rectangle. in C language.

Write a program that calculates the volume of a cube. in C language.

Write a program that calculates the surface area of a sphere. in C language.

Write a program that calculates the circumference of a circle. in C language.

Write a program that calculates the area of a triangle. in C language.

Write a program that calculates the area of a trapezoid. in C language.

Write a program that calculates the area of a parallelogram. in C language.

Write a program that calcul

In [15]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.812
Wall time medusa:                            0.282
Wall time Tree:                            331.880
Wall time Posterior:                         0.220
Wall time Update:                            0.373
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.995
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.535


prompt-3

In [16]:
prompt = "Develop a function that counts the number of words in a given string. in Python"

In [17]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 369
Compression ratio: tensor(1.0027, device='cuda:0')


In [18]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

.

Here is an example of how you might implement this function:
```
def count_words(string):
    # Initialize a counter for the number of words
    word_count = 0
    
    # Iterate through each character in the string
    for char in string:
        # If the character is a space, increment the word count
        if char == ' ':
            word_count += 1
    
    # Return the number of words in the string
    return word_count
```
This function takes a string as input and returns the number of words in the string. It does this by iterating through each character in the string and incrementing a counter for each space character it encounters.

You can test this function by calling it with a string as an argument and printing the result. For example:
```
string = "The quick brown fox jumps over the lazy dog."
print(count_words(string))  # Output: 6
```
In this example, the function correctly counts the number of words in the string as 6.

You can also use this function to count the num

In [19]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.835
Wall time medusa:                            0.189
Wall time Tree:                            239.357
Wall time Posterior:                         0.151
Wall time Update:                            0.264
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.994
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.532


prompt4

In [21]:
prompt = "Write a function that checks if a number is even or odd."

In [22]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)

Output length: 45
Compression ratio: tensor(1.0227, device='cuda:0')


In [23]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)



Example:

* 10 is even
* 5 is odd
* 7 is odd

Your function should return true if the number is even and false if the number is odd.</s>


In [24]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.954
Wall time medusa:                            0.057
Wall time Tree:                             28.553
Wall time Posterior:                         0.016
Wall time Update:                            0.019
--------------------------------------------------
Wall time portion medusa:                    0.002
Wall time portion Tree:                      0.965
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.520


prompt5

In [33]:
prompt="What are the applications of calculus in physics?"

In [34]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)


Output length: 286
Compression ratio: tensor(1.0035, device='cuda:0')


In [35]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)


Calculus is a branch of mathematics that deals with the study of rates of change and accumulation. It has numerous applications in physics, including:
1. Describing the motion of objects: Calculus is used to describe the motion of objects, including the motion of planets, the motion of particles in fluids, and the motion of waves.
2. Finding the maximum and minimum values of functions: Calculus is used to find the maximum and minimum values of functions that describe physical quantities, such as the position of an object as a function of time or the velocity of an object as a function of distance.
3. Solving optimization problems: Calculus is used to solve optimization problems in physics, such as finding the path of least resistance in a fluid or the path of a projectile in the absence of air resistance.
4. Modeling physical phenomena: Calculus is used to model physical phenomena, such as the motion of particles in a gas, the spread of diseases, and the behavior of electrical circuit

In [36]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.835
Wall time medusa:                            0.082
Wall time Tree:                            176.313
Wall time Posterior:                         0.174
Wall time Update:                            0.287
--------------------------------------------------
Wall time portion medusa:                    0.000
Wall time portion Tree:                      0.992
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.002
--------------------------------------------------
Tokens/second:                               1.610


prompt6

In [37]:

prompt = "Explain the significance of the Pythagorean theorem in geometry"

In [38]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)


Output length: 93
Compression ratio: tensor(1.0109, device='cuda:0')


In [39]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

.
The Pythagorean theorem is a fundamental theorem in geometry that states that in a right triangle, the square of the length of the hypotenuse (the side opposite the right angle) is equal to the sum of the squares of the lengths of the other two sides. This theorem has a number of important applications in mathematics and science, including the calculation of distances, the design of buildings and bridges, and the analysis of sound waves.</s>


In [40]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.839
Wall time medusa:                            0.036
Wall time Tree:                             57.301
Wall time Posterior:                         0.018
Wall time Update:                            0.081
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.983
Wall time portion Posterior:                 0.000
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.596


prompt7

In [41]:
prompt = "How do electric circuits work in basic electronic devices?"

In [42]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)


Output length: 255
Compression ratio: tensor(1.0039, device='cuda:0')


In [43]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)



Electric circuits in basic electronic devices work by using a flow of electrons to perform a specific task. The flow of electrons is controlled by a variety of components, including resistors, capacitors, and transistors. These components are connected together in a specific way to create a circuit that allows the flow of electrons to perform a specific task.

For example, a simple circuit might consist of a battery, a resistor, and a light bulb. When the circuit is completed by connecting the battery to the resistor and the resistor to the light bulb, the flow of electrons through the circuit causes the light bulb to light up. This is an example of a basic electric circuit that uses the flow of electrons to perform a specific task.

In more complex electronic devices, the flow of electrons is controlled by a variety of components and circuits that work together to perform more complex tasks. For example, a computer uses a complex circuit of transistors and other components to perfor

In [44]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              0.831
Wall time medusa:                            0.112
Wall time Tree:                            156.873
Wall time Posterior:                         0.068
Wall time Update:                            0.200
--------------------------------------------------
Wall time portion medusa:                    0.001
Wall time portion Tree:                      0.992
Wall time portion Posterior:                 0.000
Wall time portion Update:                    0.001
--------------------------------------------------
Tokens/second:                               1.613


In [5]:
prompt = "Given a text, extract all email addresses present."

In [6]:
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)


Output length: 53
Compression ratio: tensor(1.0192, device='cuda:0')


In [7]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)



Example:

Input:

"Hello, my email is [example@example.com](mailto:example@example.com). How are you? - John"

Output:

["example@example.com"]</s>


In [8]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)

Wall time init:                              2.678
Wall time medusa:                            0.057
Wall time Tree:                            112.259
Wall time Posterior:                         0.165
Wall time Update:                            0.033
--------------------------------------------------
Wall time portion medusa:                    0.000
Wall time portion Tree:                      0.975
Wall time portion Posterior:                 0.001
Wall time portion Update:                    0.000
--------------------------------------------------
Tokens/second:                               0.460
