In [1]:
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16
%env CUDA_VISIBLE_DEVICES=4
%load_ext autoreload
%autoreload 2
import sys; sys.path.insert(0, "..")

env: HF_HOME=/mnt/LLM
env: OMP_NUM_THREADS=16
env: CUDA_VISIBLE_DEVICES=4


In [2]:
import torch
from src.find_important_tokens_eagle import EaModelForAutoJudge
from prompts import GSM8KPrompts, llama_assistant_turn_end

base_model_path = "meta-llama/Llama-3.1-8B-Instruct"
EAGLE_model_path = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
window_size = 16

model = EaModelForAutoJudge.from_pretrained_with_tied_ref_model(
    base_model_path=base_model_path,
    ea_model_path=EAGLE_model_path,
    use_eagle3="eagle3" in EAGLE_model_path.lower(),
    torch_dtype="auto", #was: torch.bfloat16
    device_map="auto",
    low_cpu_mem_usage=True,    
    depth=window_size - 1,
    total_token=window_size,
    do_sample=False, top_p=None, top_k=1, temperature=None
)
device = next(model.parameters()).device
tokenizer = model.get_tokenizer()
model.eval();


ACHTUNG: Model will be loaded twice! This can be optimized.
LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
prompt_with_shots = GSM8KPrompts.prompt_with_0_shots
prompt = prompt_with_shots + question + "\n" + GSM8KPrompts.formatting_prompt + llama_assistant_turn_end
batch_input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)

In [4]:
# this prints a lot of logs; scroll down to get the actual sanity checks
current_response, changed_token_indices = model.find_important_tokens_greedy(
    input_ids=batch_input_ids, max_new_tokens=256
)
print('='*50)
print('=' * 17, 'FINAL RESPONSE', '=' * 17)
print('='*50)
print(*tokenizer.batch_decode(current_response))



NOT IMPORTANT
accept_length before tensor(6, device='cuda:0')
accept_length after 7
NOT IMPORTANT
accept_length before 7
accept_length after 9
BAD logp: ... t_id|><|start_header_id|>assistant<|end_header_id|>\n\nTo find the total number of clips sold in April |  April [logp=0.00000]
IMPORTANT
BAD logp: ... <|end_header_id|>\n\nTo find the total number of clips sold in April and May, we need to calculate the |  how [logp=0.00000]
IMPORTANT
BAD logp: ...  find the total number of clips sold in April and May, we need to calculate the number of clips sold |  by [logp=0.00002]
IMPORTANT
NOT IMPORTANT
accept_length before tensor(0, device='cuda:0')
accept_length after 1
BAD logp: ... he total number of clips sold in April and May, we need to calculate the number of clips sold in May |  separately [logp=0.00019]
IMPORTANT
BAD logp: ... al number of clips sold in April and May, we need to calculate the number of clips sold in May first |  clip [logp=0.00000]
IMPORTANT
BAD logp: ... umber of cl

In [5]:
for mismatch in changed_token_indices:
    if mismatch["is_important"]:
        assert current_response[0, mismatch['mismatch_index']].item() == mismatch['mismatch_target_token']
    else:
        assert current_response[0, mismatch['mismatch_index']].item() == mismatch['mismatch_draft_token']
print("Important tokens check passed!")

Important tokens check passed!


In [6]:
from lm_eval_utils import GSM8KParser, GSM8KEvaluator
parser, evaluator = GSM8KParser(), GSM8KEvaluator()
target_response = model.ref_model.eagenerate(batch_input_ids, max_new_tokens=256)

our_answer, target_answer = parser([
    *tokenizer.batch_decode(current_response), *tokenizer.batch_decode(target_response)])
is_match = evaluator(generations=tokenizer.batch_decode(current_response), references=[target_answer]) == 1
print(f"{our_answer=}, {target_answer=}, match={is_match}")
assert is_match

our_answer='72', target_answer='72.', match=True


In [7]:
print("Important rate:",
      sum(mismatch["is_important"] for mismatch in changed_token_indices) / len(changed_token_indices))

Important rate: 0.7073170731707317
