In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import time
from typing import List
import spacy
import openai
import numpy as np
import wandb
from datasets import load_dataset
from mega.data.load_datasets import load_xnli_dataset
from mega.data.data_utils import choose_few_shot_examples
from mega.prompting.instructions import INSTRUCTIONS
from mega.prompting.prompting_utils import load_prompt_template
from mega.utils.env_utils import load_openai_env_variables
from mega.models.completion_models import get_model_pred, gpt3x_completion
from mega.prompting.prompting_utils import construct_prompt, construct_qa_prompt
from tqdm.notebook import tqdm
from evaluate import load

In [4]:
# Make sure that {env_name}.env file is present in the envs/ directory
env_name = "vellm"
load_env(env_name=env_name)

In [5]:
# openai.api_version = "2023-03-15-preview"
openai.api_version

'2023-03-15-preview'

In [6]:
openai.api_base

'https://vellmapi.openai.azure.com/'

In [7]:
model = "gptturbo"
pivot_lang = "en"
tgt_lang = "ta"
prompt_name = "answer_given_context_and_question"
few_shot_k = 0
dataset = "indicqa"
short_contexts = False
max_tokens = 20

In [8]:
config = {
    "model" : model,
    "pivot_lang": pivot_lang,
    "tgt_lang": tgt_lang,
    "prompt_name": prompt_name,
    "few_shot_k": few_shot_k,
    "dataset": dataset,
    "short_contexts": short_contexts,
    "max_tokens": max_tokens
}

wandb.init(project="GPT-4-eval", entity="scai-msri", config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkabirahuja2431[0m ([33mscai-msri[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
class SpacySentenceTokenizer:
    
    def __init__(self):
        self.nlp = spacy.load('xx_ent_wiki_sm')
        self.nlp.add_pipe("sentencizer")
        
    def __call__(self, text: str) -> List[str]:
        return list(map(lambda span: span.text, self.nlp(text).sents))


In [10]:
def load_qa_dataset(dataset_name, lang, split, dataset_frac = 1, translate_test = False):
    if dataset_name == "indicqa":
        if split != "train":
            dataset = load_dataset("ai4bharat/IndicQA", f"indicqa.{lang}")[split]
        else:
            dataset = load_dataset("squad")[split]
    elif dataset_name == "xquad":
        if split != "train":
            dataset = load_dataset("xquad", f"xquad.{lang}")[split]
        else:
            dataset = load_dataset("squad")[split]
    elif dataset_name == "tydiqa":
        dataset = load_dataset("tydiqa", 'secondary_task')[split]
        dataset = dataset.map(lambda example: {"lang" : TYDIQA_LANG2CODES[example["id"].split("-")[0]]})
        dataset = dataset.filter(lambda example: example["lang"] == lang)
    elif dataset_name == "mlqa":
        if split == "train":
            print("No Training Data for MLQA, switching to validation!")
            split = "validation"
        if translate_test:
            dataset_name = f"mlqa-translate-test.{lang}"
        else:
            dataset_name = f"mlqa.{lang}.{lang}"
        
        dataset = load_dataset("mlqa", dataset_name)[split]
    
    else:
        raise NotImplementedError()
    N = len(dataset)
    selector = np.arange(int(N * dataset_frac))
    return dataset.select(selector)

In [11]:
train_dataset = load_qa_dataset(dataset,
                                lang = pivot_lang,
                                split="train")
test_dataset = load_qa_dataset(dataset,
                                lang = tgt_lang,
                                split="validation")

Found cached dataset squad (/home/t-kabirahuja/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset indic_qa (/home/t-kabirahuja/.cache/huggingface/datasets/ai4bharat___indic_qa/indicqa.ta/1.0.0/f410c3a04e1e13303ea2e04267c0767261a938879f5ad7abf5ea57610444b55f)


  0%|          | 0/1 [00:00<?, ?it/s]

In [12]:
if short_contexts:
    sent_tokenizer = SpacySentenceTokenizer() 

    train_dataset = train_dataset.map(lambda example: {
        "context": [sent for sent in sent_tokenizer(example["context"]) if example["answers"]["text"][0] in sent][0]
    }, num_proc = 8)

In [13]:
train_examples = choose_few_shot_examples(
        train_dataset, few_shot_k, selection_criteria="random")

In [14]:
PROMPTS_DICT = {
    "answer_given_context_and_question" : """{context}
    Q: {question}

    Referring to the passage above, the correct answer to the given question is:
    {answer}""",
    
    "lang_instruct_answer_given_context_and_question" : """{context}
    Q: {question}

    Referring to the passage above, the correct answer to the given question is? Please try to answer in {language} and ensure that the answer appears as it is in the passage.
    A: {answer}""",
    
}

In [15]:
prompt_template = PROMPTS_DICT[prompt_name]

In [16]:
# Loading instruction for the task
instruction = INSTRUCTIONS["xquad"]
print(instruction)

You are an NLP assistant whose purpose is to solve reading comprehension problems. You will be provided questions on a set of passages and you will need to provide the answer as it appears in the passage. The answer should be in the same language as the question and the passage.


In [17]:
squad_metric = load("squad")

In [18]:
test_example = test_dataset[132]

prompt, label = construct_qa_prompt(
    train_examples,
    test_example,
    train_prompt_template=prompt_template,
    test_prompt_template=prompt_template,
    chat_prompt=True,
    instruction=instruction
)
prompt

[{'role': 'system',
  'content': 'You are an NLP assistant whose purpose is to solve reading comprehension problems. You will be provided questions on a set of passages and you will need to provide the answer as it appears in the passage. The answer should be in the same language as the question and the passage.'},
 {'role': 'user',
  'content': '1962 ல் பத்மஸ்ரீ விருது வழங்கப்படுவதற்கு கால் நூற்றாண்டுக்கு முன்பே இந்திய அரசால் அன்னை தெரேசா அடையாளங்காணப்பட்டுள்ளார். 1972-ல், பன்னாட்டு புரிந்துணர்வுக்கான ஜவகர்லால் நேரு விருது, 1980-ல் இந்தியாவின் உயரிய குடிமக்கள் விருதான பாரத ரத்னா உட்பட இந்திய உயர்விருதுகளை அடுத்த பத்தாண்டுகளில் பெற்றார். அவரது அதிகாரபூர்வ வாழ்க்கைச்சரித்திரம், இந்திய ஆட்சிப் பணியாளரான நவீன் சாவ்லாவால் எழுதப்பட்டு 1992இல் வெளியிடப்பட்டது. அன்னை தெரசாவைப் பற்றிய எல்லா இந்தியாரும் உயர்வாகப் பார்க்கவில்லை. கல்கத்தாவில் பிறந்து லண்டனில் வாழ்ந்து கொண்டிருக்கும் அவரது விமர்சகரான அரூப் ச்சேட்டர்ஜி அவர் வாழ்ந்த காலத்தில் கல்கத்தாவின் முக்கிய அங்கமாக இருக்கவில்லையெனக் குறிப்பிட்

In [19]:
pred = gpt3x_completion(
    prompt,
    model,
    temperature=0,
    max_tokens=20
)

In [20]:
print(f"Prediction: {pred}")
print(f"Label: {label}")
prediction = {"prediction_text": pred, "id": test_example["id"]}
reference = {}
reference["answers"] = test_example["answers"]
reference["id"] = test_example["id"]
results = squad_metric.compute(
            predictions=[prediction],
            references=[reference]
        )
print(results)

Prediction: 1962.
Label: 1962
{'exact_match': 100.0, 'f1': 100.0}


In [21]:
f1_sum = 0
em_sum = 0
avg_em = 0
avg_f1 = 0

run_details = {"num_calls": 0}

pbar = tqdm(enumerate(test_dataset))

for i, test_example in pbar:    
    prompt, label = construct_qa_prompt(
        train_examples,
        test_example,
        train_prompt_template=prompt_template,
        test_prompt_template=prompt_template,
        chat_prompt=True,
        instruction=instruction
    )
    pred = gpt3x_completion(
        prompt,
        model,
        temperature=0,
        run_details=run_details,
        max_tokens=max_tokens
    )
    prediction = {"prediction_text": pred, "id": test_example["id"]}
    reference = {}
    reference["answers"] = test_example["answers"]
    reference["id"] = test_example["id"]
    results = squad_metric.compute(
                predictions=[prediction],
                references=[reference])
    f1_sum += results["f1"]
    em_sum += results["exact_match"]
        
    avg_f1 = f1_sum / (i+1)
    avg_em = em_sum / (i+1)
    
    wandb.log({"f1": avg_f1, "em": avg_em}, step = i+1)
    wandb.log(run_details, step = i+1)
    pbar.set_description(f"em: {avg_em} f1: {avg_f1}")
    time.sleep(1/2)

0it [00:00, ?it/s]