In [None]:
!pip install -qqq -U wandb --progress-bar off
import wandb
from huggingface_hub import login
from google.colab import userdata

login(userdata.get('HF_TOKEN'))

wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

In [None]:
!pip install -q -U git+https://github.com/huggingface/transformers.git --progress-bar off
!pip install -q -U git+https://github.com/huggingface/accelerate.git --progress-bar off
!pip install datasets evaluate --progress-bar off
!pip install -q -U bitsandbytes --progress-bar off
!pip install -q -U git+https://github.com/huggingface/peft.git --progress-bar off

In [None]:
from transformers import AutoTokenizer

base_model_id = "microsoft/phi-2"#"microsoft/phi-3-mini-4k-instruct"#"microsoft/phi-2"
# eval tokenizer does not have eos token and padding
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    add_bos_token=True,
    trust_remote_code=True,
    use_fast=False,
    #truncate=True,
    #padding_side="left", # https://ai.stackexchange.com/questions/41485/while-fine-tuning-a-decoder-only-llm-like-llama-on-chat-dataset-what-kind-of-pa
)
#tokenizer.pad_token = tokenizer.eos_token

In [None]:
from datasets import load_dataset
import evaluate

split = "validation"#"test"#
mrqa_eval = load_dataset("enriquesaou/mrqa-squadded-sample", split=split)

In [None]:
# source: https://github.com/mrqa/MRQA-Shared-Task-2019/blob/master/mrqa_official_eval.py

import string
import re
import json
import gzip
from collections import Counter

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def read_predictions(prediction_file):
    with open(prediction_file) as f:
        predictions = json.load(f)
    return predictions


def read_answers(gold_file):
    answers = {}
    with gzip.open(gold_file, 'rb') as f:
        for i, line in enumerate(f):
            example = json.loads(line)
            if i == 0 and 'header' in example:
                continue
            for qa in example['qas']:
                answers[qa['qid']] = qa['answers']
    return answers


def evaluate_predictions(answers, predictions, skip_no_answer=False):
    f1 = exact_match = total = 0

    for qid, ground_truths in answers.items():
        if qid not in predictions:
            if not skip_no_answer:
                message = 'Unanswered question %s will receive score 0.' % qid
                print(message)
                total += 1
            continue
        total += 1
        prediction = predictions[qid]
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(
            f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [None]:
from peft import LoraConfig

config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["Wqkv", "fc1", "fc2"], #="all-linear",
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

In [None]:
def format_cqa(context, question):
    return "Answer the question extracting from the context below.\nContext: " + context + "\nQuestion: " + question + "\nAnswer: "

In [None]:
def tokenize_and_generate(test_model, prompt, new_tokens=16):
    inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False).to('cuda')
    with torch.no_grad():
        outputs = test_model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, max_new_tokens=new_tokens)
        answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    return answer

In [None]:
models_to_evaluate = ["enriquesaou/phi-2-mrqa"]

In [None]:
import torch
from tqdm.auto import tqdm
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


new_tok = 5

for model_id in models_to_evaluate:

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )

    # load peft if not base
    if model_id != base_model_id:
      # note that base_model may be modified in place
      model_for_eval = PeftModel.from_pretrained(base_model, model_id)
    else: model_for_eval = base_model

    model_for_eval.to('cuda').eval()

    all_predictions = {}
    for example in tqdm(mrqa_eval):
      prompt = format_cqa(example['context'], example['question'])
      outs = tokenize_and_generate(model_for_eval, prompt, new_tokens=new_tok)
      outs = outs.replace(prompt, '')
      outs = outs.split('Answer:')[1] if 'Answer:' in outs else outs
      all_predictions[example['id']] = outs.strip()


    # compute metrics
    answers = mrqa_eval.to_dict()
    answers = {id: aws['text'] for id, aws in zip(answers['id'], answers['answers'])}
    metrics = evaluate_predictions(answers, predictions=all_predictions)

    print(model_id, split, json.dumps(metrics))

    """
    for k in answers.keys():
        print(all_predictions[k], answers[k])
        metrics = evaluate_predictions({k: answers[k]}, {k: all_predictions[k]})
        print(json.dumps(metrics))
    """


# Test token length

In [None]:
import torch
from tqdm.auto import tqdm
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


models_to_evaluate = ["enriquesaou/phi2-mrqa"]

mrqa_eval = mrqa_eval.shuffle(seed=27).select(range(150))

tk = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 40, 50]
em = []
f1 = []

for model_id in models_to_evaluate:

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    # load peft if not base
    if model_id != base_model_id:
      # note that base_model may be modified in place
      model_for_eval = PeftModel.from_pretrained(base_model, model_id)
    else:
      model_for_eval = base_model

    model_for_eval.eval()

    for new_tok in tk:
      all_predictions = {}
      for example in tqdm(mrqa_eval):
        prompt = format_cqa(example['context'], example['question'])
        outs = tokenize_and_generate(model_for_eval, prompt, new_tokens=new_tok).replace(prompt, '')
        outs = outs.split('Answer:')[1] if 'Answer:' in outs else outs
        all_predictions[example['id']] = outs.strip()


      # compute metrics
      answers = mrqa_eval.to_dict()
      answers = {qid: aws['text'] for qid, aws in zip(answers['id'], answers['answers'])}
      metrics = evaluate_predictions(answers, predictions=all_predictions)

      em.append(metrics['exact_match'])
      f1.append(metrics['f1'])

      print(model_id, json.dumps(metrics))

      """
      for k in answers.keys():
          print(all_predictions[k], answers[k])
          metrics = evaluate_predictions({k: answers[k]}, {k: all_predictions[k]})
          print(json.dumps(metrics))
      """


In [None]:
print('em',em)
print('tk',tk)
print('f1',f1)

em [0.0, 6.0, 16.666666666666668, 22.666666666666668, 21.333333333333332, 17.333333333333332, 10.0, 4.0, 2.0, 0.6666666666666666, 0.6666666666666666, 0.6666666666666666, 1.3333333333333333, 0.6666666666666666, 1.3333333333333333, 0.6666666666666666, 0.6666666666666666, 0.6666666666666666, 0.6666666666666666]
tk [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 40, 50]
f1 [0.020833333333333332, 15.39861111111111, 34.326172438672415, 44.17516788766787, 46.310329022829, 45.94180125430125, 44.630497742997704, 40.63119149369145, 37.634488196988165, 34.20407384010324, 29.894118819776697, 26.823178809208216, 24.60567481063679, 22.43074010270301, 19.86675099408612, 16.641160022233066, 15.1014561754907, 13.43372937169635, 13.721398427479853]



In [None]:
import matplotlib.pyplot as plt

fontsize = 12

plt.figure(figsize=(10, 5))

plt.plot(tk, em, label='EM', marker='o')
plt.plot(tk, f1, label='F1', marker='o')

plt.xlabel('Generated answer length (# of new tokens)', fontsize=fontsize)
plt.ylabel('Score', fontsize=fontsize)

plt.xticks([1, 4, 7, 9, 12, 16, 20, 25, 30,40,50], fontsize=fontsize)

plt.yticks(fontsize=fontsize)

plt.xlim(0)

plt.legend(fontsize=fontsize)
plt.show()
