In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer
from decoding import clip_input, infer_input_ids
from rouge_score import rouge_scorer
import numpy as np

In [None]:
torch.nn.Linear.reset_parameters = lambda x: None
model = LlamaForCausalLM.from_pretrained('../llama-2-13b/', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('../llama-2-13b/')
device='cuda:6'
model = model.to(device).eval()

In [None]:
# -*- encoding:utf-8 -*-
import transformers
print(torch.version.cuda)
print(torch.__version__) 
print(transformers.__version__)

In [None]:
_attn_skip_layer_id_set, _mlp_skip_layer_id_set =model.get_skip_layers()
print(_attn_skip_layer_id_set, _mlp_skip_layer_id_set)

In [None]:
seed=42
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
from datasets import load_dataset
n_shot = 1
task_name = 'cnndm'
prompt_shots = ''
if task_name == 'xsum':
    data = load_dataset('xsum', split='test').shuffle(seed=seed).select(range(1000))
    shots = load_dataset('xsum',split='train').shuffle(seed=seed).select(range(n_shot))
    prompt_keys=['document','summary']
elif task_name == 'cnndm':
    data = load_dataset('cnn_dailymail', name='3.0.0', split='test') .shuffle(seed=seed).select(range(1000))
    shots = load_dataset('cnn_dailymail', name='3.0.0', split='train').shuffle(seed=seed).select(range(n_shot))
    prompt_keys=['article','highlights']
for i in range(n_shot):
    prompt = 'Article: ' + shots[i][prompt_keys[0]] + '\nSummary: ' + shots[i][prompt_keys[1]].replace('\n', '') + '\n'
    prompt_shots += prompt

In [None]:
from datasets import load_from_disk
n_shot = 1
task_name = 'cnndm'
prompt_shots = ''
if task_name == 'xsum':
    data =  load_from_disk('../xsum/test/1000')
    shots = load_from_disk('../xsum/shots')
    prompt_keys=['document','summary']
elif task_name == 'cnndm':
    data =  load_from_disk('../cnndm/test/1000')
    shots = load_from_disk('../cnndm/shots')
    prompt_keys=['article','highlights']
for i in range(n_shot):
    prompt = 'Article: ' + shots[i][prompt_keys[0]] + '\nSummary: ' + shots[i][prompt_keys[1]].replace('\n', '') + '\n'
    prompt_shots += prompt

In [None]:
print(prompt_shots)

In [None]:
rouge=rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
main_metrics = {'rouge2_base':[], 
                'rouge2_essg1':[], 'rouge2_essg2':[], 'rouge2_essg3':[], 'rouge2_essg4':[], 
                'rouge2_essg_autoth':[],
                'time_base':[], 
                'time_essg1':[], 'time_essg2':[], 'time_essg3':[], 'time_essg4':[], 
                'time_essg_autoth':[],
                'token_time_base':[], 
                'token_time_essg1':[], 'token_time_essg2':[], 'token_time_essg3':[], 'token_time_essg4':[], 
                'token_time_essg_autoth':[],
                'matchness_essg1':[],'num_drafted_tokens_essg1':[],
                'matchness_essg2':[],'num_drafted_tokens_essg2':[],
                'matchness_essg3':[],'num_drafted_tokens_essg3':[],
                'matchness_essg4':[],'num_drafted_tokens_essg4':[],
                'matchness_essg_autoth':[],'num_drafted_tokens_essg_autoth':[]}
with open('CNNDM_llama13b_1shot_bayesian-prompt-new_maxtoken512_maxstep12_essg_autoth0.60-st1-mat0.90-var1e-2-mom1-0.50-mom2-0.90_abla-th2468_env4-33-1_torch_1-13-1_cuda11.6_dec1.txt', 'w') as f:
    for i,x in enumerate(data):
        input_ids = clip_input(tokenizer, x, task_name, max_new_tokens=512,prompt_shots=prompt_shots)
        
        if i == 0:
            th_stop_draft_essg1 = 0.20
            th_stop_draft_essg2 = 0.40
            th_stop_draft_essg3 = 0.60
            th_stop_draft_essg4 = 0.80
            th_stop_draft_essg_autoth  = 0.60
        else:
            th_stop_draft_essg1 = result_essg1['th_stop_draft']
            th_stop_draft_essg2 = result_essg2['th_stop_draft']
            th_stop_draft_essg3 = result_essg3['th_stop_draft']
            th_stop_draft_essg4 = result_essg4['th_stop_draft']
            th_stop_draft_essg_autoth = result_essg_autoth['th_stop_draft']
        f.write('essg th1: {:.4f}, essg th2: {:.4f}, essg th3: {:.4f}, essg th4: {:.4f}, essg autoth: {:.4f} \n'.format(
        th_stop_draft_essg1, th_stop_draft_essg2, th_stop_draft_essg3, th_stop_draft_essg4, th_stop_draft_essg_autoth))
        result_base = infer_input_ids(model, tokenizer, input_ids, generate_fn='base',
                    max_new_tokens=512, do_sample=False, early_stop=True)
        result_essg1 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', 
                    max_new_tokens=512, early_stop=True, max_step_draft=12, 
                    th_stop_draft=th_stop_draft_essg1, auto_th_stop_draft=False,
                    do_sample=False, do_sample_draft=False)
        result_essg2 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', 
                    max_new_tokens=512, early_stop=True, max_step_draft=12, 
                    th_stop_draft=th_stop_draft_essg2, auto_th_stop_draft=False,
                    do_sample=False, do_sample_draft=False)
        result_essg3 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', 
                    max_new_tokens=512, early_stop=True, max_step_draft=12, 
                    th_stop_draft=th_stop_draft_essg3,  auto_th_stop_draft=False,
                    do_sample=False, do_sample_draft=False)
        result_essg4 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', 
                    max_new_tokens=512, early_stop=True, max_step_draft=12, 
                    th_stop_draft=th_stop_draft_essg4,  auto_th_stop_draft=False,
                    do_sample=False, do_sample_draft=False)
        result_essg_autoth = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', 
                    max_new_tokens=512, early_stop=True, max_step_draft=12, 
                    th_stop_draft=th_stop_draft_essg_autoth, auto_th_stop_draft=True, auto_parameters=[1,0.50,0.90,1e-2,0.90],
                    do_sample=False, do_sample_draft=False)

        if len(result_base['completion']) < 5 or ('.....' in result_base['completion'][:5]):
            print("too short, skip")
            continue
        
        if task_name == 'xsum':
            references = x['summary']
        elif task_name =='cnndm':
            references = x['highlights']
            
        results = [
            ('base', result_base),
            ('essg1', result_essg1),
            ('essg2', result_essg2),
            ('essg3', result_essg3),
            ('essg4', result_essg4),
            ('essg_autoth', result_essg_autoth)
        ]

        for key, result in results:
            main_metrics['time_' + key].append(result['time'])
            main_metrics['token_time_' + key].append(result['time'] / result['generate_ids'].shape[1])
            if key != 'base':
                main_metrics['matchness_' + key].append(result['matchness'])
                main_metrics['num_drafted_tokens_' + key].append(result['num_drafted_tokens'])
            clip_pred = result['completion'].find("\nArticle:")
            if clip_pred > 0:
                prediction = result['completion'][:clip_pred]
            else:
                prediction = result['completion']
            rouge_score = rouge.score(prediction, references)
            main_metrics['rouge2_' + key].append(rouge_score['rouge2'].fmeasure)

        metric = {
            'mean rouge-2 base':np.mean(main_metrics['rouge2_base']),
            f'mean rouge-2 essg th {th_stop_draft_essg1}':np.mean(main_metrics['rouge2_essg1']),
            f'mean rouge-2 essg th {th_stop_draft_essg2}':np.mean(main_metrics['rouge2_essg2']),
            f'mean rouge-2 essg th {th_stop_draft_essg3}':np.mean(main_metrics['rouge2_essg3']),
            f'mean rouge-2 essg th {th_stop_draft_essg4}':np.mean(main_metrics['rouge2_essg4']),
            'mean rouge-2 essg autoth':np.mean(main_metrics['rouge2_essg_autoth']),
            'mean time base':np.mean(main_metrics['time_base']),
            f'mean time essg th {th_stop_draft_essg1}':np.mean(main_metrics['time_essg1']),
            f'mean time essg th {th_stop_draft_essg2}':np.mean(main_metrics['time_essg2']),
            f'mean time essg th {th_stop_draft_essg3}':np.mean(main_metrics['time_essg3']),
            f'mean time essg th {th_stop_draft_essg4}':np.mean(main_metrics['time_essg4']),
            'mean time essg autoth':np.mean(main_metrics['time_essg_autoth']),
            f'E2E mean speed up essg th {th_stop_draft_essg1}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg1']),
            f'E2E mean speed up essg th {th_stop_draft_essg2}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg2']),
            f'E2E mean speed up essg th {th_stop_draft_essg3}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg3']),
            f'E2E mean speed up essg th {th_stop_draft_essg4}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg4']),
            'E2E mean speed up essg autoth':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg_autoth']),
            'mean token time base':np.mean(main_metrics['token_time_base']),
            f'mean token time essg th {th_stop_draft_essg1}':np.mean(main_metrics['token_time_essg1']),
            f'mean token time essg th {th_stop_draft_essg2}':np.mean(main_metrics['token_time_essg2']),
            f'mean token time essg th {th_stop_draft_essg3}':np.mean(main_metrics['token_time_essg3']),
            f'mean token time essg th {th_stop_draft_essg4}':np.mean(main_metrics['token_time_essg4']),
            'mean token time essg autoth':np.mean(main_metrics['token_time_essg_autoth']),  
            f'E2E mean token speed up essg th {th_stop_draft_essg1}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg1']),
            f'E2E mean token speed up essg th {th_stop_draft_essg2}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg2']),
            f'E2E mean token speed up essg th {th_stop_draft_essg3}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg3']),
            f'E2E mean token speed up essg th {th_stop_draft_essg4}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg4']),
            'E2E mean token speed up essg autoth':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg_autoth']),          
            f'mean matchness essg th {th_stop_draft_essg1}':np.mean(main_metrics['matchness_essg1']),
            f'mean matchness essg th {th_stop_draft_essg2}':np.mean(main_metrics['matchness_essg2']),
            f'mean matchness essg th {th_stop_draft_essg3}':np.mean(main_metrics['matchness_essg3']),
            f'mean matchness essg th {th_stop_draft_essg4}':np.mean(main_metrics['matchness_essg4']),
            'mean matchness essg autoth':np.mean(main_metrics['matchness_essg_autoth']),
            f'mean num_drafted_tokens essg th {th_stop_draft_essg1}':np.mean(main_metrics['num_drafted_tokens_essg1']),
            f'mean num_drafted_tokens essg th {th_stop_draft_essg2}':np.mean(main_metrics['num_drafted_tokens_essg2']),
            f'mean num_drafted_tokens essg th {th_stop_draft_essg3}':np.mean(main_metrics['num_drafted_tokens_essg3']),
            f'mean num_drafted_tokens essg th {th_stop_draft_essg4}':np.mean(main_metrics['num_drafted_tokens_essg4']),
            'mean num_drafted_tokens essg autoth':np.mean(main_metrics['num_drafted_tokens_essg_autoth']),
        }
        for key, value in metric.items():
            if isinstance(value, float):
                metric[key] = f"{value:.4f}"

        # print(f'data {i},{metric}')
        f.write(f'data {i},{metric} \n')
        f.flush()