In [None]:
!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7
!pip install scipy
!pip install tensorboard
!pip install huggingface_hub
!huggingface-cli login --token '##############'
!pip install json5

In [None]:

from tqdm import tqdm
import os
import torch
from datasets import load_dataset, Dataset, load_from_disk
import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

import json5
import string
import re
import csv

In [None]:
def load_model_tokenizer(model_name, adapter_name, quantization=False):
    if quantization:
      bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype="float16",
      bnb_4bit_use_double_quant=False)
    else:
      bnb_config = None

    model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config, device_map="auto")
    if adapter_name:
      model = PeftModel.from_pretrained(model, adapter_name, device_map="auto")

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    return model, tokenizer


def predict_response(text):
  inputs = tokenizer(text, return_tensors="pt").to(device)
  outputs = model.generate(**inputs, max_new_tokens=100)
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return response


def normalize_answer(text):
    if text:
        punc = string.punctuation
        text = text.lower()
        return ''.join(char for char in text if char not in punc)
    else:
        return None

def extract_final_answer(text):
    if text.find("}") != -1 and text.find("{") != -1:
        answer = text[text.find("{") : text.find("}")+1]
    elif text.find("}") == -1 and text.find("{") != -1:
        answer = f"""
        {text[text.find("{") : ]} }}
        """
    elif "impossible to answer" in text:
        return "impossible to answer"

    try:
        answer = json5.loads(answer)["answer"]
    except:
        answer = None
    return answer

In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

# Fine-tuned model name
adapter_name = "TANK/Llama-2-7b-chat-hf-squad2_v2"

# Base model
model, tokenizer = load_model_tokenizer(
    model_name=model_name,
    adapter_name=adapter_name,
    quantization=True,
)

device = torch.device('cuda:0')
model.to(device)

In [None]:
dataset = load_from_disk('data/squad_v2/validation')

In [None]:
SYSTEM_PROMPT = """\
<s>[INST] <<SYS>>\n
You are a helpful, respectful, and honest assistant. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
You are given a context and a question, and your task is to answer the question using the content provided in the context.
If the context does not provide content to answer the question, answer: "impossible to answer"
If the context does not provide the content to answer the question,
If you don't know the answer to a question, please don't share false information.
If a question does not make sense, explain why instead of answering something incorrectly.

Think step by step and explain your reasoning, then answer in JSON format as follows:
```json
{
  "answer": ...
}
```
\n<</SYS>>\n\n
"""

dataset_prompted = dataset.map(lambda example: {'text': SYSTEM_PROMPT + f"""Context: {example['context']} \n\nQuestion: {example['question']} [/INST]\n
```json
{{"answer": {example["answers"]["text"]}}}```
</s>"""})
print(dataset_prompted[0]['text'])

In [None]:
# Benchmark with only context
save_path = 'benchmark_fine_tuned.csv'
with open(save_path, "w") as file:
    writer = csv.writer(file)
    writer.writerow(["Context", "Question", "Answer", "Prediction", "Full prediction"])
    for i in tqdm(range(len(dataset_prompted['context']))):
        context = dataset_prompted[i]['context']
        question = dataset_prompted[i]['question']
        if len(dataset_prompted[i]['answers']['text']) == 0:
            answer = 'impossible to answer'
        else:
            answer = normalize_answer(dataset_prompted[i]['answers']['text'][0])

        ins_index = dataset_prompted[i]['text'].find('[/INST]\n')
        prompt = dataset_prompted[i]['text'][:ins_index+8] + '\n</s>'

        full_prediction = predict_response(prompt)

        answer_start_index = full_prediction.find('[/INST]\n')
        prediction = full_prediction[answer_start_index+9:]

        prediction = extract_final_answer(prediction)
        prediction = normalize_answer(str(prediction))

        writer.writerow(
            [
                context,
                question,
                answer,
                prediction,
                full_prediction
            ]
            )
        file.flush()

