In [1]:
import pandas as pd
import os

from datasets import YagoECQ
from utils import format_query, extract_name_from_yago_uri

In [2]:
uri = "reverse-http://schema.org/leader"
kb_name, relation = extract_name_from_yago_uri(uri)
cc = YagoECQ(subname=f"{kb_name}:{relation}", query_id=uri, entity_types=["entities", "gpt_fake_entities"], max_contexts=100)

Failed to load entities, answers, contexts, and queries from paths None, None, None, and None.
Manually reconstructing dataset and saving to aforementioned paths.


In [3]:
qe_df = cc.get_contexts_per_query_entity_df()
print(qe_df.info())
print(qe_df.head())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4000 entries, 0 to 3999
Data columns (total 5 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   q_id        4000 non-null   object
 1   query_form  4000 non-null   object
 2   entity      4000 non-null   object
 3   answer      4000 non-null   object
 4   contexts    4000 non-null   object
dtypes: object(5)
memory usage: 156.4+ KB
None
                               q_id  \
0  reverse-http://schema.org/leader   
1  reverse-http://schema.org/leader   
2  reverse-http://schema.org/leader   
3  reverse-http://schema.org/leader   
4  reverse-http://schema.org/leader   

                                   query_form               entity  \
0  Q: Is {entity} the leader of {answer}?\nA:     (Étienne Blanc,)   
1  Q: Is {entity} the leader of {answer}?\nA:  (Laurent Wauquiez,)   
2  Q: Is {entity} the leader of {answer}?\nA:     (Bernard Piras,)   
3  Q: Is {entity} the leader of {answer}?\nA:   (Ime

In [4]:
qc_df = cc.get_entities_per_query_context_df()
print(qc_df.info())
print(qc_df.head())


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   q_id        200 non-null    object
 1   query_form  200 non-null    object
 2   context     200 non-null    object
 3   entities    200 non-null    object
dtypes: object(4)
memory usage: 6.4+ KB
None
                               q_id  \
0  reverse-http://schema.org/leader   
1  reverse-http://schema.org/leader   
2  reverse-http://schema.org/leader   
3  reverse-http://schema.org/leader   
4  reverse-http://schema.org/leader   

                                   query_form  \
0  Q: Is {entity} the leader of {answer}?\nA:   
1  Q: Is {entity} the leader of {answer}?\nA:   
2  Q: Is {entity} the leader of {answer}?\nA:   
3  Q: Is {entity} the leader of {answer}?\nA:   
4  Q: Is {entity} the leader of {answer}?\nA:   

                                             context  \
0         Xiaohui Zhao is

In [5]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import torch

MODEL_ID = "EleutherAI/pythia-70m-deduped"
LOAD_IN_8BIT = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID, load_in_8bit=LOAD_IN_8BIT, device_map="auto"
    )
except:
    print(f"Failed to load model {MODEL_ID} in 8-bit. Attempting to load normally.")
    model = GPTNeoXForCausalLM.from_pretrained(
        MODEL_ID,
        load_in_8bit=False,
    ).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    padding_side="left",
)

if tokenizer.padding_side != "left":
        raise ValueError(
            f"Expected tokenizer {tokenizer} to have padding side of `left` for batch generation, instead has padding side of `{tokenizer.padding_side}`. Please make sure you initialize the tokenizer to use left padding."
        )

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if model.config.pad_token_id != model.config.eos_token_id:
    print("Setting model.config.pad_token_id to model.config.eos_token_id")
    model.config.pad_token_id = model.config.eos_token_id

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


Setting model.config.pad_token_id to model.config.eos_token_id


In [6]:
from enum import Enum
AnswerType = Enum("AnswerType", ["ORIGINAL", "CONTEXT", "OTHER"])

In [7]:
row = qe_df.iloc[0]
queries = [format_query(row["query_form"], row["entity"], context, answer=row["answer"]) for context in row["contexts"]] # shape: (len(contexts),)
# tokens = tokenizer(
#     queries,
#     padding=True,
#     return_tensors="pt",
# ).to(model.device)
# max_output_length = 8
# output_tokens = model.generate(**tokens, max_length=len(tokens["input_ids"][0]) + max_output_length)[:, -max_output_length:]
# output_tokens.shape
# outputs = tokenizer.batch_decode(output_tokens[:, -max_output_length:]) # shape: (len(contexts),)


In [75]:
def is_answer_original_or_context(output, original_answer, context_answer):
    if output.strip().startswith(original_answer):
        return "original_answer"
    elif output.strip().startswith(context_answer):
        return "context_answer"
    else:
        return "other"

In [70]:
row["q_id"]

'reverse-http://schema.org/leader'

In [21]:
import re
def construct_regex_pattern_for_entity_and_answer(template):
    # Patterns for entity and answer placeholders
    entity_pattern = r"(?:.+)" # non matching group
    answer_pattern = r"(.*?)" # matching group
    
    # Escape special characters in the template, then replace placeholders
    template_escaped = re.escape(template)
    template_with_patterns = template_escaped.replace("\\{entity\\}", entity_pattern).replace("\\{answer\\}", answer_pattern)
    
    # The final regex pattern captures the answer
    regex_pattern = template_with_patterns + r"(?=\.\n|$)"
    return regex_pattern

# Example template and sentences to test
template_with_answer = "{entity} is the leader of {answer}.\n"
sentences_with_answer = [
    "Nandita Bose is the leader of the Stanishev Cabinet.\n",
    "Prue Hackett is the leader of Innlandet.\n",
    "Amintore Fanfani is the leader of Hainan.\n",
]

# Execute the updated function and extract matches
matches_with_answer = []
regex_pattern_with_answer = construct_regex_pattern_for_entity_and_answer(template_with_answer)
for sentence in sentences_with_answer:
    match = re.search(regex_pattern_with_answer, sentence)
    if match:
        matches_with_answer.append(match.group(1))  # group(1) corresponds to the {answer} capture
    else:
        matches_with_answer.append("No match found")

matches_with_answer


['the Stanishev Cabinet', 'Innlandet', 'Hainan']

In [22]:
import re

def construct_regex_pattern_for_template_general(template):
    # Directly replace '{entity}' with a regex pattern that matches any series of word characters and spaces
    entity_pattern = r"[\w\s]+"
    # Escape special characters in the template, then replace '{entity}' placeholder
    template_escaped = re.escape(template).replace("\\{entity\\}", entity_pattern)
    # Adjust the regex pattern to capture everything up to ".\n"
    regex_pattern = template_escaped + r" (.*?)(?=\.\n)"
    return regex_pattern

def extract_answer(query_template: str, sentence: str, regex_constructor):
    regex_pattern_corrected = regex_constructor(query_template)
    match = re.search(regex_pattern_corrected, sentence)
    if match:
        return match.group(1)
    
    print("No match found ahhh", sentence)
    return None

import json
with open("../data/YagoECQ/yago_qec.json", "rb") as f:
    yago_qec = json.load(f)
    
answers = [extract_answer(yago_qec[row["q_id"]]["context_templates"]["base"], c, construct_regex_pattern_for_entity_and_answer) for c in row["contexts"]]
# answers = [extract_answer(yago_qec[row["q_id"]]["query_forms"]["open"][1], c, construct_regex_pattern_for_template_general) for c in row["contexts"]]
answers

['Ahmedabad',
 'Hainan',
 'Finland',
 'Kansas City Southern',
 'Agricultural University of Berlin',
 'Université de Montréal',
 'Morelábor',
 'Médecins Sans Frontières',
 'Southwest Papua',
 'Inca Empire',
 'Beijing',
 'East Francia',
 'Montevideo Department',
 'East Francia',
 'Kursk Oblast',
 'Cabinet Santkohi',
 'Nikšić Municipality',
 'Nobel Prize Museum',
 'Fosun International Limited',
 '2nd FitzGerald ministry',
 'Andorra la Vella',
 'Palau',
 'Luhansk',
 'Te Pāti Māori',
 'Grand Paris',
 'General Electric',
 'College Board',
 'Heilongjiang',
 'Boulogne-Billancourt',
 'Cazeneuve Ministry',
 'MetLife',
 'The Catholic University of America',
 'Second Lubbers cabinet',
 "Alexander Stubb's cabinet",
 'Commonwealth realm of Uganda',
 'Cagnes-sur-Mer',
 'Yanka Kupala National Academic Theatre',
 'Cabinet of Hassan Diab',
 'Fukui Prefecture',
 "Dunkin' Brands",
 'Municipio Libertador',
 'Anguilla',
 'Raykov Government',
 'WEHCO Media',
 'Norderstedt',
 'Publicis',
 'Oscorp',
 'Finnish 

In [76]:
og_or_ctx_answers = [
    is_answer_original_or_context(
        output, 
        "Yes" if "{answer}" in row["query_form"] else row["entity"][0], 
        "No" if "{answer}" in row["query_form"] else answers[i], 
    ) 
    for i, output in enumerate(outputs)
]
og_or_ctx_answers

['original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'other',
 'original_answer',
 'original_answer',
 'original_answer',
 'other',
 'original_answer',
 'original_answer',
 'original_answer',
 'other',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'other',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'other',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original_answer',
 'original

In [46]:
!pwd

/home/kevin/code/rycolab/measureLM/preprocessing


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [48]:
import json
with open("../data/YagoECQ/yago_qec.json", "rb") as f:
    yago_qec = json.load(f)

In [54]:
yago_qec[row["q_id"]]["query_forms"]["open"][1]

'{entity} is the leader of'