# Analysis1-Prompt Selection

Studying the influence of language-specific prompt tuning on the final results

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
import os
import argparse
import sys
import time
import random
import json
import wandb
from tqdm.notebook import tqdm
import numpy as np
from promptsource.templates import Template, DatasetTemplates
from mega.data.load_datasets import load_xnli_dataset, load_xnli_translate_test
from mega.data.data_utils import choose_few_shot_examples
from mega.eval.eval_cls import evaluate_model
from mega.prompting.prompting_utils import load_prompt_template
from mega.prompting.instructions import INSTRUCTIONS
from mega.utils.parser import parse_args
from mega.utils.env_utils import load_env
from mega.prompting.create_lang_prompts import add_prompt_to_dataset

In [4]:
load_env("gpt4v2")

In [5]:
model = "gpt-35-tunro"

In [6]:
MAX_VAL_SIZE = 500
K = 8
TEMPERATURE = 0

## XNLI

In [7]:
langs = ["sw", "ur"]

In [8]:
lang2train_dataset = {
    lang: load_xnli_dataset(lang, split = "train")
    for lang in langs
}

lang2val_dataset = {
    lang: load_xnli_dataset(lang, split = "validation").select(list(range(MAX_VAL_SIZE)))
    for lang in langs
}

lang2test_dataset = {
    lang: load_xnli_dataset(lang, split = "test")
    for lang in langs
}

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/sw/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/ur/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/sw/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/ur/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/sw/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset xnli (/home/t-kabirahuja/.cache/huggingface/datasets/xnli/ur/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd)


  0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
prompt_names = [
#     "take the following as truth",
    "does this imply",
    "GPT-3 style",
    "based on the previous passage",
    "guaranteed true",
    "should assume",
    "must be true",
    "can we infer",
    "justified in saying",
    "claim true/false/inconclusive",
    "consider always/sometimes/never",
    "always/sometimes/never",
    "guaranteed/possible/impossible",
    "MNLI crowdsource"
]

In [18]:
def add_prompts_to_lang_xnli(lang, prompt_name):
    
    prompt_template = load_prompt_template(
        "en", prompt_name, dataset="xnli"
    )
    tgt_prompt_dataset = DatasetTemplates(f"xnli/{lang}")
    add_prompt_to_dataset(
            tgt_prompt_dataset,
            prompt_template,
            lang,
            "en",
            translate=False,
        )

In [11]:
for lang in langs:
    for prompt_name in prompt_names:
        add_prompts_to_lang_xnli(lang, prompt_name)

In [19]:
lang2prompt2acc = {}

for lang in tqdm(langs):
    lang2prompt2acc[lang] = {}
    for prompt_name in tqdm(prompt_names):
        prompt_template = load_prompt_template(
            lang, prompt_name, dataset="xnli"
        )
        acc = evaluate_model(
            train_dataset = lang2train_dataset[lang],
            test_dataset = lang2val_dataset[lang],
            train_prompt_template = prompt_template,
            test_prompt_template = prompt_template,
            model = model,
            few_shot_size = K,
            selection_criteria = "random",
            chat_prompt = True,
            instruction=INSTRUCTIONS.get("xnli", ""),
            save_preds_path = None,
            num_evals_per_sec = 2,
            temperature = TEMPERATURE
        )
        lang2prompt2acc[lang][prompt_name] = acc

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]



  0%|                                                                                                                | 0/500 [00:00<?, ?it/s][A[A

Accuracy: 1.0:   0%|                                                                                                 | 0/500 [00:00<?, ?it/s][A[A

Accuracy: 1.0:   0%|▏                                                                                        | 1/500 [00:00<06:39,  1.25it/s][A[A

Accuracy: 1.0:   0%|▏                                                                                        | 1/500 [00:01<06:39,  1.25it/s][A[A

Accuracy: 1.0:   0%|▎                                                                                        | 2/500 [00:01<06:27,  1.28it/s][A[A

Accuracy: 1.0:   0%|▎                                                                                        | 2/500 [00:02<06:27,  1.28it/s][A[A

Accuracy: 1.0:   1%|▌                                                                                   

KeyboardInterrupt: 

In [33]:
lang2prompt2acc

{'sw': {}}