In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import time
from typing import List
import spacy
import openai
import numpy as np
import wandb
from datasets import load_dataset
from mega.data.load_datasets import load_xnli_dataset
from mega.data.data_utils import choose_few_shot_examples
from mega.prompting.instructions import INSTRUCTIONS
from mega.prompting.prompting_utils import load_prompt_template
from mega.utils.env_utils import load_env
from mega.models.completion_models import get_model_pred, gpt3x_completion
from mega.prompting.prompting_utils import construct_prompt, construct_qa_prompt
from tqdm.notebook import tqdm
from evaluate import load

In [4]:
# Make sure that {env_name}.env file is present in the envs/ directory
env_name = "gpt4v2"
load_env(env_name=env_name)

In [5]:
openai.api_version

'2023-03-15-preview'

In [6]:
model = "gpt-4-32k"
pivot_lang = "en"
tgt_lang = "ta"
prompt_name = "answer_given_context_and_question"
few_shot_k = 0
dataset = "indicqa"
short_contexts = False
max_tokens = 20

In [7]:
config = {
    "model" : model,
    "pivot_lang": pivot_lang,
    "tgt_lang": tgt_lang,
    "prompt_name": prompt_name,
    "few_shot_k": few_shot_k,
    "dataset": dataset,
    "short_contexts": short_contexts,
    "max_tokens": max_tokens
}

# wandb.init(project="GPT-4-eval", entity="scai-msri", config=config)

In [8]:
class SpacySentenceTokenizer:
    
    def __init__(self):
        self.nlp = spacy.load('xx_ent_wiki_sm')
        self.nlp.add_pipe("sentencizer")
        
    def __call__(self, text: str) -> List[str]:
        return list(map(lambda span: span.text, self.nlp(text).sents))


In [9]:
def load_qa_dataset(dataset_name, lang, split, dataset_frac = 1, translate_test = False):
    if dataset_name == "indicqa":
        if split != "train":
            dataset = load_dataset("ai4bharat/IndicQA", f"indicqa.{lang}")[split]
        else:
            dataset = load_dataset("squad")[split]
    elif dataset_name == "xquad":
        if split != "train":
            dataset = load_dataset("xquad", f"xquad.{lang}")[split]
        else:
            dataset = load_dataset("squad")[split]
    elif dataset_name == "tydiqa":
        dataset = load_dataset("tydiqa", 'secondary_task')[split]
        dataset = dataset.map(lambda example: {"lang" : TYDIQA_LANG2CODES[example["id"].split("-")[0]]})
        dataset = dataset.filter(lambda example: example["lang"] == lang)
    elif dataset_name == "mlqa":
        if split == "train":
            print("No Training Data for MLQA, switching to validation!")
            split = "validation"
        if translate_test:
            dataset_name = f"mlqa-translate-test.{lang}"
        else:
            dataset_name = f"mlqa.{lang}.{lang}"
        
        dataset = load_dataset("mlqa", dataset_name)[split]
    
    else:
        raise NotImplementedError()
    N = len(dataset)
    selector = np.arange(int(N * dataset_frac))
    return dataset.select(selector)

In [10]:
train_dataset = load_qa_dataset(dataset,
                                lang = pivot_lang,
                                split="train")
test_dataset = load_qa_dataset(dataset,
                                lang = tgt_lang,
                                split="validation")

Found cached dataset squad (/home/t-kabirahuja/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset indic_qa (/home/t-kabirahuja/.cache/huggingface/datasets/ai4bharat___indic_qa/indicqa.ta/1.0.0/f410c3a04e1e13303ea2e04267c0767261a938879f5ad7abf5ea57610444b55f)


  0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
if short_contexts:
    sent_tokenizer = SpacySentenceTokenizer() 

    train_dataset = train_dataset.map(lambda example: {
        "context": [sent for sent in sent_tokenizer(example["context"]) if example["answers"]["text"][0] in sent][0]
    }, num_proc = 8)

In [12]:
train_examples = choose_few_shot_examples(
        train_dataset, few_shot_k, selection_criteria="random")

In [13]:
PROMPTS_DICT = {
    "answer_given_context_and_question" : """{context}
    Q: {question}

    Referring to the passage above, the correct answer to the given question is:
    {answer}""",
    
    "lang_instruct_answer_given_context_and_question" : """{context}
    Q: {question}

    Referring to the passage above, the correct answer to the given question is? Please try to answer in {language} and ensure that the answer appears as it is in the passage.
    A: {answer}""",
    
}

In [14]:
prompt_template = PROMPTS_DICT[prompt_name]

In [15]:
# Loading instruction for the task
instruction = INSTRUCTIONS["xquad"]
print(instruction)

You are an NLP assistant whose purpose is to solve reading comprehension problems. You will be provided questions on a set of passages and you will need to provide the answer as it appears in the passage. The answer should be in the same language as the question and the passage.


In [16]:
squad_metric = load("squad")

In [17]:
test_example = test_dataset[132]

prompt, label = construct_qa_prompt(
    train_examples,
    test_example,
    train_prompt_template=prompt_template,
    test_prompt_template=prompt_template,
    chat_prompt=True,
    instruction=instruction
)
prompt

[{'role': 'system',
  'content': 'You are an NLP assistant whose purpose is to solve reading comprehension problems. You will be provided questions on a set of passages and you will need to provide the answer as it appears in the passage. The answer should be in the same language as the question and the passage.'},
 {'role': 'user',
  'content': '1962 ல் பத்மஸ்ரீ விருது வழங்கப்படுவதற்கு கால் நூற்றாண்டுக்கு முன்பே இந்திய அரசால் அன்னை தெரேசா அடையாளங்காணப்பட்டுள்ளார். 1972-ல், பன்னாட்டு புரிந்துணர்வுக்கான ஜவகர்லால் நேரு விருது, 1980-ல் இந்தியாவின் உயரிய குடிமக்கள் விருதான பாரத ரத்னா உட்பட இந்திய உயர்விருதுகளை அடுத்த பத்தாண்டுகளில் பெற்றார். அவரது அதிகாரபூர்வ வாழ்க்கைச்சரித்திரம், இந்திய ஆட்சிப் பணியாளரான நவீன் சாவ்லாவால் எழுதப்பட்டு 1992இல் வெளியிடப்பட்டது. அன்னை தெரசாவைப் பற்றிய எல்லா இந்தியாரும் உயர்வாகப் பார்க்கவில்லை. கல்கத்தாவில் பிறந்து லண்டனில் வாழ்ந்து கொண்டிருக்கும் அவரது விமர்சகரான அரூப் ச்சேட்டர்ஜி அவர் வாழ்ந்த காலத்தில் கல்கத்தாவின் முக்கிய அங்கமாக இருக்கவில்லையெனக் குறிப்பிட்

In [18]:
pred = gpt3x_completion(
    prompt,
    model,
    temperature=0,
    max_tokens=20
)

In [19]:
print(f"Prediction: {pred}")
print(f"Label: {label}")
prediction = {"prediction_text": pred, "id": test_example["id"]}
reference = {}
reference["answers"] = test_example["answers"]
reference["id"] = test_example["id"]
results = squad_metric.compute(
            predictions=[prediction],
            references=[reference]
        )
print(results)

Prediction: 1962
Label: 1962
{'exact_match': 100.0, 'f1': 100.0}


Bad pipe message: %s [b"\xb4\xf1\xdaR\x86@\x03\x1b\x12|\xabR\x91\xbb-\x82\x1c\xe2 \xbc\x87'\x83\xdaR$\x07\x85\n\x7f\xc7\xb1r\x9d\xaf\xb4,\xd7Q\x94\xe3B\xc1\xe0\n\x91\x99\x9a6\xea<\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127"]
Bad pipe message: %s [b'.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00']
Bad pipe message: %s [b'\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01']
Bad pipe message: %s [b"\xd3\x13\x122\xee\xa2?x&8\x95\x81\xb3\xb0\xb8\xc4\x01\x9a\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0

In [33]:
f1_sum = 0
em_sum = 0
avg_em = 0
avg_f1 = 0

run_details = {"num_calls": 0}

pbar = tqdm(enumerate(test_dataset))

for i, test_example in pbar:    
    prompt, label = construct_qa_prompt(
        train_examples,
        test_example,
        train_prompt_template=prompt_template,
        test_prompt_template=prompt_template,
        chat_prompt=True,
        instruction=instruction
    )
    pred = gpt3x_completion(
        prompt,
        model,
        temperature=0,
        run_details=run_details,
        max_tokens=max_tokens
    )
    prediction = {"prediction_text": pred, "id": test_example["id"]}
    reference = {}
    reference["answers"] = test_example["answers"]
    reference["id"] = test_example["id"]
    results = squad_metric.compute(
                predictions=[prediction],
                references=[reference])
    f1_sum += results["f1"]
    em_sum += results["exact_match"]
        
    avg_f1 = f1_sum / (i+1)
    avg_em = em_sum / (i+1)
    
    wandb.log({"f1": avg_f1, "em": avg_em}, step = i + 1)
    wandb.log(run_details, step = i + 1)
    pbar.set_description(f"em: {avg_em} f1: {avg_f1}")
    time.sleep(1/2)

0it [00:00, ?it/s]

Bad pipe message: %s [b'\xfc\x9c\xd3\xeaO\x18\x90%\xef\x9a\x90']
Bad pipe message: %s [b'\xc4M\x18\xdfo\xe7\xd8z\x11-*s\x1a\xfb|\xbeR0\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0']
Bad pipe message: %s [b"I\x0cA\xd1qb\x99\xb7\x8d\xd9\xf5\xa24\xa8e\x96J\xb6\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00\xc0\x00<\x00\xba\x005\x00\x84\x00/\x00\x96\x00A\x00\x05\x00\n\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x

Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/t-kabirahuja/work/repos/MultilingualBlanketEval/envs/megaenv/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 100, in run
    sreq = self._sock_client.read_server_request()
  File "/home/t-kabirahuja/work/repos/MultilingualBlanketEval/envs/megaenv/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 274, in read_server_request
    data = self._read_packet_bytes()
  File "/home/t-kabirahuja/work/repos/MultilingualBlanketEval/envs/megaenv/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 248, in _read_packet_bytes
    rec = self._extract_packet_bytes()
  File "/home/t-kabirahuja/work/repos/MultilingualBlanketEval/envs/megaenv/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 230, in _extract_packet_bytes
    assert magic == ord("W")
AssertionError
Exception in thr