In [179]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from dataclasses import dataclass, field
import json
import math
import pathlib
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import transformers
from transformers import Trainer, BitsAndBytesConfig
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
import os
from medusa.model.medusa_model import MedusaModel, MedusaConfig,SingleMedusa
import torch.nn.functional as F
IGNORE_TOKEN_ID = LabelSmoother.ignore_index


In [180]:

def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_buffers, medusa_topk, temperature, posterior_threshold, posterior_alpha, past_key_values, past_key_values_data, current_length_data, steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        reset_medusa_mode(model)
        input_len = input_ids.shape[1]
        medusa_logits, logits = initialize_medusa(input_ids, model, medusa_buffers['medusa_attn_mask'], past_key_values)
    
    new_token = 0

    for idx in range(steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(medusa_logits, logits, medusa_topk, medusa_buffers['tree_indices'], temperature)

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(model, tree_candidates, past_key_values, medusa_buffers['medusa_position_ids'], input_ids, medusa_buffers['retrieve_indices'])

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(logits, candidates, temperature, posterior_threshold, posterior_alpha)
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(input_ids, candidates, best_candidate, accept_length, medusa_buffers['retrieve_indices'], outputs, logits, medusa_logits, new_token, past_key_values_data, current_length_data)

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times


In [181]:
state_name = '../../../../idea8_t1_medusa_mlp_TinyLlama-1.1B-Chat-v0.6_medusa_1_lr_0.0005_layers_1/checkpoint-7200/pytorch_model.bin'

In [182]:
dict =torch.load(state_name)

In [183]:
model_name_or_path="../../../../../model/TinyLlama-1.1B-Chat-v0.6"

In [184]:
config = transformers.AutoConfig.from_pretrained(
    model_name_or_path,
)

In [185]:
model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        config=config,
        low_cpu_mem_usage=True,
    )


Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ../../../../../model/TinyLlama-1.1B-Chat-v0.6 and are newly initialized: ['model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.15.se

In [186]:
medusa_lm_head = MedusaModel(
        model,
        medusa_num_heads=1,
        medusa_num_layers=1,
        base_model_name_or_path=model_name_or_path
    )

path:  ../../../../../model/TinyLlama-1.1B-Chat-v0.6


In [187]:
medusa_lm_head.load_state_dict(dict)

<All keys matched successfully>

In [188]:
model_max_length=2048
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=model_max_length,
    padding_side="right",
    use_fast=False,
)

    

In [189]:
def generate(fastmodel,input_ids,attention_mask,outputs,k=5):
        embed =fastmodel.base_model.model.embed_tokens(input_ids)
        embedtrigram = torch.cat((embed[:,:-2],embed[:,1:-1],embed[:,2:]),dim=-1)
        gram1 = torch.cat((embed[:,0],embed[:,1],embed[:,1]),dim=-1).unsqueeze(1)
        embedtrigram = torch.cat((gram1,embedtrigram),dim=-2)
        embed = fastmodel.trimlp(embedtrigram )
        from modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
        batch_size, seq_length = embed.shape[:2]
        attention_mask = _prepare_4d_causal_attention_mask(
                         attention_mask[:,:-1], (batch_size, seq_length), embed, 0
                    )
        attention_mask  = attention_mask.to(fastmodel.base_model.device) 
        input1 = torch.cat((outputs[0],embed[:,:]),dim=-1)
        output_fastlayer =    fastmodel.fast_layer1(input1,attention_mask = attention_mask)
        output_fastlayer = fastmodel.fastoutput(output_fastlayer[0])
        logits = fastmodel.medusa_head[0](output_fastlayer)
        candidate = logits[0][-1].topk(k=5)[1]
        return logits,candidate
def calacc(input,max_length = 100,k=3):
    input = tokenizer([inputs])
    input_ids = torch.tensor(input.input_ids)
    attention_mask = torch.tensor(input.attention_mask)
    count = 0
    for i in range(max_length):
        outputs = medusa_lm_head.base_model.model(input_ids ,attention_mask = attention_mask ,output_hidden_states=True)
        orig =  medusa_lm_head.base_model.lm_head(outputs[0])
        t0 = torch.argmax(orig[0][-1])
        input_ids  = torch.cat((input_ids,t0.unsqueeze(0).unsqueeze(0)),dim=-1)
        attention_mask = torch.cat((attention_mask,torch.tensor([[1]])),dim=-1)
        
        l,ca = generate(medusa_lm_head,input_ids,attention_mask,outputs,k=k)
        realt1 =torch.argmax( medusa_lm_head.base_model(input_ids = input_ids)[0][0][-1])
        count+=sum(ca.eq(realt1))
    print(count/100)
    return input_ids

In [190]:
inputs = "who are you?Assistant:"

In [191]:
output = calacc(inputs)

tensor(0.9800)


In [192]:
tokenizer.decode(output[0])

'<s>who are you?Assistant: I am a person who is always looking for new experiences and learning new things. I am curious about the world around me and I love to explore new places and meet new people. I am always looking for ways to make a positive impact on the world around me. I am passionate about helping others and making a difference in their lives. I am always looking for ways to make a positive impact on the world around me. I am always looking for ways to make a positive impact on the world around'