In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,default_data_collator
import torch
from peft import prepare_model_for_kbit_training,LoraConfig,get_peft_model,PeftModel
from transformers import GenerationConfig
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 0] = 1.0
        return scores


def get_logits_processor() -> LogitsProcessorList:
    logits_processor = LogitsProcessorList()
    logits_processor.append(InvalidScoreLogitsProcessor())
    return logits_processor

abs_path = "/data/home/chenpz/git_clone_project"
model_path = f"{abs_path}/All_base_model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/d3aa29f914761e8ea0298051fbaf8dd173e94db5"
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype= None
)


model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=nf4_config, device_map = "cuda:0")
model = PeftModel.from_pretrained(
                    model, f"{abs_path}/LLaMA-Factory/saves/llama3_5000k_anli_r3/checkpoint-30", is_trainable=False
                    )
tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=False,split_special_tokens=False,
                                          padding_side="left",
                                          **{'trust_remote_code': True, 'cache_dir': None, 'revision': 'main', 'use_auth_token': None})
tokenizer.pad_token = '<|eot_id|>'

In [2]:
prompt = \
'''<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'''

In [None]:
from datasets import load_dataset

data_path = "/data/home/chenpz/git_clone_project/nlpData/p3/anli_r3_json_file/anli_r3_vaild.json"
dataset = load_dataset("json", data_files= data_path)

def preprocess_supervised_data(examples):
    model_inputs = {'input_ids': [],
                    'attention_mask':[],
                    # 'prompt':[]
                    }
    for  instruction in examples['instruction']:
           text = prompt.format(instruction = instruction)
           res = tokenizer(text,padding='max_length',max_length=400)
           model_inputs['input_ids'].append(res['input_ids'])
           model_inputs['attention_mask'].append(res['attention_mask'])
          #  model_inputs['prompt'].append(text)
    return model_inputs

dataset2 = dataset.map(preprocess_supervised_data,batched=True,remove_columns=['output', 'input', 'instruction'],num_proc=16)
print(tokenizer.decode(dataset2['train'][0]['input_ids']))


In [None]:
g_config = GenerationConfig(**{
  "do_sample": True,
  "max_new_tokens": 10,
  "eos_token_id": [
    128001,
    128009
  ]
})

from torch.utils.data import DataLoader
eval_dataloader = DataLoader(dataset2['train'],batch_size=20, pin_memory=True,collate_fn=default_data_collator,shuffle=False)

from tqdm import tqdm
device = 'cuda:0'
output_file = '/data/home/chenpz/git_clone_project/LLaMA-Factory/saves/llama3_5000k_anli_r3/vaild_predict_result/output.txt'
model.eval()
with open(output_file, 'w') as file:  # Open the file in write mode before the loop
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model.generate(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                generation_config=g_config,
                logits_processor=get_logits_processor()
            )
        # eval_preds.extend(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
        for item in tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True):
            final_pred = item.split("assistant")[1]
            # print(final_pred)
            # final_pred = item.split("[/INST] ")[1]
            file.write('%s\n' % final_pred)  # Write each generated text to the file after generation