In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import time
import random
from typing import List
import spacy
import openai
import numpy as np
import wandb
from datasets import load_dataset
from mega.data.load_datasets import load_xnli_dataset
from mega.data.data_utils import choose_few_shot_examples
from mega.prompting.instructions import INSTRUCTIONS
from mega.prompting.prompting_utils import load_prompt_template
from mega.utils.env_utils import load_env
# from mega.models.completion_models import get_model_pred, gpt3x_completion
from mega.models.tag_models import gpt3x_tagger
from mega.models.completion_models import gpt3x_completion
from mega.prompting.prompting_utils import construct_prompt, construct_qa_prompt, construct_tagging_prompt
from mega.data.load_datasets import load_tagging_dataset
from seqeval.metrics import f1_score
from tqdm.notebook import tqdm
from evaluate import load

# Set seed
random.seed(42)
np.random.seed(42)

In [4]:
# Make sure that {env_name}.env file is present in the envs/ directory
env_name = "melange"
load_env(env_name=env_name)

In [5]:
openai.api_version = "2023-03-15-preview"
openai.api_version

'2023-03-15-preview'

In [6]:
openai.api_base

'https://gpttesting1.openai.azure.com/'

In [27]:
model = "gpt-35-turbo-deployment"
pivot_lang = "fr"
tgt_lang = "fr"
prompt_name = "structure_prompting_chat"
few_shot_k = 8
dataset = "udpos"
# short_contexts = False
max_tokens = 100

In [28]:
config = {
    "model" : model,
    "pivot_lang": pivot_lang,
    "tgt_lang": tgt_lang,
    "prompt_name": prompt_name,
    "few_shot_k": few_shot_k,
    "dataset": dataset,
    "max_tokens": max_tokens
}

# wandb.init(project="GPT-4-eval", entity="scai-msri", config=config)

In [29]:
class SpacySentenceTokenizer:
    
    def __init__(self):
        self.nlp = spacy.load('xx_ent_wiki_sm')
        self.nlp.add_pipe("sentencizer")
        
    def __call__(self, text: str) -> List[str]:
        return list(map(lambda span: span.text, self.nlp(text).sents))


In [30]:
# def load_tagging_dataset(
#     dataset: str,
#     lang: str,
#     split: str,
#     dataset_frac: float = 1.0,
#     xtreme_dir: str = "xtreme/download",
#     delimiter: str = "_",
# ):

#     split = "dev" if split == "validation" else split

#     filename = f"{xtreme_dir}/{dataset}/{split}-{lang}.tsv"
#     inputs, labels = read_conll_data(filename)

#     dataset = Dataset.from_dict({"tokens": inputs, "tags": labels})
#     dataset = dataset.map(
#         lambda example: {
#             "tagged_tokens": [f"{token}{delimiter}{tag}"
#             for token, tag in zip(example["tokens"], example["tags"])]
#         }
#     )

#     N = len(dataset)
#     selector = np.arange(int(N * dataset_frac))
#     return dataset.select(selector)


In [31]:
train_dataset = load_tagging_dataset(dataset,
                                lang = pivot_lang,
                                split="dev")
test_dataset = load_tagging_dataset(dataset,
                                lang = tgt_lang,
                                split="test")

Map:   0%|          | 0/2902 [00:00<?, ? examples/s]

Map:   0%|          | 0/5008 [00:00<?, ? examples/s]

In [32]:
train_dataset[0]

{'tokens': ['FLEXIBLE'], 'tags': ['ADJ'], 'tagged_tokens': ['FLEXIBLE_ADJ']}

In [33]:
train_examples = choose_few_shot_examples(
        train_dataset, few_shot_k, selection_criteria="random")

In [71]:
PROMPTS_DICT = {
    "structure_prompting": """C: {context}\nT: {tagged}""",
    "structure_prompting_chat": """Tag the following sentence: "{context}"\n{tagged}"""
}

In [72]:
prompt_template = PROMPTS_DICT[prompt_name]

In [73]:
# Loading instruction for the task
instruction = INSTRUCTIONS[dataset]
print(instruction)

You are an NLP assistant whose purpose is to perform Part of Speech (PoS) Tagging. PoS tagging is the process of marking up a word in a text (corpus) as corresponding to a particular part of speech, based on both its definition and its context. You will need to use the tags defined below:
    1. ADJ: adjective
    2. ADP: adposition
    3. ADV: adverb
    4. AUX: auxiliary
    5. CCONJ: coordinating-conjunction
    6. DET: determiner
    7. INTJ: interjection
    8. NOUN: noun
    9. NUM: numeral
    10. PART: particle
    11. PRON: pronoun
    12. PROPN: proper-noun
    13. PUNCT: punctuation
    14. SCONJ: subordinating-conjunction
    15. SYM: symbol
    16. VERB: verb
    17. X: other


In [37]:
train_examples

[{'tokens': ['En',
   '1925',
   ',',
   'il',
   'adhère',
   'au',
   'Parti',
   'communiste',
   'palestinien',
   ';'],
  'tags': ['ADP',
   'NUM',
   'PUNCT',
   'PRON',
   'VERB',
   'ADP',
   'NOUN',
   'ADJ',
   'ADJ',
   'PUNCT'],
  'tagged_tokens': ['En_ADP',
   '1925_NUM',
   ',_PUNCT',
   'il_PRON',
   'adhère_VERB',
   'au_ADP',
   'Parti_NOUN',
   'communiste_ADJ',
   'palestinien_ADJ',
   ';_PUNCT']},
 {'tokens': ['Gibamond',
   'est',
   "l'",
   'un',
   'des',
   'trois',
   'commandants',
   'des',
   'corps',
   "d'",
   'armée',
   'vandales',
   'présent',
   'à',
   'la',
   'bataille',
   'de',
   "l'",
   'Ad',
   'Decimum',
   '.'],
  'tags': ['PROPN',
   'AUX',
   'DET',
   'PRON',
   'ADP',
   'NUM',
   'NOUN',
   'ADP',
   'NOUN',
   'ADP',
   'NOUN',
   'ADJ',
   'ADJ',
   'ADP',
   'DET',
   'NOUN',
   'ADP',
   'DET',
   'PROPN',
   'PROPN',
   'PUNCT'],
  'tagged_tokens': ['Gibamond_PROPN',
   'est_AUX',
   "l'_DET",
   'un_PRON',
   'des_ADP',
   'tro

In [74]:
valid_labels = set()
for example in train_examples:
    valid_labels.update(example["tags"])
valid_labels = list(valid_labels)
valid_labels

['VERB',
 'ADP',
 'DET',
 'ADJ',
 'INTJ',
 'CCONJ',
 'ADV',
 'NUM',
 'AUX',
 'PUNCT',
 'PROPN',
 'NOUN',
 'PRON',
 'SCONJ']

In [81]:
test_example = test_dataset[186]

prompt, label = construct_tagging_prompt(
    train_examples,
    test_example,
    prompt_template=prompt_template,
    chat_prompt=True,
    instruction="Do not try to answer the question. Just tag each token in the sentence."
)
prompt

[{'role': 'system',
  'content': 'Do not try to answer the question. Just tag each token in the sentence.'},
 {'role': 'user',
  'content': 'Tag the following sentence: "En 1925 , il adhère au Parti communiste palestinien ;"'},
 {'role': 'assistant',
  'content': 'En_ADP 1925_NUM ,_PUNCT il_PRON adhère_VERB au_ADP Parti_NOUN communiste_ADJ palestinien_ADJ ;_PUNCT'},
 {'role': 'user',
  'content': 'Tag the following sentence: "Gibamond est l\' un des trois commandants des corps d\' armée vandales présent à la bataille de l\' Ad Decimum ."'},
 {'role': 'assistant',
  'content': "Gibamond_PROPN est_AUX l'_DET un_PRON des_ADP trois_NUM commandants_NOUN des_ADP corps_NOUN d'_ADP armée_NOUN vandales_ADJ présent_ADJ à_ADP la_DET bataille_NOUN de_ADP l'_DET Ad_PROPN Decimum_PROPN ._PUNCT"},
 {'role': 'user',
  'content': 'Tag the following sentence: "enfant je je lisais tout"'},
 {'role': 'assistant',
  'content': 'enfant_NOUN je_PRON je_PRON lisais_VERB tout_PRON'},
 {'role': 'user',
  'conte

In [82]:
test_example["tokens"]

['Donner',
 'le',
 'nom',
 "d'",
 'une',
 'équipe',
 'pour',
 'laquelle',
 'Tim',
 'Crews',
 'a',
 'joué',
 '.']

In [83]:
preds = gpt3x_completion(
    prompt,
    model,
    test_example["tokens"],
    temperature=0,
    max_tokens=100
)

In [84]:
preds

"Donner_VERB le_DET nom_NOUN d'_ADP une_DET équipe_NOUN pour_ADP laquelle_PRON Tim_PROPN Crews_PROPN a_AUX joué_VERB ._PUNCT"

In [85]:
" ".join(test_example["tagged_tokens"])

"Donner_VERB le_DET nom_NOUN d'_ADP une_DET équipe_NOUN pour_ADP laquelle_PRON Tim_PROPN Crews_PROPN a_AUX joué_VERB ._PUNCT"

Bad pipe message: %s [b'c1', b'F\x88\xe8C\xac\xe6\x97\xcf\xb1\xffv\xd4\xa2r \xae\xf3!>#\x17.\xbd\x84\xa8v\x16\xf1d\x7f\xbb\x163\x9c\x1c\xc6\xfb)\xab,\x9a\x99$\xd4!\x85\x89\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$']
Bad pipe message: %s [b'L\xfau\x03\x033\x803;\xc7O\xd0\n\xaa\xd9\xc9\x1b\xe6 \xa1\xae\x9c\xb6\xa8%Lu \x93`\xdfl\xd9Nj:A\x9eS\x91\xab\xfd8\xce\x89\x00\x8fRk\x82R\x00', b'\x02\x13\x03\x13\x01\x00\xff', b'']
Bad pipe message: %s [b'c\xff\x90\xa1\xeb<\xcf\xf6\x1dT\xdf\x99\x1d\xaf\xf9Xa.\x00\x00\xa2\xc0\x14\xc0\n\x009\x008\x007\x006\x00\x88\x00\x87\x00\x86\x00\x85\xc0\x19\x00:\x00\x89\xc0\x0f\xc0\x05\

Bad pipe message: %s [b'\xe3\x01\x8f\x8at\nH\x88r\xa1\xa0FH\xe7\x1f\x10\xbc\x1c K\x9b\xd2\x94\x18\xe9\xd3^X=_\x99\xc4\xf9\x86\xe1g\n\xf32\x06\xb7$\xdeZr\x80\xb0\xc5\x0fQH\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04']
Bad pipe message: %s [b'\x83']
Bad pipe message: %s [b'`W\xdfr\xbc+\xef\xca\xedp\xe2\xbbM\xaf\xd8 \xb1,\x87/\xdfh\xc2,\xdd#\xde\xd5 V\x0f\xee;B\x03\x01\x86\xf7\xc6\x87$fLB\x9bd`\x86\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00']
Bad pipe message: %s [b'\x00\t127.0.0.1']
Bad pipe message: %s [b'W\x83\x99~\nI\t\rT\xc8E\x16\x08\x1f\x96\x19\xd4\x10\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\

Bad pipe message: %s [b'\xed\xf1\xd2)\x1c\x9d\xb8S"\x9e\xd9\xb16pT\x7f\xf0\xe9 \xa6\x9a\xce\x06\xfc\x82\xf9\xab\xba!\xfaWR\x82wIUw\x8c\xf8m\xe74\x1e.\xbc\xa5\x0cp\x14\xb3\x98\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 U\x0e\xd1']
Bad pipe message: %s [b"\xe5iJ\xb1z:\x0e\xaf\xa1\x02\xe5\x06\x15\xe5\xe0\xa7^P\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\x

In [20]:
test_example = test_dataset[0]
prompt, label = construct_tagging_prompt(
    train_examples,
    test_example,
    prompt_template=prompt_template,
    chat_prompt=False,
    instruction=instruction
)
print(prompt)

C: James Graham était le chef du clan Graham .
T: James_PROPN Graham_PROPN était_AUX le_DET chef_NOUN du_ADP clan_NOUN Graham_PROPN ._PUNCT

C: C' est là qu' à 26 ans , il devient le plus jeune chef récompensé par deux étoiles .
T: C'_PRON est_AUX là_ADV qu'_SCONJ à_ADP 26_NUM ans_NOUN ,_PUNCT il_PRON devient_VERB le_DET plus_ADV jeune_ADJ chef_NOUN récompensé_VERB par_ADP deux_NUM étoiles_NOUN ._PUNCT

C: La formation de craquelures est appelée craquelage ( crazing en anglais ) .
T: La_DET formation_NOUN de_ADP craquelures_NOUN est_AUX appelée_VERB craquelage_NOUN (_PUNCT crazing_NOUN en_ADP anglais_NOUN )_PUNCT ._PUNCT

C: l' incidence des fractures en % ( IC )
T: l'_DET incidence_NOUN des_ADP fractures_NOUN en_ADP %_NOUN (_PUNCT IC_PROPN )_PUNCT

C: Pour remédier à cela , les Chinois conduisaient d' abord le gaz dans un grand réservoir en bois de forme conique , placé 3 m sous le niveau du sol , où un autre conduit amenait l' air .
T: Pour_ADP remédier_VERB à_ADP cela_PRON ,_PUNCT l

In [26]:
preds = gpt3x_tagger(
    prompt,
    model,
    test_example["tokens"],
    one_shot_tag=False,
    temperature=0,
    max_tokens=5
)

> [0;32m/home/t-kabirahuja/work/repos/MultilingualBlanketEval/mega/models/tag_models.py[0m(116)[0;36mpredict_tag[0;34m()[0m
[0;32m    114 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 116 [0;31m        [0;32mreturn[0m [0mresponse[0m[0;34m[[0m[0;34m"choices"[0m[0;34m][0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;34m"text"[0m[0;34m][0m[0;34m.[0m[0mstrip[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0msplit[0m[0;34m([0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;32mdef[0m [0mpredict_one_shot[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> response["choices"][0]["text"].strip()
'PRON est_AUX -'
ipdb> response["choices"][0]["text"].strip().split()[0]
'PRON'
ipdb> 

In [24]:
preds

['PRON',
 'ce_PRON',
 "qu'_PRON",
 'une_DET',
 'aide_NOUN',
 'au_logement_NOUN',
 'logement_NOUN',
 '?',
 'PUNCT']

In [84]:
test_example["tags"]

['PRON', 'AUX', 'PRON', 'SCONJ', 'DET', 'NOUN', 'ADP', 'NOUN', 'PUNCT']

In [26]:
get_model_pred(
    train_examples,
    test_example,
    prompt_template,
    verbalizer={},
    model=model,
    chat_prompt=True,
    instruction=instruction,
    one_shot_tag=True,
    max_tokens=max_tokens
    
)

{'prediction': ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B-LOC',
  'O',
  'O',
  'O',
  'O',
  'B-LOC',
  'O',
  'O',
  ''],
 'ground_truth': ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B-LOC',
  'O',
  'O',
  'O',
  'O',
  'B-LOC',
  'O',
  'O',
  'O']}

In [27]:
preds = [pred if pred != "" else np.random.choice(valid_labels) for pred in preds]

In [28]:
print(preds)

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'I-PER']


In [29]:
print(f"Prediction: {preds}")
print(f"Label: {label}")

f1_score([preds], [label])

# prediction = {"prediction_text": pred, "id": test_example["id"]}
# reference = {}
# reference["answers"] = test_example["answers"]
# reference["id"] = test_example["id"]
# results = squad_metric.compute(
#             predictions=[prediction],
#             references=[reference]
#         )

Prediction: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'I-PER']
Label: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O']


0.8

In [31]:
f1_sum = 0
em_sum = 0
avg_em = 0
avg_f1 = 0

run_details = {"num_calls": 0}

pbar = tqdm(enumerate(test_dataset.select(range(1000))))

for i, test_example in pbar:    
    prompt, label = construct_tagging_prompt(
        train_examples,
        test_example,
        prompt_template=prompt_template,
        chat_prompt=True,
        instruction=instruction
    )
    preds = gpt3x_tagger(
        prompt,
        model,
        test_example["tokens"],
        one_shot_tag=True,
        temperature=0,
        max_tokens=100
    )
    preds = [pred if pred != "" else np.random.choice(valid_labels) for pred in preds]
    f1_sum += f1_score([preds], [label])
        
    avg_f1 = f1_sum / (i+1)
    
#     wandb.log({"f1": avg_f1})
#     wandb.log(run_details)
    pbar.set_description(f"f1: {avg_f1}")
#     time.sleep(1/2)

0it [00:00, ?it/s]



KeyboardInterrupt: 