# LLM Reasoning (1) Fact Retrieval

## Input
- Ranked rules(ranked_rules/icews14/confidence.json)

## Output
- Retrieved facts(retrieved_facts/icews14/prediction_icews14.txt)

## Setup requirements
- python 3.12



In [1]:
!pip install datasets==2.13.1 torch==2.4.0 fairscale==0.4.13 fire==0.5.0 numpy<2 tokenizers transformers==4.35.2 tqdm wandb==0.16.0 sentencepiece peft==0.10 PyYAML==6.0.1 setuptools bitsandbytes==0.41.0 scipy

/bin/bash: line 1: 2: No such file or directory


In [None]:
# uv run ./data_utils/retrieve.py --name_of_rules_file confidence.json --dataset icews14 -s ./data/processed_rule2

## Utils


In [37]:
import json


def read_txt_as_list(path_txt):
    with open(path_txt, 'r', encoding='utf-8-sig') as file:
        data = file.readlines()
    return data


def read_json(json_dir):
    with open(json_dir, "r", encoding="utf-8") as f:
        json_data = json.load(f)
    return json_data

def write_txt(txt_dir, out_list, head='\t'):
    with open(txt_dir, 'w', encoding='utf-8') as txtfile:
        for sublist in out_list:
            txtfile.write(head.join(map(str, sublist)) + '\n')

def flip_dict(original_dict):
    return {v: k for k, v in original_dict.items()}

def str_dict(original_dict):
    return {str(k): str(v) for k, v in original_dict.items()}


def convert(str_in, dict_in):
    return dict_in[str_in]

def id_words(li, dict_ent, dict_r, dict_t, end=str(0), period=1):
    li_new = []
    for line in li:
        try:
            columns = line.strip().split("\t")
    
            # columns[0] = str(convert(columns[0], dict_ent))
            # columns[1] = str(convert(columns[1], dict_r))
            # columns[2] = str(convert(columns[2], dict_ent))
            columns[0] = str(columns[0])
            columns[1] = str(columns[1])
            columns[2] = str(columns[2])
            # columns[3] = str(convert(str(int(columns[3]) * period), dict_t))
            columns[3] = str(columns[3])
            line = "\t".join([columns[0], columns[1], columns[2], columns[3], end])
            li_new.append(line)
        except Exception as e:
            print(f"Error converting line: {line}")
            print(f"Error: {e}")
            raise e
    return li_new


def convert_dataset(li_to_convert, path_workspace, end=str(0), period=1):
    relations = read_json(path_workspace + "relation2id.json")
    entities = read_json(path_workspace + "entity2id.json")
    times_id = read_json(path_workspace + "ts2id.json")
    test_ans = id_words(
        li_to_convert,
        # str_dict(flip_dict(entities)),
        # str_dict(flip_dict(relations)),
        # str_dict(flip_dict(times_id)),
        str_dict(entities),
        str_dict(relations),
        str_dict(times_id),
        end,
        period,
    )  # convert list in ids to list in words
    return test_ans


In [None]:
## TLR

In [32]:
import numpy as np
import time as ti
from tqdm import tqdm


class Retriever:
    def __init__(
        self,
        test,
        all_facts,
        entities,
        relations,
        times_id,
        num_relations,
        chains,
        rel_keys,
        dataset,
        retrieve_type="TLogic",
    ):
        self.retrieve_type = retrieve_type
        self.dataset = dataset
        self.test = test
        self.all_facts = all_facts

        self.entities = entities
        self.relations = relations
        self.times_id = times_id
        self.num_relations = num_relations
        self.chains = chains

        self.entities_flip = flip_dict(self.entities)
        self.relations_flip = flip_dict(self.relations)
        col_sub = []
        col_rel = []
        col_obj = []
        col_time = []
        for row in all_facts:
            row = row.strip().split("\t")
            col_sub.append(row[0])  # take sub
            col_rel.append(row[1])  # Get the relation column of all facts
            col_obj.append(row[2])  # Take obj
            col_time.append(row[3])  # Time in Str form
        self.col_obj = np.array(col_obj)  # To get cand and then search for facts use
        self.col_sub = np.array(col_sub)  # To get cand and then search for facts use
        self.col_time = np.array(col_time)  # time, str form
        self.col_rel = np.array(col_rel)  # relation, str form
        all_facts_array = np.array(all_facts)

        self.rel_keys = np.array(rel_keys)

    def prepare_bs(self, i):
        sub, rel, _, time, _ = self.test[i].strip().split("\t")
        idx_t = np.where(self.col_time < time)[0]  # cannot be equal to
        s_t = set(idx_t)
        idx0 = np.where(self.col_sub == sub)[0]
        s0 = set(idx0)
        idx = list(s0 & s_t)
        idx.sort(reverse=True)
        time = self.times_id[time]
        return time, sub, rel, idx

    def build_bs(self):
        # Pure Entity mode
        test_text = []
        test_idx = []

        for i in tqdm(range(0, len(self.test))):  # csv has a header, txt has no header. len(self.test)
            num_facts = 50  # 20 or 100
            time, sub, rel, idx = self.prepare_bs(i)

            facts = []
            idx = idx[0:num_facts]
            for k in idx:
                facts.append(self.all_facts[k])  # Get the facts where sub and rel are the same

            if len(facts) < num_facts:
                num_facts = len(facts)

            histories = self.collect_hist(i, facts, num_facts)
            history_query = self.build_history_query(time, sub, rel, histories=histories)

            test_idx.append(idx)
            test_text.append(history_query)
        return test_idx, test_text

    def tlogic_prepro(self, i):
        test_sub, test_rel, _, test_time, _ = self.test[i].strip().split("\t")
        # First of all, there must be a time premise of retrieve s_t
        # Here we need to find out the idx of test in all_facts so that it can be removed
        idx_test = len(self.all_facts) - (len(self.test) - 1) + i - 1
        # #The major premise is that retrieval must be performed from those ranges earlier than test_time
        idx_t = np.where(self.col_time < test_time)[0]
        s_t = set(idx_t)
        if idx_test in s_t:
            s_t.remove(idx_test)  # Remove the test item itself
            # Second, the search for the beginning of the chain needs to be restricted to rel==test_sub
        idx_test_sub = np.where(self.col_sub == test_sub)[0]
        s_test_sub = set(idx_test_sub)
        s_0 = s_t & s_test_sub  # Get: major premise
        head_rel = self.relations[test_rel]  # Get the idx corresponding to test_relation: 0,1,2,...
        time = self.times_id[
            test_time
        ]  # To move forward the time according to the id corresponding relationship of time. Get int from str
        return s_0, s_t, head_rel, int(time), test_sub, test_rel

    def build_tl(self):
        test_text = []
        test_idx = []

        for i in tqdm(range(0, len(self.test))):  # Starting from 1 because there is a header. 1, len(test)

            num_facts = 50  # Set it again for each question, as the following may change.
            s_0, s_t, head_rel, time, test_sub, test_rel = self.tlogic_prepro(i)
            facts = []
            idx = []

            if not str(head_rel) in self.chains:  # If the test relation has no chain, nothing will be done.
                # l = ['Just repeat "No Chains."\n']
                history_query = [
                    str(int(time / 24)) + ": [" + test_sub + ", " + test_rel + ",\n"
                ]  # TODO : 24로 나누는 부분 추가함
                test_idx.append([])  # At this time idx is a blank line
                test_text.append(history_query)
                # print(i, 'no chain in this line')
                continue
            s_0 = np.array(list(s_0))  # Convert collection to NumPy array
            # After the above preparations, start searching for facts by chain.
            idx_chain = []
            for k in range(0, len(self.chains[str(head_rel)])):  # There are len(chains[str(head_rel)]) chains
                body_rel_len = len(self.chains[str(head_rel)][k]["body_rels"])  # The length of this chain
                if body_rel_len == 1:
                    idx_chain.append(k)
            for k in idx_chain:  # TLogic (or TLogic-3), as long as the shortest len=1 chain
                idx_case = []
                body_rel_len = len(self.chains[str(head_rel)][k]["body_rels"])  # The length of this chain
                rel = self.chains[str(head_rel)][k]["body_rels"][-1] % self.num_relations
                idx_rel = np.where(self.col_rel == self.rel_keys[rel])[0]
                idx_rel = np.intersect1d(idx_rel, s_0)  # Using NumPy functions for intersection operations
                idx_case = idx_rel
                idx_case.tolist()
                if len(idx_case) != 0:  # If it is not empty, retrieve it.
                    idx_case = list(set(idx_case))
                    idx = list(set(idx + idx_case))
                else:
                    continue  # If no such chain exists, jump to the next chain
                if len(idx) >= num_facts:
                    break  # Break out of the loop on chains and go to the next test
            # time reordering
            idx.sort(reverse=True)
            # Idx with chain.sort(reverse=true)
            if len(idx) > num_facts:
                idx = idx[0:num_facts]
                for a in idx:
                    facts.append(self.all_facts[a])
            else:
                for a in idx:
                    facts.append(self.all_facts[a])
            test_idx.append(idx)
            if len(facts) == 0:  # If the test relation has no chain, nothing will be done.
                history_query = self.build_history_query(time, test_sub, test_rel)  # Self.times id[time]
                test_text.append(history_query)
                # test_idx.append([]) The above line has been appended
                continue
            if num_facts >= len(facts):
                num_facts = len(facts)

            histories = self.collect_hist(i, facts, num_facts)
            history_query = self.build_history_query(time, test_sub, test_rel, histories=histories)
            test_text.append(history_query)
            ti.sleep(0.001)
        return test_idx, test_text

    def collect_hist(self, i, facts, num_facts):
        period = 1
        if self.dataset == "icews14" or self.dataset == "icews18":
            period = 24
        histories = []
        facts = facts[0:num_facts]  #
        facts.reverse()  # Replace the order so that the last output is the one closest in time.
        for b in range(num_facts):
            fact = facts[b].strip().split("\t")
            time_in_id = self.times_id[fact[3]]
            sub_in_word = fact[0]
            rel_in_word = fact[1]

            obj_in_word = fact[2]
            id_obj = self.entities[obj_in_word]
            histories = histories + [
                str(int(int(time_in_id) / int(period)))
                + ": ["
                + sub_in_word
                + ", "
                + rel_in_word
                + ", "
                + str(id_obj)
                + "."
                + obj_in_word
                + "] \n"
            ]
        return histories

    def build_history_query(self, time, test_sub, test_rel, histories=""):
        period = 1
        if self.dataset == "icews14" or self.dataset == "icews18":
            period = 24
        # time_in_id = self.times_id[time]  # TODO: 이미 id형태의 time이 나와서 로직 변경함.
        time_in_id = time

        return [
            "".join(histories)
            + str(int(int(time_in_id) / int(period)))
            + ": ["
            + test_sub
            + ", "
            + test_rel
            + ",\n"  # times id[time]
        ]

    def call_function(self, func_name):
        func = getattr(self, func_name)
        if func and callable(func):
            test_idx, test_text = func()
        else:
            print("Retrieve function not found")
        return test_idx, test_text

    def get_output(self):
        type_retr = "bs" if self.retrieve_type == "bs" else "tl"
        test_idx, test_text = self.call_function("build_" + type_retr)

        return test_idx, test_text


## Main

In [38]:
import os, glob
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", default="icews14", type=str)
parser.add_argument("--retrieve_type", "-t", default="TLogic-3", type=str)
parser.add_argument("--name_of_rules_file", "-r", default="confidence.json", type=str)
parser.add_argument("--path_save", "-s", default="./retrieved_facts/", type=str)

parsed, _ = parser.parse_known_args()
parsed = vars(parsed)

retrieve_type = parsed["retrieve_type"]
type_dataset = parsed["dataset"]
name_rules = parsed["name_of_rules_file"]

path_workspace = "./datasets/" + type_dataset + "/"  
path_out_tl = "./ranked_rules/" + type_dataset + "/"
print(path_out_tl)

path_save = parsed["path_save"] + type_dataset + "/"
if not os.path.exists(path_save):
    os.makedirs(path_save)

period = 1
if type_dataset == "icews18":
    num_relations = 256  # for ICEWS18 #set before np.array
    period = 24
elif type_dataset == "icews14":
    num_relations = 230
    # period = 24 # 데이터 바꾸면서 필요없어짐
elif type_dataset == "GDELT":
    num_relations = 238  # GDELT and
else:
    num_relations = 24  # YAGO

test_ans = []
li_files = ["test"]  #  ['train','valid','test'] or  ['test'] when only test set is needed

for files in li_files:
    # print("existing rules:", glob.glob(path_out_tl + "*rules.json"))
    dir_rules = glob.glob(path_out_tl + "*rules.json")[0] if name_rules == "" else path_out_tl + name_rules
    # print("files", files)
    test_ans = read_txt_as_list(path_workspace + files + ".txt")

    relations = read_json(path_workspace + "relation2id.json")
    entities = read_json(path_workspace + "entity2id.json")
    times_id = read_json(path_workspace + "ts2id.json")

    test_ans = convert_dataset(test_ans, path_workspace, period=period)

    chains = read_json(path_out_tl + name_rules)
    rel_keys = list(relations.keys())
    ent_idx = list(entities.keys())  # [0, 1, ...]
    times_id_keys = list(times_id.keys())
    all_facts = []
    with open(path_workspace + "all_facts.txt", "r", encoding="utf-8") as f:
        all_facts = f.readlines()

    rtr = Retriever(
        test_ans, all_facts, entities, relations, times_id, num_relations, chains, rel_keys, dataset=type_dataset
    )
    test_idx, test_text = rtr.get_output()

    path_file = (
        path_save + files + "/history_facts/" + "history_facts_" + type_dataset
    )  # "history_facts_"+retrieve_type+type_dataset
    path_file_word = path_file + ".txt"
    path_file_id = path_file + "_idx_fine_tune_all.txt"

    if not os.path.exists(path_save + files + "/history_facts/"):
        os.makedirs(path_save + files + "/history_facts/")
    write_txt(path_file_id, test_text)
    with open(path_file_word, "w", encoding="utf-8") as f:
        for i in range(len(test_text)):
            f.write(test_text[i][0] + "\n")
    print("saved as ", path_file_word, "and ", path_file_id)

    path_answer = path_save + files + "/test_answers/" + "test_answers_" + type_dataset + ".txt"
    if not os.path.exists(path_save + files + "/test_answers/"):
        os.makedirs(path_save + files + "/test_answers/")
    write_txt(path_answer, test_ans, head="")
    print("saved as ", path_answer)


./ranked_rules/icews14/


100%|██████████| 7371/7371 [01:15<00:00, 97.51it/s] 

saved as  ./retrieved_facts/icews14/test/history_facts/history_facts_icews14.txt and  ./retrieved_facts/icews14/test/history_facts/history_facts_icews14_idx_fine_tune_all.txt
saved as  ./retrieved_facts/icews14/test/test_answers/test_answers_icews14.txt



