In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2" # define GPU id, remove if you want to use all GPUs available
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import *
from medusa.model.utils import *
#from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=122


  warn((f'\n\n{"="*80}\n'


[2023-12-25 00:20:17,087] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)




In [2]:
def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values,attention_mask):
    """
    Initializes the Medusa structure for a given model.

    This function performs the following operations:
    1. Forward pass through the model to obtain the Medusa logits, original model outputs, and logits.
    2. Sets the Medusa attention mask within the base model.

    Args:
    - input_ids (torch.Tensor): The input tensor containing token ids.
    - model (MedusaLMHead): The model containing the Medusa layers and base model.
    - medusa_attn_mask (torch.Tensor): The attention mask designed specifically for the Medusa structure.
    - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.

    Returns:
    - medusa_logits (torch.Tensor): Logits from the Medusa heads.
    - logits (torch.Tensor): Original logits from the base model.
    """
    print(attention_mask)
    medusa_logits, outputs, logits = model(
        input_ids, attention_mask = attention_mask , output_orig=True# past_key_values=past_key_values
    )
    model.base_model.model.medusa_mask = medusa_attn_mask
    return medusa_logits, logits

In [3]:
@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_buffers, medusa_topk, temperature, posterior_threshold, posterior_alpha, past_key_values, past_key_values_data, current_length_data, attention_mask ,steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    print(attention_mask)
    with timed(wall_times, 'init'):
        reset_medusa_mode(model)
        input_len = input_ids.shape[1]
        medusa_logits, logits = initialize_medusa(input_ids, model, medusa_buffers['medusa_attn_mask'], past_key_values,attention_mask)
    
    new_token = 0

    for idx in range(steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(medusa_logits, logits, medusa_topk, medusa_buffers['tree_indices'], temperature)

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(model, tree_candidates, past_key_values, medusa_buffers['medusa_position_ids'], input_ids, medusa_buffers['retrieve_indices'])

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(logits, candidates, temperature, posterior_threshold, posterior_alpha)
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(input_ids, candidates, best_candidate, accept_length, medusa_buffers['retrieve_indices'], outputs, logits, medusa_logits, new_token, past_key_values_data, current_length_data)

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


In [4]:
model_name = '../../../../idea5_3gram_4fastlayer_t1_skipbert_teacherstudent_2_medusa_mlp_vicuna-7b-v1.3_medusa_1_lr_0.0001_layers_1/checkpoint-1800/pytorch_model.bin'

In [5]:
model = torch.load(model_name)

In [6]:
model_name2 = '../../../../idea5_3gram_4fastlayer_t1_skipbert_teacherstudent_2_medusa_mlp_vicuna-7b-v1.3_medusa_1_lr_0.0001_layers_1'
model2 = MedusaModel.from_pretrained(
    model_name2,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
   
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


path ../../../../../model/vicuna-7b-v1.3
path:  ../../../../../model/vicuna-7b-v1.3


In [7]:
model2.load_state_dict(model)

<All keys matched successfully>

In [8]:
tokenizer = model2.get_tokenizer()

In [9]:
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

In [10]:
prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi, could you share a tale about a charming llama that grows Medusa-like hair and starts its own coffee shop? ASSISTANT:"

In [11]:
model = model2

In [12]:
tokenizer = model2.get_tokenizer()

medusa_choices = torch.tensor([2,5])#([5,5,5,5,5])#4,5
num_heads = len(medusa_choices) - 1
medusa_topk = medusa_choices[1:]

medusa_buffers = generate_medusa_buffers(medusa_choices, device=model2.base_model.device)

  medusa_choices = torch.tensor(medusa_choices)


In [13]:
def single_forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
        past_key_values=None,
        output_orig=False,
        position_ids=None,
        
    ):
        """Forward pass of the MedusaModel.

        Args:
            input_ids (torch.Tensor, optional): Input token IDs.
            attention_mask (torch.Tensor, optional): Attention mask.
            labels (torch.Tensor, optional): Ground truth labels for loss computation.
            past_key_values (tuple, optional): Tuple containing past key and value states for attention.
            output_orig (bool, optional): Whether to also output predictions from the original LM head.
            position_ids (torch.Tensor, optional): Position IDs.

        Returns:
            torch.Tensor: A tensor containing predictions from all Medusa heads.
            (Optional) Original predictions from the base model's LM head.
        """
        with torch.inference_mode():
            # Pass input through the base model
            outputs = self.base_model.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                #output_hidden_states=True,
            )
            if output_orig:
                orig = self.base_model.lm_head(outputs[0])
        
        #####1.获取fastlayer层#####
        embed =self.base_model.model.embed_tokens(input_ids)
        embedtrigram = torch.cat((embed[:,:-2],embed[:,1:-1],embed[:,2:],),dim=-1)
        embed = self.trimlp(embedtrigram )
        from modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
        batch_size, seq_length = embed.shape[:2]
        attention_mask = _prepare_4d_causal_attention_mask(
                         attention_mask[:,:-2], (batch_size, seq_length), embed, 0
                    )
        attention_mask  = attention_mask.to(self.base_model.device)
        # embedtrigram = torch.cat((embed[:,:-2],embed[:,1:-1],embed[:,2:]),dim=-1)
        #for  i in self.fast_layer :      
        embed = self.fast_layer1(embed ,attention_mask = attention_mask)
        embed = self.fast_layer2(embed[0] ,attention_mask = attention_mask)
        embed = self.fast_layer3(embed[0] ,attention_mask = attention_mask)
        embed = self.fast_layer4(embed[0] ,attention_mask = attention_mask)
        embed = embed[0]
        loss_fct = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
        hsloss =loss_fct( outputs[0][:,2:].clone(),embed[:,:])
        embed3 = torch.cat((outputs[0][:,1:-1],embed[:,:]),dim=-1)#output2[0][:,-seq_length+2:-1]
        medusa_logits = []
        # TODO: Consider parallelizing this loop for efficiency?
        for i in range(self.medusa):
            #######修改后输出######
            medusa_logits.append(self.medusa_head[i](embed3.unsqueeze(0)))#self.medusa_head[i]embed3.unsqueeze(0)(outputs[0]))#hidden_states[i*4].clone()))gruout.to(self.base_model.dtype)
            ######原输出######
            #medusa_logits.append(self.medusa_head[i]((outputs[0].clone())))
        # if output_orig:
        #     return torch.stack(medusa_logits, dim=0), outputs, orig
        if output_orig:
            return torch.stack(medusa_logits, dim=0), outputs, orig
        return {"logits":torch.stack(medusa_logits, dim=0),"hsloss":hsloss}

In [20]:
with torch.inference_mode():
    for i in range(10):
        input = tokenizer([prompt])
        input_ids = torch.as_tensor(input.input_ids).cuda()
        attention_mask =torch.as_tensor(input.attention_mask).cuda()
        output = single_forward(model,input_ids,attention_mask=attention_mask,output_orig = True)
        token1 = torch.argmax(output[-1][-1][-1])
        
        input_ids = torch.cat((input_ids,token1.unsqueeze(0).unsqueeze(0)),dim=-1)
        prompt = tokenizer.decode(
                        input_ids[0],
                        spaces_between_special_tokens=False,
                    )
        input = tokenizer([prompt])
        input_ids = torch.as_tensor(input.input_ids).cuda()
        attention_mask =torch.as_tensor(input.attention_mask).cuda()
        output1 = model(input_ids,attention_mask=attention_mask,output_orig = True)
        token_ref = torch.argmax(output1[-1][-1][-1])
        token2 = torch.topk(output[0][-1][-1][-1][-1],5)[1]
        # print( token_ref)
        print(token2)
        correct = token2.eq(token_ref.unsqueeze(-1)).any(-1)
        print(correct)
        
        

tensor([ 263, 2793, 2998,  278, 4249], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 1339, 26373, 24600, 18708,  9914], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 8238,   347, 18708,   368,   993], device='cuda:0')
tensor(False, device='cuda:0')
tensor([5697,  322,  282, 4509, 1766], device='cuda:0')
tensor(False, device='cuda:0')
tensor([  310,   491,  4249, 29892,   322], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 278, 1749,  902,  263,  365], device='cuda:0')
tensor(False, device='cuda:0')
tensor([4870, 3942, 7881, 6350, 3815], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 3942,  7881,  2291,  6289, 29889], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 4870,    13,  2296,  2318, 29889], device='cuda:0')
tensor(False, device='cuda:0')
tensor([ 8126,   471,   750, 10398,   727], device='cuda:0')
tensor(False, device='cuda:0')


In [21]:
torch.argmax(output[0][0][0][0][-1])

tensor(29902, device='cuda:0')

In [29]:
token = torch.argmax(output[1][0][0][-1])

In [22]:
torch.argmax(output[1][0][0][0][-1])

tensor(0, device='cuda:0')

In [13]:
# past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model)

In [36]:
# with torch.inference_mode():
#     input = tokenizer([prompt])
#     input_ids = input.input_ids
#     attention_mask =torch.as_tensor(input.attention_mask).cuda()
#     print(attention_mask)
#     output_ids, new_token, idx, wall_time = medusa_forward(
#                     torch.as_tensor(input_ids).cuda(),
#                     model,
#                     tokenizer,
#                     medusa_buffers,
#                     medusa_topk,
#                     temperature,
#                     posterior_threshold,
#                     posterior_alpha,
#                     past_key_values,
#                     past_key_values_data,
#                     current_length_data,
#                     attention_mask = attention_mask  ,
#                 )
#     output_ids = output_ids[0][len(input_ids[0]) :]
#     print("Output length:", output_ids.size(-1))
#     print("Compression ratio:", new_token / idx)

In [None]:
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

In [None]:
max_length = 50

def format_string(text, value, max_length):
    value_str = "{:.3f}".format(value)
    return f"{text:<{max_length - len(value_str)}}{value_str}"

time_init = np.sum(wall_time['init'] )
time_medusa = np.sum(wall_time['medusa'] )
time_tree = np.sum(wall_time['tree'] )
time_posterior = np.sum(wall_time['posterior'] )
time_update = np.sum(wall_time['update'] )
time_total = time_init + time_medusa + time_tree + time_posterior + time_update

print('='*max_length)
print(format_string("Wall time init: ", time_init, max_length))
print(format_string("Wall time medusa: ", time_medusa, max_length))
print(format_string("Wall time Tree: ", time_tree, max_length))
print(format_string("Wall time Posterior: ", time_posterior, max_length))
print(format_string("Wall time Update: ", time_update, max_length))
print('-'*max_length)
print(format_string("Wall time portion medusa: ", time_medusa / time_total, max_length))
print(format_string("Wall time portion Tree: ", time_tree / time_total, max_length))
print(format_string("Wall time portion Posterior: ", time_posterior / time_total, max_length))
print(format_string("Wall time portion Update: ", time_update / time_total, max_length))
print('-'*max_length)
print(format_string("Tokens/second: ", new_token / time_total, max_length))
print('='*max_length)