In [6]:
import sys
sys.path.append("..")

from src.utils.triplet_utils import TripletUtils

import jsonlines
import pickle
import json 
import os

mappings_folder = "../data/id2name_mappings"

In [None]:
def process_mapping(mapping, output_dir, output_file_name):
    data = [{"id": key, "en_label": value} for key, value in mapping.items() if key is not None]

    output_file_path = os.path.join(output_dir, output_file_name)
    with jsonlines.open(output_file_path, "w") as writer:
        writer.write_all(data)


path_to_entity_mapping = os.path.join("../data/rebel", "english_mapping.pickle")
with open(path_to_entity_mapping, "rb") as f:
    ent_m = pickle.load(f)
process_mapping(ent_m, mappings_folder, "entity_mapping.jsonl")


path_to_relations_mapping = os.path.join("../data/rebel", "relations.pickle")
with open(path_to_relations_mapping, "rb") as f:
    rel_m = pickle.load(f)
process_mapping(rel_m, mappings_folder, "relation_mapping.jsonl")

In [35]:
def _read_world(constrained_worlds_dir, world_id):
    assert world_id in {None, "genie"}, "Invalid constrained world identifier {}".format(world_id)

    if world_id is None:
        return

    path_to_constrained_world_dir = os.path.join(constrained_worlds_dir, world_id)
    
    with open(os.path.join(path_to_constrained_world_dir, "entities.json")) as json_file:
        entities = json.load(json_file)

    with open(os.path.join(path_to_constrained_world_dir, "relations.json")) as json_file:
        relations = json.load(json_file)

    return entities, relations


In [36]:
# ~~~ Maximal relation set from training and maximal entity set from training + val + test ~~~~
keep_spaces_entity = False
keep_spaces_relation = True


# Read entity mapping 
ent_mapping_path = os.path.join(mappings_folder, "entity_mapping.jsonl")
with jsonlines.open(ent_mapping_path) as reader:
    ent_map = {obj['id']: obj['en_label'] for obj in reader}
# Read relation mapping
rel_mapping_path = os.path.join(mappings_folder, "relation_mapping.jsonl")
with jsonlines.open(rel_mapping_path) as reader:
    rel_map = {obj['id']: obj['en_label'] for obj in reader}


# Get entities and relations
ent_ids, rel_ids = _read_world("../data/constrained_worlds", "genie")

# Map ids to names and normalize them
ent_names = [TripletUtils.normalize_spaces(ent_map[_id], keep_spaces=keep_spaces_entity) for _id in ent_ids if _id in ent_map]
rel_names = [TripletUtils.normalize_spaces(rel_map[_id], keep_spaces=keep_spaces_relation) for _id in rel_ids if _id in rel_map]

In [37]:
# valid_rebel_ent_ids = [ent_id for ent_id in ent_ids if ent_id in ent_map]

# import json
# with open('entities.json', 'w') as f:
#     json.dump(valid_rebel_ent_ids, f)

print(f"{len(ent_ids) - len(ent_names)} out of {len(ent_ids)} entities were not found in the mapping")
print(f"{len(rel_ids) - len(rel_names)} out of {len(rel_ids)} entities were not found in the mapping")

0 out of 2724925 entities were not found in the mapping
0 out of 1088 entities were not found in the mapping


### Test

In [119]:
# target_linearization:
#   formatted_triplet:
#     [
#       "{subject_id}",
#       " ",
#       "{subject}",
#       " ",
#       "{relation_id}",
#       " ",
#       "{relation}",
#       " ",
#       "{object_id}",
#       " ",
#       "{object}",
#       " ",
#       "{et_id}",
#     ]
#   keep_spaces: false
#   surface_forms:
#     "{subject_id}": "[s]"
#     "{relation_id}": "[r]"
#     "{object_id}": "[o]"
#     "{et_id}": "[et]"
#     "{separator}": " "

# linearized_triplets.extend([surface_forms.get(item, item) for item in formatted_triplet])
# linearized_triplets.append(surface_forms["{separator}"])

from transformers import T5Tokenizer, T5Config

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
config = T5Config.from_pretrained("google/flan-t5-base")

def encode(text, keep_eos: bool):
    if keep_eos:
        return tokenizer.encode(text)
    
    return tokenizer.encode(text)[:-1]

In [203]:
tokenizer.eos_token_id

1

In [216]:
tokenizer.encode(" [s"), tokenizer.encode(" s")

([784, 7, 1], [3, 7, 1])

In [232]:
tokenizer.encode("784")

[489, 4608, 1]

In [233]:
tokenizer.decode(784)

'['

In [214]:
tokenizer.encode("[s]")

[784, 7, 908, 1]

In [204]:
x = lambda x: x

In [205]:
x.eleni = 3

In [206]:
x.eleni

3

In [56]:
target_text = "[s] Trinity_Peninsula [r] part of [o] Graham_Land [et] [s] Trinity_Peninsula [r] continent [o] Antarctica [et] [s] Graham_Land [r] continent [o] Antarctica [et]</s>"
target_ids = [784, 7, 908, # sub_id
              20699, 834, 345, 35, 15953, 9,  # sub
              784, 52, 908, # rel_id
              294, 13, # rel
              784, 32, 908, # obj_id
              15146, 834, 434, 232, # obj
              784, 15, 17, 908, # et_id
              784, 7, 908,
              20699, 834, 345, 35, 15953, 9, 784, 52, 908, 10829, 784, 32, 908, 26461, 9, 784, 15, 17, 908, 784, 7, 908, 15146, 834, 434, 232, 784, 52, 908, 10829, 784, 32, 908, 26461, 9, 784, 15, 17, 908, 1]
# (784, 7, 908)

In [122]:
tokenizer.decode(0) # decoder_start_token_id

'<pad>'

In [131]:
tokenizer.encode("[et]")

[784, 15, 17, 908, 1]

In [123]:
prefix_allowed_tokens_fn_params = {
    "subject_token": "s",
    "relation_token": "r",
    "object_token": "o",
    "end_of_triple_token": "et",
    "start_of_triple_tag": "[",
    "end_of_triple_tag": "]",
}

In [None]:
codes = ["[s]", "[r]", "[o]", "[et]"]

# if inside identifier, return options

# if ouside identifier, return options
# codes = [entity trie, relation trie, entity trie, exit or start of [s]]


In [None]:
import regex

patterns_of_interest = ["[s]", "[r]", "[o]", "[et]"]

def find_total_occurence

In [102]:
tokenizer.encode(target_text)



[784,
 7,
 908,
 20699,
 834,
 345,
 35,
 15953,
 9,
 784,
 52,
 908,
 294,
 13,
 784,
 32,
 908,
 15146,
 834,
 434,
 232,
 784,
 15,
 17,
 908,
 784,
 7,
 908,
 20699,
 834,
 345,
 35,
 15953,
 9,
 784,
 52,
 908,
 10829,
 784,
 32,
 908,
 26461,
 9,
 784,
 15,
 17,
 908,
 784,
 7,
 908,
 15146,
 834,
 434,
 232,
 784,
 52,
 908,
 10829,
 784,
 32,
 908,
 26461,
 9,
 784,
 15,
 17,
 908,
 1]

In [77]:
(tokenizer.decode([784, 7, 908]), 
tokenizer.decode([20699, 834, 345, 35, 15953, 9]), 
tokenizer.decode([784, 52, 908]), 
tokenizer.decode([294, 13]),
tokenizer.decode([784, 32, 908]),
tokenizer.decode([15146, 834, 434, 232]),
tokenizer.decode([784, 15, 17, 908]),
tokenizer.decode([784, 7, 908])) # And so on until you close the triplet with 784, 15, 17, 908 and then 1

('[s]',
 'Trinity_Peninsula',
 '[r]',
 'part of',
 '[o]',
 'Graham_Land',
 '[et]',
 '[s]')

In [81]:
encode("Graham_Land", keep_eos=True)

[15146, 834, 434, 232, 1]

In [147]:
special_tokens_text = {
    "sub_id": "s",
    "rel_id": "r",
    "obj_id": "o",
    "et_id": "et",
}

full_identifiers = {"sub_id": "[s]", "rel_id": "[r]", "obj_id": "[o]", "et_id": "[et]"}

# verify that tokenized partial and full identifiers are eqivalent
for key, value in full_identifiers.items():
    print(value)
    assert tokenizer.encode(value)[:-1] == tokenizer.encode(f"{special_tokens_text['start_of_tag']}{special_tokens_text[key]}{special_tokens_text['end_of_tag']}")[:-1]

[s]
[r]
[o]
[et]


{'sub_id': 's',
 'rel_id': 'r',
 'obj_id': 'o',
 'et_id': 'et',
 'start_of_tag': '[',
 'end_of_tag': ']'}

In [95]:
tokenizer.encode("[s]A [s] A")

[784, 7, 908, 188, 784, 7, 908, 71, 1]

In [100]:
tokenizer.decode([12, 188]), tokenizer.decode([12, 71], skip_special_tokens=True)

('toA', 'to A')

In [91]:
tokenizer.decode(tokenizer.encode("[s]A [s] A"))

'[s]A [s] A</s>'

In [158]:
special_tokens_text = {
    "sub_id": "[s]",
    "rel_id": "[r]",
    "obj_id": "[o]",
    "et_id": "[et]",
}

special_tokens_ids = {key: np.array(encode(value, keep_eos=False)) for key, value in special_tokens_text.items()}
special_tokens_ids

{'sub_id': array([784,   7, 908]),
 'rel_id': array([784,  52, 908]),
 'obj_id': array([784,  32, 908]),
 'et_id': array([784,  15,  17, 908])}

In [159]:
sent_ids = np.array([1, 2, 3, 4, 5])

In [None]:
import numpy as np 

def encode(text, keep_eos: bool):
    if keep_eos:
        return tokenizer.encode(text)
    
    return tokenizer.encode(text)[:-1]

def get_prefix_allowed_tokens_fn(model, entities_trie, relation_trie, sub_id_token = "[s]", rel_id_token = "[r]", obj_id_token = "[o]", et_id_token = "[et]"):
    EOS_TOKEN = model.tokenizer.eos_token_id

    state_id2token_ids = {'sub_id': np.array(encode(sub_id_token, keep_eos=False)),
                            'rel_id': np.array(encode(rel_id_token, keep_eos=False)),
                            'obj_id': np.array(encode(obj_id_token, keep_eos=False)),
                            'et_id': np.array(encode(et_id_token, keep_eos=False)),
                            }

    state_id2next_state_id = {"sub_id": "rel_id", "rel_id": "obj_id", "obj_id": "et_id", "et_id": "sub_id"}

    def _get_next_state_id(state_id):
        return state_id2token_ids[state_id2next_state_id[state_id]]

    def _get_allowed_tokens_from_trie(suffix, trie, next_state_first_token_id):
        allowed_tokens = trie.get(suffix)

        if EOS_TOKEN in allowed_tokens:
            allowed_tokens.remove(EOS_TOKEN)
            allowed_tokens.append(next_state_first_token_id)

        return allowed_tokens

    def get_allowed_tokens(state_id, suffix):
        next_state_id = _get_next_state_id(state_id)

        # ~~~ if currently generating a state identifier ~~~
        while next_state_id.size > 1:
            window = next_state_id[:-1]

            if suffix.size < window.size:
                next_state_id = window
                continue 

            if np.array_equal(window, suffix[-len(window):]):
                return [next_state_id[-1]]
                
            next_state_id = window
        
        # ~~~ otherwise ~~~
        if state_id == "et_id":
            return [EOS_TOKEN, _get_next_state_id(state_id)[0]]

        elif state_id == "rel_id":
            return _get_allowed_tokens_from_trie(suffix, relation_trie, _get_next_state_id(state_id)[0])

        return _get_allowed_tokens_from_trie(suffix, entities_trie, _get_next_state_id(state_id)[0])

    def _get_state_id_and_suffix_start(sent_ids):
        last_token_idx_plus_one = len(sent_ids)

        while last_token_idx_plus_one > 0:
            for state_id, pattern in state_id2token_ids.items():
                pat_size = pattern.size

                if last_token_idx_plus_one < pat_size:
                    continue

                window = sent_ids[last_token_idx_plus_one-pat_size:last_token_idx_plus_one] 
                if np.array_equal(window, pattern):
                    return state_id, last_token_idx_plus_one

            last_token_idx_plus_one -= 1
        
        return "et_id", 0

    def prefix_allowed_tokens_fn(batch_id: int, sent_ids: torch.Tensor) -> Iterable[int]:
        sent_ids = sent_ids.tonumpy()

        # ToDo: Is this necessary? It was for genie for some weird reason that I didn't figure out
        if len(sent_ids) > 1 and sent_ids[-1] == EOS_TOKEN:
            return []

        state_id, suffix = _get_state_id_and_suffix_start(sent_ids)
        return get_allowed_tokens(state_id, suffix)

    return prefix_allowed_tokens_fn

        

In [186]:
import numpy as np 

sent_ids = np.array([0, 1, 2, 3, 4])





In [187]:
temp = special_tokens_ids.copy()
temp

{'sub_id': array([784,   7, 908]),
 'rel_id': array([784,  52, 908]),
 'obj_id': array([784,  32, 908]),
 'et_id': array([784,  15,  17, 908])}

In [195]:
sent_ids = np.array([1, 2, 3, 4, 5])

temp = special_tokens_ids.copy()

# Case 1
temp["test"] = np.array([2, 3])
sc, idx = get_state_and_suffix_start(sent_ids, temp)
assert sc == "test" and idx == 3

# Case 2
temp["older"] = np.array([1, 2])
sc, idx = get_state_and_suffix_start(sent_ids, temp)
assert sc == "test" and idx == 3

# Case 3
temp["newer"] = np.array([3, 4])
sc, idx = get_state_and_suffix_start(sent_ids, temp)
assert sc == "newer" and idx == 4

In [198]:
suffix = sent_ids[idx:]
suffix

array([5])

In [191]:
sc, idx

('test', 3)

In [2]:
import os

constrained_worlds_dir = os.path.join("../data", "constrained_worlds")
world_id = "genie"

In [3]:
entities, relations = _read_world(constrained_worlds_dir, world_id)

In [4]:
entities[:5]

['Q1607497', 'Q22958416', 'Q1848116', 'Q1513683', 'Q19880251']

In [None]:
for key, value in english_mapping.items():
    if key is not None and key.startswith("P"):
        print(key, value)

In [None]:
for key, value in english_mapping.items():
    if key is not None and key.startswith("P"):
        print(key, value)

In [7]:


en_map = [{"id": key, "en_label": value} for key, value in ent_m.items() if key is not None]

folder = "../data/id2name_mappings"
file_name = "entity_mapping.jsonl"


In [None]:
import jsonlines

en_map = [{"id": key, "en_label": value} for key, value in ent_m.items() if key is not None]

folder = "../data/id2name_mappings"
file_name = 

output_file_path = os.path.join(folder, file_name)
with jsonlines.open(output_file_path, "w") as writer:
    writer.write_all(en_map)



In [34]:
english_mapping[None]

'Herangi_Range'

In [33]:
list(keys)[:5]

['Q191069', 'Q47716', 'Q15643', 'Q2143143', 'Q1899035']