In [1]:
import re

In [2]:
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import argparse
import logging
import torch

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)


from peft import PeftModel, PeftConfig

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Eval the finetued SFT model")
    parser.add_argument(
        "--model_name_or_path_baseline",
        type=str,
        help="Path to baseline model",
    )
    parser.add_argument(
        "--model_name_or_path_finetune",
        type=str,
        help="Path to pretrained model",
    )
    parser.add_argument(
        "--num_beams",
        type=int,
        default=1,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--num_beam_groups",
        type=int,
        default=1,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=4,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--penalty_alpha",
        type=float,
        default=0.6,
        help='Specify num of beams',
    )
    parser.add_argument(
        "--num_return_sequences",
        type=int,
        default=1,
        help='Specify num of return sequences',
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help='Specify num of return sequences',
    )
    parser.add_argument("--language",
                        type=str,
                        default="English",
                        choices=["English", "Chinese", "Japanese"])

    args = parser.parse_args()

    return args


def generate(model,
             tokenizer,
             inputs,
             num_beams=1,
             num_beam_groups=1,
             do_sample=False,
             num_return_sequences=1,
             max_new_tokens=100):

    generate_ids = model.generate(inputs.input_ids,
                                  num_beams=num_beams,
                                  num_beam_groups=num_beam_groups,
                                  do_sample=do_sample,
                                  num_return_sequences=num_return_sequences,
                                  max_new_tokens=max_new_tokens)

    result = tokenizer.batch_decode(generate_ids,
                                    skip_special_tokens=True,
                                    clean_up_tokenization_spaces=False)
    return result

def print_utils(gen_output):
    for i in range(len(gen_output)):
        print()
        print(gen_output[i])
        print()


def prompt_eval(baseline_model, baseline_tokenizer,finetuned_model,finetuned_tokenizer, device,
                prompts):
        prompt = prompts['prompt']
        dialogue = prompts['dialogue']
        

        b_inputs = baseline_tokenizer(prompt + dialogue, return_tensors="pt").to(device)
        f_inputs = finetuned_tokenizer(prompt + dialogue, return_tensors="pt").to(device) 

        print(dialogue)

        print("----------Baseline--------------------")
        b_output = generate(baseline_model,
                          baseline_tokenizer,
                          b_inputs,
                          num_beams=2,
                          num_return_sequences=1,
                          max_new_tokens=100)
        print(b_output[0].replace(prompt, '').replace(dialogue, ''))
        print("----------finetune------------------------")
        f_output = generate(finetuned_model,
                                finetuned_tokenizer,
                                f_inputs,
                                num_beams=2,
                                num_return_sequences=1,
                                max_new_tokens=100)
        print(f_output[0].replace(prompt, '').replace(dialogue, ''))
        # Note: we use the above simplest greedy search as the baseline. Users can also use other baseline methods,
        # such as beam search, multinomial sampling, and beam-search multinomial sampling.
        # We provide examples as below for users to try.

        # print("==========finetune: Multinomial sampling=========")
        # r_finetune_m = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=1,
        #                         do_sample=True,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_m)
        # print("==========finetune: Beam Search=========")
        # r_finetune_b = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=args.num_beams,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_b)
        # print("==========finetune: Beam-search multinomial sampling=========")
        # r_finetune_s = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=args.num_beams,
        #                         do_sample=True,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_s)
        # print("==========finetune: Diverse Beam Search=========")
        # r_finetune_d = generate(model_fintuned, tokenizer, inputs,
        #                         num_beams=args.num_beams,
        #                         num_beam_groups=args.num_beam_groups,
        #                         num_return_sequences=args.num_return_sequences,
        #                         max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_d)
        # print("==========finetune: Constrastive Search=========")
        # r_finetune_c = generate_constrastive_search(model_fintuned, tokenizer, inputs,
        #                                             top_k=args.top_k,
        #                                             penalty_alpha=args.penalty_alpha,
        #                                             num_return_sequences=args.num_return_sequences,
        #                                             max_new_tokens=args.max_new_tokens)
        # print_utils(r_finetune_c)
        print("====================prompt end=============================")

def get_model(config, model_path, tokenizer):

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        from_tf=bool(".ckpt" in model_path),
        config=config,
    )
    model.resize_token_embeddings(len(tokenizer))

    # prepare the tokenizer and model config
    tokenizer.pad_token = tokenizer.eos_token
    model.config.end_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id

    return model

def make_hf_model(path, device):

    tokenizer = AutoTokenizer.from_pretrained(path, fast_tokenizer=True)
    config = AutoConfig.from_pretrained(path)
    model = get_model(config, path, tokenizer)

    model.eval()
    model.to(device)

    return model, tokenizer

def peft_model(path, device):

    config = PeftConfig.from_pretrained(path)

    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path,fast_tokenizer=True)

    model= AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)

    model = PeftModel.from_pretrained(model, path)

    model.to(device)
    model.eval()
    
    return model, tokenizer


def main():

    device = torch.device("cuda:0")

    path = 'google/gemma-2b' 

    baseline_model, baseline_tokenizer = make_hf_model(path,device)

    peft_path = '/home/chanho/Model/SHARE/Refactorizing/result/output_path/2024-05-16-00.58.46/peft_checkpoint-14000'

    peft_model, peft_tokenizer = peft_model(peft_path, device)

    # One observation: if the prompt ends with a space " ", there is a high chance that
    # the original model (without finetuning) will stuck and produce no response.
    # Finetuned models have less such issue. Thus following prompts all end with ":"
    # to make it a more meaningful comparison.
    
    prompts = {'prompt' : '' , 'dialogue': ''}

    prompt_eval(baseline_model, baseline_tokenizer,peft_model, peft_tokenizer, device,
                prompts)

In [3]:
device = torch.device("cuda:0")

path = 'google/gemma-2b' 

baseline_model, baseline_tokenizer = make_hf_model(path,device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
peft_path = '/home/chanho/Model/SHARE/Refactorizing/result/output_path/2024-05-16-00.58.46/peft_checkpoint-14000'

peft_model, peft_tokenizer = peft_model(peft_path, device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [109]:
prompts = {'prompt' : '' , 'dialogue': ''}
prompts['prompt'] = "\nTask: Generate the next response in a dialogue by focusing on the contextual cues detailed within parentheses in the dialogue history. Responses should be tailored according to the type of cue provided:\n\n1. Memory-driven dialogues: If the cue within parentheses details specific character traits or background context, craft responses that reflect these memory-driven elements, ensuring character consistency and rich context.\n2. Everyday language dialogues: If the cue within parentheses is labeled \"Everyday Language,\" generate responses that are based on typical day-to-day interactions, free from specific personas or detailed context.\n\n"


text ={"text": "\nTask: Generate the next response in a dialogue by focusing on the contextual cues detailed within parentheses in the dialogue history. Responses should be tailored according to the type of cue provided:\n\n1. Memory-driven dialogues: If the cue within parentheses details specific character traits or background context, craft responses that reflect these memory-driven elements, ensuring character consistency and rich context.\n2. Everyday language dialogues: If the cue within parentheses is labeled \"Everyday Language,\" generate responses that are based on typical day-to-day interactions, free from specific personas or detailed context.\n\n**Dialogue History**:\nBRANDT: (BRANDT deals with operations and finances in his role , BRANDT has a dismissive opinion of someone who relies on his father's money) \"... Senator, the failure of one operation shouldn't cause your committee to question financing everything else we're doing down there... ... I know it looks bad, and I appreciate your support. Together we'll get it done... Yeah. 'bye. Without his father's money, that asshole'd be keeping bees for a living...\"\nBRANDT: (Everyday Language) \"What? No.\"\nUPDEGRAF: (UPDEGRAF relays communications from others such as Mr. Pitt) \"He's called every day.\"\nBRANDT: (Everyday Language) \"I don't need it.\"\nUPDEGRAF: (UPDEGRAF is an intermediary or assistant to BRANDT) \"Mr. Pitt, Mr. Brandt'll have to get back to you.\"\n\n\n"}


text = text['text']
match = re.search(r"\*\*Dialogue History\*\*:\s*(.+)", text, re.DOTALL)

text = match.group(1).strip()

prompts['dialogue'] = text

In [110]:
prompt_eval(baseline_model, baseline_tokenizer, peft_model, peft_tokenizer, device, prompts)

BRANDT: (BRANDT deals with operations and finances in his role , BRANDT has a dismissive opinion of someone who relies on his father's money) "... Senator, the failure of one operation shouldn't cause your committee to question financing everything else we're doing down there... ... I know it looks bad, and I appreciate your support. Together we'll get it done... Yeah. 'bye. Without his father's money, that asshole'd be keeping bees for a living..."
BRANDT: (Everyday Language) "What? No."
UPDEGRAF: (UPDEGRAF relays communications from others such as Mr. Pitt) "He's called every day."
BRANDT: (Everyday Language) "I don't need it."
UPDEGRAF: (UPDEGRAF is an intermediary or assistant to BRANDT) "Mr. Pitt, Mr. Brandt'll have to get back to you."
----------Baseline--------------------

BRANDT: (Everyday Language) "I don't need it."
UPDEGRAF: (Everyday Language) "Mr. Pitt, Mr. Brandt'll have to get back to you."
BRANDT: (Everyday Language) "I don't need it."
UPDEGRAF: (Everyday Language) "Mr