In [None]:
#%pip install -r requirements-synthie.txt

In [2]:
DATA_DIR="/ceph/mlautenb/synthIE/data"
MODELS_DIR="/ceph/mlautenb/synthIE/models"

In [3]:
import os
import sys
import pathlib
synthie_root = pathlib.Path('/home/mlautenb/SynthIE').resolve()
ciexmas_root = pathlib.Path('/home/mlautenb/CIExMAS').resolve()
sys.path.insert(0, str(synthie_root))
sys.path.insert(0, str(ciexmas_root))


"""Load the Model (downloaded in the ../data/models directory)"""
from src.models import GenIEFlanT5PL

ckpt_name = "genie_base_sc.ckpt"
path_to_checkpoint = os.path.join(MODELS_DIR, ckpt_name)
model = GenIEFlanT5PL.load_from_checkpoint(checkpoint_path=path_to_checkpoint)
model.to("cuda");



In [4]:
"""Load constrained decoding module"""
from src.constrained_generation import IEConstrainedGeneration

params = {}
params['constrained_worlds_dir'] = os.path.join(DATA_DIR, "constrained_worlds")
params['constrained_world_id'] = "genie_t5_tokenizeable" # specifies the folder name from which the constrained world is loaded
params['identifier'] = "genie_t5_tokenizeable" # specifies the cache subfolder where the trie will be stored
    
params['path_to_trie_cache_dir'] = os.path.join(DATA_DIR, ".cache")
params['path_to_entid2name_mapping'] = os.path.join(DATA_DIR, "id2name_mappings", "entity_mapping.jsonl")
params['path_to_relid2name_mapping'] = os.path.join(DATA_DIR, "id2name_mappings", "relation_mapping.jsonl")

constraint_module = IEConstrainedGeneration.from_constrained_world(model=model, 
                                                                   linearization_class_id=model.hparams.linearization_class_id, 
                                                                   **params)

model.constraint_module = constraint_module

In [5]:
import jsonlines

path_to_entity_id2name_mapping = os.path.join(DATA_DIR, "id2name_mappings", "entity_mapping.jsonl")
with jsonlines.open(path_to_entity_id2name_mapping) as reader:
    entity_id2name_mapping = {obj["id"]: obj["en_label"] for obj in reader}

path_to_relation_id2name_mapping = os.path.join(DATA_DIR, "id2name_mappings", "relation_mapping.jsonl")
with jsonlines.open(path_to_relation_id2name_mapping) as reader:
    relation_id2name_mapping = {obj["id"]: obj["en_label"] for obj in reader}

In [6]:
def get_entity_id_by_name(entity_name, mapping):
    return next(k for k, v in mapping.items() if v == entity_name)

def get_relation_id_by_name(relation_name, mapping):
    return next(k for k, v in mapping.items() if v == relation_name)

In [7]:
from helper_tools import parser

DATASET = "synthie_text"
SPLIT = "test"
NUMBER_OF_SAMPLES = 50

triple_df, entity_df, docs = parser.unified_parser(DATASET, SPLIT, NUMBER_OF_SAMPLES, upload=False)

Fetching 27 files:   0%|          | 0/27 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 39666.20it/s]


In [8]:
from tqdm import tqdm
import pandas as pd

override_models_default_hf_generation_parameters = {
    "num_beams": 10,
    "num_return_sequences": 1,
    "return_dict_in_generate": True,
    "output_scores": True,
    "seed": 123,
    "length_penalty": 0.8
}

turtle_string_docs = dict()

for i in tqdm(range(len(docs))):
    target_doc = docs.iloc[i]
    doc_id = target_doc["docid"]
    text = target_doc["text"]
    output = model.sample([text],
                        convert_to_triplets=True,
                        **override_models_default_hf_generation_parameters)
    turtle_string = "@prefix wd: <http://www.wikidata.org/entity/> .\n"
    for triple in output['grouped_decoded_outputs'][0][0]:
        turtle_string += f"wd:{get_entity_id_by_name(triple[0], entity_id2name_mapping)} wdt:{get_relation_id_by_name(triple[1], relation_id2name_mapping)} wd:{get_entity_id_by_name(triple[2], entity_id2name_mapping)} .\n"
    turtle_string_docs[doc_id] = turtle_string

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [03:32<00:00,  4.25s/it]


In [9]:
import pickle
from datetime import datetime

pickle.dump(turtle_string_docs, open(f"{ciexmas_root}/approaches/evaluation_logs/One_Agent/{DATASET}-{SPLIT}-{NUMBER_OF_SAMPLES}-evaluation_log-{os.getenv('LLM_MODEL_PROVIDER')}_{ckpt_name.split('.')[0]}-{datetime.now().strftime('%Y-%m-%d-%H%M')}.xlsx","wb"))