# Dynamic Rule Adaption (2) Rule Generation & Dynamic Adaptation


## Input
- icews14 (/dataset/icews14/*)
- sampled rules (sampled_path/icews14/original/rules_var.json)

## Output
- llm adapted rules (gen_rules_iteration/final_summary/rules.txt)
- confidence files for each iterations (evaluation/eva/*)

## Setup requirements
- python 3.12



In [1]:
!pip install joblib>=1.5.0 \
    matplotlib \
    networkx==3.2.1 \
    numpy==1.26.3 \
    openai==1.79.0 \
    pandas==2.2.0 \
    python-dotenv==1.0.0 \
    scipy==1.12.0 \
    seaborn==0.13.1 \
    sentence-transformers>=3.0.1 \
    tiktoken==0.9.0 \
    torch==2.4.0 \
    tqdm==4.66.1 \
    transformers==4.35.2 \
    datasets

## Get LLM Model

In [2]:
import collections


class BaseLanguageModel(object):
    """
    Base lanuage model. Define how to generate sentence by using a LM
    Args:
        args: arguments for LM configuration
    """

    @staticmethod
    def add_args(parser):
        return

    def __init__(self, args):
        self.args = args

    def load_model(self, **kwargs):
        raise NotImplementedError

    def token_len(self, text):
        '''
        Return tokenized length of text

        Args:
            text (str): input text
        '''
        raise NotImplementedError
    
    def prepare_for_inference(self, **model_kwargs):
        raise NotImplementedError
    
    def prepare_model_prompt(self, query):
        '''
        Add model-specific prompt to the input

        Args:
            instruction (str)
            input (str): str
        '''
        raise NotImplementedError
    
    def generate_sentence(self, llm_input):
        """
        Generate sentence by using a LM

        Args:
            lm_input (LMInput): input for LM
        """
        raise NotImplementedError

    def gen_rule_statistic(self, input_dir, output_file_path):

        raise NotImplementedError
    

import time
import os
from openai import OpenAI
import openai
import dotenv
import tiktoken
import glob

dotenv.load_dotenv()

os.environ["TIKTOKEN_CACHE_DIR"] = "./tmp"

OPENAI_MODEL = ["gpt-4", "gpt-3.5-turbo"]


def get_token_limit(model="gpt-4"):
    """Returns the token limitation of provided model"""
    if model in ["gpt-4", "gpt-4-0613"]:
        num_tokens_limit = 8192
    elif model in ["gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0125"]:
        num_tokens_limit = 16384
    elif model in ["gpt-3.5-turbo", "gpt-3.5-turbo-0613", "text-davinci-003", "text-davinci-002"]:
        num_tokens_limit = 4096
    elif model in ["gpt-4.1-mini", "gpt-4.1-nano"]: 
        num_tokens_limit = 16384
    else:
        raise NotImplementedError(f"""get_token_limit() is not implemented for model {model}.""")
    return num_tokens_limit


PROMPT = """{instruction}

{input}"""


class ChatGPT(BaseLanguageModel):
    @staticmethod
    def add_args(parser):
        parser.add_argument("--retry", type=int, help="retry time", default=5)
        parser.add_argument("--model_path", type=str, default="None")
        
    def __init__(self, args):
        super().__init__(args)
        self.retry = args.retry
        self.model_name = args.model_name
        self.maximum_token = get_token_limit(self.model_name)

    def token_len(self, text):
        """Returns the number of tokens used by a list of messages."""
        try:
            encoding = tiktoken.encoding_for_model("gpt-4o")  # TODO: 수정했음. 4.1은 4o와 같은 tokenizer씀
            num_tokens = len(encoding.encode(text))
        except KeyError:
            raise KeyError(f"Warning: model {self.model_name} not found.")
        return num_tokens

    def prepare_for_inference(self, model_kwargs={}):
        client = OpenAI(
            api_key=os.environ["OPENAI_API_KEY"],  # this is also the default, it can be omitted
        )
        self.client = client

    def prepare_model_prompt(self, query):
        """
        Add model-specific prompt to the input
        """
        return query

    def generate_sentence(self, llm_input):
        query = [{"role": "user", "content": llm_input}]
        cur_retry = 0
        num_retry = self.retry
        # Check if the input is too long
        input_length = self.token_len(llm_input)
        if input_length > self.maximum_token:
            print(
                f"Input length {input_length} is too long. The maximum token is {self.maximum_token}.\n Right truncate the input to {self.maximum_token} tokens."
            )
            llm_input = llm_input[: self.maximum_token]
            query = [{"role": "user", "content": llm_input}]
        while cur_retry <= num_retry:
            try:
                response = self.client.chat.completions.create(
                    model=self.model_name, messages=query, timeout=60, temperature=0.0
                )
                result = response.choices[0].message.content.strip()  # type: ignore
                return result
            except openai.APITimeoutError as e:
                wait_time = 30 + 10 * cur_retry  # Exponential backoff
                print(f"Request Time out. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
                cur_retry += 1
            except openai.RateLimitError as e:
                wait_time = 30 + 10 * cur_retry  # Exponential backoff
                print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
                cur_retry += 1
            except openai.APIConnectionError as e:
                # 打印异常的详细信息
                print("Failed to connect to OpenAI API.")
                print("Error message:", e.args[0] if e.args else "No details available.")
                if hasattr(e, "response") and e.response:
                    print("HTTP response status:", e.response.status_code)
                    print("HTTP response body:", e.response.text)
                else:
                    print("No HTTP response received.")
                wait_time = 30 + 10 * cur_retry  # Exponential backoff
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
                cur_retry += 1
            except Exception as e:
                print("Message: ", llm_input)
                print("Number of token: ", self.token_len(llm_input))
                print(e)
                time.sleep(30)
                cur_retry += 1
        print(f"Maximum retries reached. Unable to generate sentence")
        return None

    def gen_rule_statistic(self, input_dir, output_file_path):
        sum = 0
        with open(output_file_path, "w") as fout:
            for input_filepath in glob.glob(os.path.join(input_dir, "*.txt")):
                file_name = input_filepath.split("/")[-1]
                if file_name.startswith("fail"):
                    continue
                else:
                    with open(input_filepath, "r") as fin:
                        rules = fin.readlines()
                        for rule in rules:
                            if "Rule_head" in rule:
                                continue
                            elif "Sample" in rule:
                                continue
                            fout.write(rule)
                            sum += 1

            fout.write(f"LL {sum}\n")


## Utils

In [3]:


def copy_folder_contents(source_folder, destination_folder):
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

def filter_candidates(test_query, candidates, test_data):
    """
    Filter out those candidates that are also answers to the test query
    but not the correct answer.

    Parameters:
        test_query (np.ndarray): test_query
        candidates (dict): answer candidates with corresponding confidence scores
        test_data (np.ndarray): test dataset

    Returns:
        candidates (dict): filtered candidates
    """

    other_answers = test_data[
        (test_data[:, 0] == test_query[0])
        * (test_data[:, 1] == test_query[1])
        * (test_data[:, 2] != test_query[2])
        * (test_data[:, 3] == test_query[3])
    ]

    if len(other_answers):
        objects = other_answers[:, 2]
        for obj in objects:
            candidates.pop(obj, None)

    return candidates



import argparse
import shutil


def save_json_data(data, file_path):
    try:
        with open(file_path, "w", encoding="utf-8") as file:
            json.dump(data, file, indent=4, ensure_ascii=False)
        print(f"Data has been converted to JSON and saved to {file_path}")
    except Exception as e:
        print(f"Error saving JSON data to {file_path}: {e}")

def load_json_data(file_path, default=None):
    try:
        if os.path.exists(file_path):
            print(f"Use cache from: {file_path}")
            with open(file_path, "r") as file:
                return json.load(file)
        else:
            print(f"File not found: {file_path}")
            return default
    except Exception as e:
        print(f"Error loading JSON data from {file_path}: {e}")
    return default

def str_to_bool(value):
    if isinstance(value, bool):
        return value
    if value.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif value.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

def clear_folder(folder_path):
    # 确保文件夹存在
    if not os.path.exists(folder_path):
        return

    # 遍历文件夹中的所有文件和文件夹
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        # 如果是文件，则直接删除
        if os.path.isfile(file_path):
            os.remove(file_path)
        # 如果是文件夹，则递归清空文件夹
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)

def check_prompt_length(prompt, list_of_paths, model):
    """Check whether the input prompt is too long. If it is too long, remove the first path and check again."""
    all_paths = "\n".join(list_of_paths)
    all_tokens = prompt + all_paths
    maximun_token = model.maximum_token
    if model.token_len(all_tokens) < maximun_token:
        return all_paths
    else:
        # Shuffle the paths
        random.shuffle(list_of_paths)
        new_list_of_paths = []
        # check the length of the prompt
        for p in list_of_paths:
            tmp_all_paths = "\n".join(new_list_of_paths + [p])
            tmp_all_tokens = prompt + tmp_all_paths
            if model.token_len(tmp_all_tokens) > maximun_token:
                return "\n".join(new_list_of_paths)
            new_list_of_paths.append(p)


## Dataset helper

In [4]:
import copy
import random
from random import sample
import networkx as nx
from scipy import sparse


def construct_nx(idx2rel, idx2ent, ent2idx, fact_rdf):
    G = nx.Graph()
    for rdf in fact_rdf:
        fact = parse_rdf(rdf)
        h, r, t = fact
        h_idx, t_idx = ent2idx[h], ent2idx[t]
        G.add_edge(h_idx, t_idx, relation=r)
    return G


def construct_fact_dict(fact_rdf):
    fact_dict = {}
    for rdf in fact_rdf:
        fact = parse_rdf(rdf)
        h, r, t = fact
        if r not in fact_dict:
            fact_dict[r] = []
        fact_dict[r].append(rdf)

    return fact_dict


def construct_rmat(idx2rel, idx2ent, ent2idx, fact_rdf):
    e_num = len(idx2ent)
    r2mat = {}
    # initialize rmat
    for idx, rel in idx2rel.items():
        mat = sparse.dok_matrix((e_num, e_num))
        r2mat[rel] = mat
    # fill rmat
    for rdf in fact_rdf:
        fact = parse_rdf(rdf)
        h, r, t = fact
        h_idx, t_idx = ent2idx[h], ent2idx[t]
        r2mat[r][h_idx, t_idx] = 1
    return r2mat


class RuleDataset(object):
    def __init__(self, r2mat, rules, e_num, idx2rel, args):
        self.e_num = e_num
        self.r2mat = r2mat
        self.rules = rules
        self.idx2rel = idx2rel
        self.len = len(self.rules)
        self.args = args

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        rel = self.idx2rel[idx]
        _rules = self.rules[rel]
        path_count = sparse.dok_matrix((self.e_num, self.e_num))
        for rule in _rules:
            head, body, conf_1, conf_2 = rule

            body_adj = sparse.eye(self.e_num)
            for b_rel in body:
                body_adj = body_adj * self.r2mat[b_rel]

            body_adj = body_adj * conf_1
            path_count += body_adj

        return rel, path_count

    @staticmethod
    def collate_fn(data):
        head = [_[0] for _ in data]
        path_count = [_[1] for _ in data]
        return head, path_count


def parse_rdf(rdf):
    """
    return: head, relation, tail
    """
    return rdf
    # rdf_tail, rdf_rel, rdf_head = rdf
    # return rdf_head, rdf_rel, rdf_tail


class Dictionary(object):
    def __init__(self):
        self.rel2idx_ = {}
        self.idx2rel_ = {}
        self.idx = 0

    def add_relation(self, rel):
        if rel not in self.rel2idx_.keys():
            self.rel2idx_[rel] = self.idx
            self.idx2rel_[self.idx] = rel
            self.idx += 1

    @property
    def rel2idx(self):
        return self.rel2idx_

    @property
    def idx2rel(self):
        return self.idx2rel_

    def __len__(self):
        return len(self.idx2rel_)


def load_entities(path):
    idx2ent, ent2idx = {}, {}
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            e = line.strip()
            ent2idx[e] = idx
            idx2ent[idx] = e
    return idx2ent, ent2idx


class Dataset(object):
    def __init__(self, data_root, sparsity=1, inv=False):
        # Construct entity_list
        entity_path = data_root + "entities.txt"
        self.idx2ent_, self.ent2idx_ = load_entities(entity_path)
        # Construct rdict which contains relation2idx & idx2relation2
        relation_path = data_root + "relations.txt"
        self.rdict = Dictionary()
        self.load_relation_dict(relation_path)
        # head relation
        self.head_rdict = Dictionary()
        self.head_rdict = copy.deepcopy(self.rdict)
        # load (h, r, t) tuples
        fact_path = data_root + "facts.txt"
        train_path = data_root + "train.txt"
        valid_path = data_root + "valid.txt"
        test_path = data_root + "test.txt"
        if inv:
            fact_path += ".inv"
        self.rdf_data_ = self.load_data_(fact_path, train_path, valid_path, test_path, sparsity)
        self.fact_rdf_, self.train_rdf_, self.valid_rdf_, self.test_rdf_ = self.rdf_data_
        # inverse
        if inv:
            # add inverse relation to rdict
            rel_list = list(self.rdict.rel2idx_.keys())
            for rel in rel_list:
                inv_rel = "inv_" + rel
                self.rdict.add_relation(inv_rel)
                self.head_rdict.add_relation(inv_rel)
                # add None
        self.head_rdict.add_relation("None")

    def load_rdfs(self, path):
        rdf_list = []
        with open(path, "r", encoding="utf-8") as f:
            lines = f.readlines()
            for line in lines:
                tuples = line.strip().split("\t")
                rdf_list.append(tuples)
        return rdf_list

    def load_data_(self, fact_path, train_path, valid_path, test_path, sparsity):
        fact = self.load_rdfs(fact_path)
        fact = sample(fact, int(len(fact) * sparsity))
        train = self.load_rdfs(train_path)
        valid = self.load_rdfs(valid_path)
        test = self.load_rdfs(test_path)
        return fact, train, valid, test

    def load_relation_dict(self, relation_path):
        """
        Read relation.txt to relation dictionary
        """
        with open(relation_path, encoding="utf-8") as f:
            rel_list = f.readlines()
            for r in rel_list:
                relation = r.strip()
                self.rdict.add_relation(relation)
                # self.head_dict.add_relation(relation)

    def get_relation_dict(self):
        return self.rdict

    def get_head_relation_dict(self):
        return self.head_rdict

    @property
    def idx2ent(self):
        return self.idx2ent_

    @property
    def ent2idx(self):
        return self.ent2idx_

    @property
    def fact_rdf(self):
        return self.fact_rdf_

    @property
    def train_rdf(self):
        return self.train_rdf_

    @property
    def valid_rdf(self):
        return self.valid_rdf_

    @property
    def test_rdf(self):
        return self.test_rdf_


def sample_anchor_rdf(rdf_data, num=1):
    if num < len(rdf_data):
        return sample(rdf_data, num)
    else:
        return rdf_data


def construct_descendant(rdf_data):
    # take entity as h, map it to its r, t
    entity2desced = {}
    for rdf_ in rdf_data:
        h_, r_, t_ = parse_rdf(rdf_)
        if h_ not in entity2desced.keys():
            entity2desced[h_] = [(r_, t_)]
        else:
            entity2desced[h_].append((r_, t_))
    return entity2desced


def connected(entity2desced, head, tail):
    if head in entity2desced:
        decedents = entity2desced[head]
        for d in decedents:
            d_relation_, d_tail_ = d
            if d_tail_ == tail:
                return d_relation_
        return False
    else:
        return False


def search_closed_rel_paths(anchor_rdf, entity2desced, max_path_len=2):
    anchor_h, anchor_r, anchor_t = parse_rdf(anchor_rdf)
    visited = set()
    rules = []

    def dfs(current, rel_path):
        if len(rel_path) > max_path_len:  # max path length
            return
        if current == anchor_t and len(rel_path) == 1 and rel_path[-1] == anchor_r:  # remove directly connected
            return
        if current == anchor_t:
            rule = "|".join(rel_path)
            if rule not in rules:
                rules.append(rule)
        else:
            visited.add(current)
            if current in entity2desced:
                deced_list = entity2desced[current]
                for r, t in deced_list:
                    if t not in visited:
                        dfs(t, rel_path + [r])
            visited.remove(current)

    dfs(anchor_h, [])
    return rules

def body2idx(body_list, head_rdict):
    """
    Input a rule (string) and idx it
    """
    res = []
    for body in body_list:
        body_path = body.split("|")
        # indexs include body idx seq + notation + head idx
        indexs = []
        for rel in body_path:
            indexs.append(head_rdict.rel2idx[rel])
        res.append(indexs)
    return res


def inv_rel_idx(head_rdict):
    inv_rel_idx = []
    for i_ in range(len(head_rdict.idx2rel)):
        r_ = head_rdict.idx2rel[i_]
        if "inv_" in r_:
            inv_rel_idx.append(i_)
    return inv_rel_idx


def idx2body(index, head_rdict):
    body = "|".join([head_rdict.idx2rel[idx] for idx in index])
    return body


def rule2idx(rule, head_rdict):
    """
    Input a rule (string) and idx it
    """
    body, head = rule.split("-")
    body_path = body.split("|")
    # indexs include body idx seq + notation + head idx
    indexs = []
    for rel in body_path + [-1, head]:
        indexs.append(head_rdict.rel2idx[rel] if rel != -1 else -1)
    return indexs


def idx2rule(index, head_rdict):
    body_idx = index[0:-2]
    body = "|".join([head_rdict.idx2rel[b] for b in body_idx])
    rule = body + "-" + head_rdict.idx2rel[index[-1]]
    return rule


def enumerate_body(relation_num, body_len, rdict):
    import itertools

    all_body_idx = list(list(x) for x in itertools.product(range(relation_num), repeat=body_len))
    # transfer index to relation name
    idx2rel = rdict.idx2rel
    all_body = []
    for b_idx_ in all_body_idx:
        b_ = [idx2rel[x] for x in b_idx_]
        all_body.append(b_)
    return all_body_idx, all_body


## Rule Learner

In [5]:
import os
import json
import itertools
import numpy as np
from collections import Counter

import copy
import re
import traceback



class Rule_Learner(object):
    def __init__(self, edges, id2relation, inv_relation_id, dataset):
        """
        Initialize rule learner object.

        Parameters:
            edges (dict): edges for each relation
            id2relation (dict): mapping of index to relation
            inv_relation_id (dict): mapping of relation to inverse relation
            dataset (str): dataset name

        Returns:
            None
        """

        self.edges = edges
        self.id2relation = id2relation
        self.inv_relation_id = inv_relation_id
        self.num_individual = 0
        self.num_shared = 0
        self.num_original = 0

        self.found_rules = []
        self.rule2confidence_dict = {}
        self.original_found_rules = []
        self.rules_dict = dict()
        self.output_dir = "./sampled_path/" + dataset + "/"
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def create_rule(self, walk, confidence=0, use_relax_time=False):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}
            confidence (float): confidence value
            use_relax_time (bool): whether the rule is created with relaxed time

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        if rule not in self.found_rules:
            self.found_rules.append(rule.copy())
            (
                rule["conf"],
                rule["rule_supp"],
                rule["body_supp"],
            ) = self.estimate_confidence(rule, is_relax_time=use_relax_time)

            rule["llm_confidence"] = confidence

            if rule["conf"] or confidence:
                self.update_rules_dict(rule)

    def create_rule_for_merge(
        self, walk, confidence=0, rule_without_confidence="", rules_var_dict=None, is_merge=False, is_relax_time=False
    ):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        if is_merge is True:
            if rules_var_dict.get(rule_without_confidence) is None:
                if rule not in self.found_rules:
                    self.found_rules.append(rule.copy())
                    (
                        rule["conf"],
                        rule["rule_supp"],
                        rule["body_supp"],
                    ) = self.estimate_confidence(rule)

                    rule["llm_confidence"] = confidence

                    if rule["conf"] or confidence:
                        self.num_individual += 1
                        self.update_rules_dict(rule)

            else:
                rule_var = rules_var_dict[rule_without_confidence]
                rule_var["llm_confidence"] = confidence
                temp_var = {}
                temp_var["head_rel"] = rule_var["head_rel"]
                temp_var["body_rels"] = rule_var["body_rels"]
                temp_var["var_constraints"] = rule_var["var_constraints"]
                if temp_var not in self.original_found_rules:
                    self.original_found_rules.append(temp_var.copy())
                    self.update_rules_dict(rule_var)
                    self.num_shared += 1
        else:
            if rule not in self.found_rules:
                self.found_rules.append(rule.copy())
                (
                    rule["conf"],
                    rule["rule_supp"],
                    rule["body_supp"],
                ) = self.estimate_confidence(rule, is_relax_time=is_relax_time)

                # if rule["body_supp"] == 0:
                #     rule["body_supp"] = 2

                rule["llm_confidence"] = confidence

                if rule["conf"] or confidence:
                    self.update_rules_dict(rule)

    def create_rule_for_merge_for_iteration(
        self, walk, llm_confidence=0, rule_without_confidence="", rules_var_dict=None, is_merge=False
    ):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        rule_with_confidence = ""

        if is_merge is True:
            if rules_var_dict.get(rule_without_confidence) is None:
                if rule not in self.found_rules:
                    self.found_rules.append(rule.copy())
                    (
                        rule["conf"],
                        rule["rule_supp"],
                        rule["body_supp"],
                    ) = self.estimate_confidence(rule)

                    tuple_key = str(rule)
                    self.rule2confidence_dict[tuple_key] = rule["conf"]
                    rule_with_confidence = rule_without_confidence + "&" + str(rule["conf"])

                    rule["llm_confidence"] = llm_confidence

                    if rule["conf"] or llm_confidence:
                        self.num_individual += 1
                        self.update_rules_dict(rule)
                else:
                    tuple_key = tuple(rule)
                    confidence = self.rule2confidence_dict[tuple_key]
                    rule_with_confidence = rule_without_confidence + "&" + confidence

            else:
                rule_var = rules_var_dict[rule_without_confidence]
                rule_var["llm_confidence"] = llm_confidence
                temp_var = {}
                temp_var["head_rel"] = rule_var["head_rel"]
                temp_var["body_rels"] = rule_var["body_rels"]
                temp_var["var_constraints"] = rule_var["var_constraints"]
                if temp_var not in self.original_found_rules:
                    self.original_found_rules.append(temp_var.copy())
                    self.update_rules_dict(rule_var)
                    self.num_shared += 1
        else:
            if rule not in self.found_rules:
                tuple_key = str(rule)
                self.found_rules.append(rule.copy())
                (
                    rule["conf"],
                    rule["rule_supp"],
                    rule["body_supp"],
                ) = self.estimate_confidence(rule)

                self.rule2confidence_dict[tuple_key] = rule["conf"]
                rule_with_confidence = rule_without_confidence + "&" + str(rule["conf"])

                if rule["body_supp"] == 0:
                    rule["body_supp"] = 2

                rule["llm_confidence"] = llm_confidence

                if rule["conf"] or llm_confidence:
                    self.update_rules_dict(rule)
            else:
                tuple_key = str(rule)
                confidence = self.rule2confidence_dict[tuple_key]
                rule_with_confidence = rule_without_confidence + "&" + str(confidence)

        return rule_with_confidence

    def define_var_constraints(self, entities):
        """
        Define variable constraints, i.e., state the indices of reoccurring entities in a walk.

        Parameters:
            entities (list): entities in the temporal walk

        Returns:
            var_constraints (list): list of indices for reoccurring entities
        """

        var_constraints = []
        for ent in set(entities):
            all_idx = [idx for idx, x in enumerate(entities) if x == ent]
            var_constraints.append(all_idx)
        var_constraints = [x for x in var_constraints if len(x) > 1]

        return sorted(var_constraints)

    def estimate_confidence(self, rule, num_samples=2000, is_relax_time=False):
        """
        Estimate the confidence of the rule by sampling bodies and checking the rule support.

        Parameters:
            rule (dict): rule
                         {"head_rel": int, "body_rels": list, "var_constraints": list}
            num_samples (int): number of samples

        Returns:
            confidence (float): confidence of the rule, rule_support/body_support
            rule_support (int): rule support
            body_support (int): body support
        """

        if any(body_rel not in self.edges for body_rel in rule["body_rels"]):
            return 0, 0, 0

        if rule["head_rel"] not in self.edges:
            return 0, 0, 0

        all_bodies = []
        for _ in range(num_samples):
            sample_successful, body_ents_tss = self.sample_body(
                rule["body_rels"], rule["var_constraints"], is_relax_time
            )
            if sample_successful:
                all_bodies.append(body_ents_tss)

        all_bodies.sort()
        unique_bodies = list(x for x, _ in itertools.groupby(all_bodies))
        body_support = len(unique_bodies)

        confidence, rule_support = 0, 0
        if body_support:
            rule_support = self.calculate_rule_support(unique_bodies, rule["head_rel"])
            confidence = round(rule_support / body_support, 6)

        return confidence, rule_support, body_support

    def sample_body(self, body_rels, var_constraints, use_relax_time=False):
        """
        Sample a walk according to the rule body.
        The sequence of timesteps should be non-decreasing.

        Parameters:
            body_rels (list): relations in the rule body
            var_constraints (list): variable constraints for the entities
            use_relax_time (bool): whether to use relaxed time sampling

        Returns:
            sample_successful (bool): if a body has been successfully sampled
            body_ents_tss (list): entities and timestamps (alternately entity and timestamp)
                                  of the sampled body
        """

        sample_successful = True
        body_ents_tss = []
        cur_rel = body_rels[0]
        rel_edges = self.edges[cur_rel]
        next_edge = rel_edges[np.random.choice(len(rel_edges))]
        cur_ts = next_edge[3]
        cur_node = next_edge[2]
        body_ents_tss.append(next_edge[0])
        body_ents_tss.append(cur_ts)
        body_ents_tss.append(cur_node)

        for cur_rel in body_rels[1:]:
            next_edges = self.edges[cur_rel]
            if use_relax_time:
                mask = next_edges[:, 0] == cur_node
            else:
                mask = (next_edges[:, 0] == cur_node) * (next_edges[:, 3] >= cur_ts)

            filtered_edges = next_edges[mask]

            if len(filtered_edges):
                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
                cur_ts = next_edge[3]
                cur_node = next_edge[2]
                body_ents_tss.append(cur_ts)
                body_ents_tss.append(cur_node)
            else:
                sample_successful = False
                break

        if sample_successful and var_constraints:
            # Check variable constraints
            body_var_constraints = self.define_var_constraints(body_ents_tss[::2])
            if body_var_constraints != var_constraints:
                sample_successful = False

        return sample_successful, body_ents_tss

    def calculate_rule_support(self, unique_bodies, head_rel):
        """
        Calculate the rule support. Check for each body if there is a timestamp
        (larger than the timestamps in the rule body) for which the rule head holds.

        Parameters:
            unique_bodies (list): bodies from self.sample_body
            head_rel (int): head relation

        Returns:
            rule_support (int): rule support
        """

        rule_support = 0
        try:
            head_rel_edges = self.edges[head_rel]
        except Exception as e:
            print(head_rel)
        for body in unique_bodies:
            mask = (
                (head_rel_edges[:, 0] == body[0])
                * (head_rel_edges[:, 2] == body[-1])
                * (head_rel_edges[:, 3] > body[-2])
            )

            if True in mask:
                rule_support += 1

        return rule_support

    def update_rules_dict(self, rule):
        """
        Update the rules if a new rule has been found.

        Parameters:
            rule (dict): generated rule from self.create_rule

        Returns:
            None
        """

        try:
            self.rules_dict[rule["head_rel"]].append(rule)
        except KeyError:
            self.rules_dict[rule["head_rel"]] = [rule]

    def sort_rules_dict(self):
        """
        Sort the found rules for each head relation by decreasing confidence.

        Parameters:
            None

        Returns:
            None
        """

        for rel in self.rules_dict:
            self.rules_dict[rel] = sorted(self.rules_dict[rel], key=lambda x: x["conf"], reverse=True)

    def save_rules(self, dt, rule_lengths, num_walks, transition_distr, seed):
        """
        Save all rules.

        Parameters:
            dt (str): time now
            rule_lengths (list): rule lengths
            num_walks (int): number of walks
            transition_distr (str): transition distribution
            seed (int): random seed

        Returns:
            None
        """

        rules_dict = {int(k): v for k, v in self.rules_dict.items()}
        filename = "{0}_r{1}_n{2}_{3}_s{4}_rules.json".format(dt, rule_lengths, num_walks, transition_distr, seed)
        filename = filename.replace(" ", "")
        with open(self.output_dir + filename, "w", encoding="utf-8") as fout:
            json.dump(rules_dict, fout)

    def save_rules_verbalized(self, dt, rule_lengths, num_walks, transition_distr, seed, rel2idx, relation_regex):
        """
        Save all rules in a human-readable format.

        Parameters:
            dt (str): time now
            rule_lengths (list): rule lengths
            num_walks (int): number of walks
            transition_distr (str): transition distribution
            seed (int): random seed

        Returns:
            None
        """

        output_original_dir = os.path.join(self.output_dir, "original/")
        os.makedirs(output_original_dir, exist_ok=True)

        rules_str, rules_var = self.verbalize_rules()
        save_json_data(rules_var, output_original_dir + "rules_var.json")

        filename = self.generate_filename(dt, rule_lengths, num_walks, transition_distr, seed, "rules.txt")
        write_to_file(rules_str, self.output_dir + filename)

        original_rule_txt = self.output_dir + filename
        remove_filename = self.generate_filename(
            dt, rule_lengths, num_walks, transition_distr, seed, "remove_rules.txt"
        )

        rule_id_content = self.remove_first_three_columns(self.output_dir + filename, self.output_dir + remove_filename)

        self.parse_and_save_rules(remove_filename, list(rel2idx.keys()), relation_regex, "closed_rel_paths.jsonl")
        self.parse_and_save_rules_with_names(
            remove_filename, rel2idx, relation_regex, "rules_name.json", rule_id_content
        )
        self.parse_and_save_rules_with_ids(rule_id_content, rel2idx, relation_regex, "rules_id.json")

        self.save_rule_name_with_confidence(
            original_rule_txt,
            relation_regex,
            self.output_dir + "relation_name_with_confidence.json",
            list(rel2idx.keys()),
        )

    def verbalize_rules(self):
        rules_str = ""
        rules_var = {}
        for rel in self.rules_dict:
            for rule in self.rules_dict[rel]:
                single_rule = verbalize_rule(rule, self.id2relation) + "\n"
                part = re.split(r"\s+", single_rule.strip())
                rule_with_confidence = f"{part[-1]}"
                rules_var[rule_with_confidence] = rule
                rules_str += single_rule
        return rules_str, rules_var

    def generate_filename(self, dt, rule_lengths, num_walks, transition_distr, seed, suffix):
        filename = f"{dt}_r{rule_lengths}_n{num_walks}_{transition_distr}_s{seed}_{suffix}"
        return filename.replace(" ", "")

    def remove_first_three_columns(self, input_path, output_path):
        rule_id_content = []
        with open(input_path, "r") as input_file, open(output_path, "w", encoding="utf-8") as output_file:
            for line in input_file:
                columns = line.split()
                new_line = " ".join(columns[3:])
                new_line_for_rule_id = " ".join(columns[3:]) + "&" + columns[0] + "\n"
                rule_id_content.append(new_line_for_rule_id)
                output_file.write(new_line + "\n")
        return rule_id_content

    def parse_and_save_rules(self, remove_filename, keys, relation_regex, output_filename):
        output_file_path = os.path.join(self.output_dir, output_filename)
        with open(self.output_dir + remove_filename, "r") as file:
            lines = file.readlines()
            converted_rules = parse_rules_for_path(lines, keys, relation_regex)
        with open(output_file_path, "w") as file:
            for head, paths in converted_rules.items():
                json.dump({"head": head, "paths": paths}, file)
                file.write("\n")
        print(f"Rules have been converted and saved to {output_file_path}")
        return converted_rules

    def parse_and_save_rules_with_names(
        self, remove_filename, rel2idx, relation_regex, output_filename, rule_id_content
    ):
        input_file_path = os.path.join(self.output_dir, remove_filename)
        output_file_path = os.path.join(self.output_dir, output_filename)
        with open(input_file_path, "r") as file:
            rules_content = file.readlines()
            rules_name_dict = parse_rules_for_name(rules_content, list(rel2idx.keys()), relation_regex)
        with open(output_file_path, "w") as file:
            json.dump(rules_name_dict, file, indent=4)
        print(f"Rules have been converted and saved to {output_file_path}")

    def parse_and_save_rules_with_ids(self, rule_id_content, rel2idx, relation_regex, output_filename):
        output_file_path = os.path.join(self.output_dir, output_filename)
        rules_id_dict = parse_rules_for_id(rule_id_content, rel2idx, relation_regex)
        with open(output_file_path, "w") as file:
            json.dump(rules_id_dict, file, indent=4)
        print(f"Rules have been converted and saved to {output_file_path}")

    def save_rule_name_with_confidence(self, file_path, relation_regex, out_file_path, relations):
        rules_dict = {}
        with open(file_path, "r") as fin:
            rules = fin.readlines()
            for rule in rules:
                # Split the string by spaces to get the columns
                columns = rule.split()

                # Extract the first and fourth columns
                first_column = columns[0]
                fourth_column = "".join(columns[3:])
                output = f"{fourth_column}&{first_column}"

                regrex_list = fourth_column.split("<-")
                match = re.search(relation_regex, regrex_list[0])
                if match:
                    head = match[1].strip()
                    if head not in relations:
                        raise ValueError(f"Not exist relation:{head}")
                else:
                    continue

                if head not in rules_dict:
                    rules_dict[head] = []
                rules_dict[head].append(output)
        save_json_data(rules_dict, out_file_path)


def parse_rules_for_path(lines, relations, relation_regex):
    converted_rules = {}
    for line in lines:
        rule = line.strip()
        if not rule:
            continue
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")

        head = ""
        body_list = []
        for idx, regrex_item in enumerate(regrex_list):
            match = re.search(relation_regex, regrex_item)
            if match:
                rel_name = match.group(1).strip()
                if rel_name not in relations:
                    raise ValueError(f"Not exist relation:{rel_name}")
                if idx == 0:
                    head = rel_name
                    paths = converted_rules.setdefault(head, [])
                else:
                    body_list.append(rel_name)

        path = "|".join(body_list)
        paths.append(path)

    return converted_rules


def parse_rules_for_name(lines, relations, relation_regex):
    rules_dict = {}
    for rule in lines:
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")
        match = re.search(relation_regex, regrex_list[0])
        if match:
            head = match[1].strip()
            if head not in relations:
                raise ValueError(f"Not exist relation:{head}")
        else:
            continue

        if head not in rules_dict:
            rules_dict[head] = []
        rules_dict[head].append(rule)

    return rules_dict


def parse_rules_for_id(rules, rel2idx, relation_regex):
    rules_dict = {}
    for rule in rules:
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")
        match = re.search(relation_regex, regrex_list[0])
        if match:
            head = match[1].strip()
            if head not in rel2idx:
                raise ValueError(f"Relation '{head}' not found in rel2idx")
        else:
            continue

        rule_id = rule2id(rule.rsplit("&", 1)[0], rel2idx, relation_regex)
        rule_id = rule_id + "&" + rule.rsplit("&", 1)[1].strip()
        rules_dict.setdefault(head, []).append(rule_id)
    return rules_dict


def rule2id(rule, relation2id, relation_regex):
    temp_rule = copy.deepcopy(rule)
    temp_rule = re.sub(r"\s*<-\s*", "&", temp_rule)
    temp_rule = temp_rule.split("&")
    rule2id_str = ""

    try:
        for idx, _ in enumerate(temp_rule):
            match = re.search(relation_regex, temp_rule[idx])
            rel_name = match[1].strip()
            subject = match[2].strip()
            object = match[3].strip()
            timestamp = match[4].strip()
            rel_id = relation2id[rel_name]
            full_id = f"{rel_id}({subject},{object},{timestamp})"
            if idx == 0:
                full_id = f"{full_id}<-"
            else:
                full_id = f"{full_id}&"

            rule2id_str += f"{full_id}"
    except KeyError as keyerror:
        # 捕获异常并打印调用栈信息
        traceback.print_exc()
        raise ValueError(f"KeyError: {keyerror}")

    except Exception as e:
        raise ValueError(f"An error occurred: {rule}")

    return rule2id_str[:-1]


def verbalize_rule(rule, id2relation):
    """
    Verbalize the rule to be in a human-readable format.

    Parameters:
        rule (dict): rule from Rule_Learner.create_rule
        id2relation (dict): mapping of index to relation

    Returns:
        rule_str (str): human-readable rule
    """

    if rule["var_constraints"]:
        var_constraints = rule["var_constraints"]
        constraints = [x for sublist in var_constraints for x in sublist]
        for i in range(len(rule["body_rels"]) + 1):
            if i not in constraints:
                var_constraints.append([i])
        var_constraints = sorted(var_constraints)
    else:
        var_constraints = [[x] for x in range(len(rule["body_rels"]) + 1)]

    rule_str = "{0:8.6f}  {1:4}  {2:4}  {3}(X0,X{4},T{5})<-"
    obj_idx = [idx for idx in range(len(var_constraints)) if len(rule["body_rels"]) in var_constraints[idx]][0]
    rule_str = rule_str.format(
        rule["conf"],
        rule["rule_supp"],
        rule["body_supp"],
        id2relation[rule["head_rel"]],
        obj_idx,
        len(rule["body_rels"]),
    )

    for i in range(len(rule["body_rels"])):
        sub_idx = [idx for idx in range(len(var_constraints)) if i in var_constraints[idx]][0]
        obj_idx = [idx for idx in range(len(var_constraints)) if i + 1 in var_constraints[idx]][0]
        rule_str += "{0}(X{1},X{2},T{3})&".format(id2relation[rule["body_rels"][i]], sub_idx, obj_idx, i)

    return rule_str[:-1]


def rules_statistics(rules_dict):
    """
    Show statistics of the rules.

    Parameters:
        rules_dict (dict): rules

    Returns:
        None
    """

    print("Number of relations with rules: ", len(rules_dict))  # Including inverse relations
    print("Total number of rules: ", sum([len(v) for k, v in rules_dict.items()]))

    lengths = []
    for rel in rules_dict:
        lengths += [len(x["body_rels"]) for x in rules_dict[rel]]
    rule_lengths = [(k, v) for k, v in Counter(lengths).items()]
    print("Number of rules by length: ", sorted(rule_lengths))


## Temporal Walk

In [6]:
import numpy as np
import pandas as pd


def initialize_temporal_walk(version_id, data, transition_distr):
    idx_map = {
        "all": np.array(data.train_idx.tolist() + data.valid_idx.tolist() + data.test_idx.tolist()),
        "train_valid": np.array(data.train_idx.tolist() + data.valid_idx.tolist()),
        "train": np.array(data.train_idx.tolist()),
        "test": np.array(data.test_idx.tolist()),
        "valid": np.array(data.valid_idx.tolist()),
    }
    return Temporal_Walk(idx_map[version_id], data.inv_relation_id, transition_distr)


class Temporal_Walk(object):
    def __init__(self, learn_data, inv_relation_id, transition_distr):
        """
        Initialize temporal random walk object.

        Parameters:
            learn_data (np.ndarray): data on which the rules should be learned
            inv_relation_id (dict): mapping of relation to inverse relation
            transition_distr (str): transition distribution
                                    "unif" - uniform distribution
                                    "exp"  - exponential distribution

        Returns:
            None
        """

        self.learn_data = learn_data
        self.inv_relation_id = inv_relation_id
        self.transition_distr = transition_distr
        self.neighbors = store_neighbors(learn_data)
        self.edges = store_edges(learn_data)

    def sample_start_edge(self, rel_idx):
        """
        Define start edge distribution.

        Parameters:
            rel_idx (int): relation index

        Returns:
            start_edge (np.ndarray): start edge
        """

        rel_edges = self.edges[rel_idx]
        start_edge = rel_edges[np.random.choice(len(rel_edges))]

        return start_edge

    def sample_next_edge(self, filtered_edges, cur_ts):
        """
        Define next edge distribution.

        Parameters:
            filtered_edges (np.ndarray): filtered (according to time) edges
            cur_ts (int): current timestamp

        Returns:
            next_edge (np.ndarray): next edge
        """

        if self.transition_distr == "unif":
            next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
        elif self.transition_distr == "exp":
            tss = filtered_edges[:, 3]
            prob = np.exp(tss - cur_ts)
            try:
                prob = prob / np.sum(prob)
                next_edge = filtered_edges[np.random.choice(range(len(filtered_edges)), p=prob)]
            except ValueError:  # All timestamps are far away
                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]

        return next_edge

    def transition_step(self, cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts=None):
        """
        Sample a neighboring edge given the current node and timestamp.
        In the second step (step == 1), the next timestamp should be smaller than the current timestamp.
        In the other steps, the next timestamp should be smaller than or equal to the current timestamp.
        In the last step (step == L-1), the edge should connect to the source of the walk (cyclic walk).
        It is not allowed to go back using the inverse edge.

        Parameters:
            cur_node (int): current node
            cur_ts (int): current timestamp
            prev_edge (np.ndarray): previous edge
            start_node (int): start node
            step (int): number of current step
            L (int): length of random walk
            target_cur_ts (int, optional): target current timestamp for relaxed time. Defaults to cur_ts.

        Returns:
            next_edge (np.ndarray): next edge
        """

        next_edges = self.neighbors[cur_node]
        if target_cur_ts is None:
            target_cur_ts = cur_ts

        if step == 1:  # The next timestamp should be smaller than the current timestamp
            filtered_edges = next_edges[next_edges[:, 3] < target_cur_ts]
        else:  # The next timestamp should be smaller than or equal to the current timestamp
            filtered_edges = next_edges[next_edges[:, 3] <= target_cur_ts]
            # Delete inverse edge
            inv_edge = [
                cur_node,
                self.inv_relation_id[prev_edge[1]],
                prev_edge[0],
                cur_ts,
            ]
            row_idx = np.where(np.all(filtered_edges == inv_edge, axis=1))
            filtered_edges = np.delete(filtered_edges, row_idx, axis=0)

        if step == L - 1:  # Find an edge that connects to the source of the walk
            filtered_edges = filtered_edges[filtered_edges[:, 2] == start_node]

        if len(filtered_edges):
            next_edge = self.sample_next_edge(filtered_edges, cur_ts)
        else:
            next_edge = []

        return next_edge

    def transition_step_with_relax_time(self, cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts):
        """
        Wrapper for transition_step with relaxed time handling.

        Parameters:
            cur_node (int): current node
            cur_ts (int): current timestamp
            prev_edge (np.ndarray): previous edge
            start_node (int): start node
            step (int): number of current step
            L (int): length of random walk
            target_cur_ts (int): target current timestamp for relaxed time

        Returns:
            next_edge (np.ndarray): next edge
        """
        return self.transition_step(cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts)

    def sample_walk(self, L, rel_idx, use_relax_time=False):
        """
        Try to sample a cyclic temporal random walk of length L (for a rule of length L-1).

        Parameters:
            L (int): length of random walk
            rel_idx (int): relation index
            use_relax_time (bool): whether to use relaxed time sampling

        Returns:
            walk_successful (bool): if a cyclic temporal random walk has been successfully sampled
            walk (dict): information about the walk (entities, relations, timestamps)
        """

        walk_successful = True
        walk = dict()
        prev_edge = self.sample_start_edge(rel_idx)
        start_node = prev_edge[0]
        cur_node = prev_edge[2]
        cur_ts = prev_edge[3]
        target_cur_ts = cur_ts
        walk["entities"] = [start_node, cur_node]
        walk["relations"] = [prev_edge[1]]
        walk["timestamps"] = [cur_ts]

        for step in range(1, L):
            if use_relax_time:
                next_edge = self.transition_step_with_relax_time(
                    cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts
                )
            else:
                next_edge = self.transition_step(cur_node, cur_ts, prev_edge, start_node, step, L)

            if len(next_edge):
                cur_node = next_edge[2]
                cur_ts = next_edge[3]
                walk["relations"].append(next_edge[1])
                walk["entities"].append(cur_node)
                walk["timestamps"].append(cur_ts)
                prev_edge = next_edge
            else:  # No valid neighbors (due to temporal or cyclic constraints)
                walk_successful = False
                break

        return walk_successful, walk


def store_neighbors(quads):
    """
    Store all neighbors (outgoing edges) for each node.

    Parameters:
        quads (np.ndarray): indices of quadruples

    Returns:
        neighbors (dict): neighbors for each node
    """

    # 将 quads 转换为 DataFrame
    df = pd.DataFrame(quads, columns=["head", "relation", "target", "timestamp"])

    # 按 'node' 列分组，并将每组转换为数组
    neighbors = {node: group.values for node, group in df.groupby("head")}

    return neighbors


def store_edges(quads):
    """
    Store all edges for each relation.

    Parameters:
        quads (np.ndarray): indices of quadruples

    Returns:
        edges (dict): edges for each relation
    """

    edges = dict()
    relations = list(set(quads[:, 1]))
    for rel in relations:
        edges[rel] = quads[quads[:, 1] == rel]

    return edges


## Grapher

In [7]:
import os.path

import numpy as np



class Grapher(object):
    def __init__(self, dataset_dir, args=None, test_mask=None):
        """
        Store information about the graph (train/valid/test set).
        Add corresponding inverse quadruples to the data.

        Parameters:
            dataset_dir (str): path to the graph dataset directory

        Returns:
            None
        """

        self.args = args
        self.dataset_dir = dataset_dir
        self.entity2id = load_json_data(os.path.join(self.dataset_dir, "entity2id.json"))
        self.relation2id_old = load_json_data(os.path.join(self.dataset_dir, "relation2id.json"))
        self.relation2id = self.relation2id_old.copy()
        counter = len(self.relation2id_old)
        for relation in self.relation2id_old:
            self.relation2id["inv_" + relation] = counter  # Inverse relation
            counter += 1
        self.ts2id = load_json_data(os.path.join(self.dataset_dir, "ts2id.json"))
        self.id2entity = dict([(v, k) for k, v in self.entity2id.items()])
        self.id2relation = dict([(v, k) for k, v in self.relation2id.items()])
        self.id2ts = dict([(v, k) for k, v in self.ts2id.items()])

        self.inv_relation_id = dict()
        num_relations = len(self.relation2id_old)
        for i in range(num_relations):
            self.inv_relation_id[i] = i + num_relations
        for i in range(num_relations, num_relations * 2):
            self.inv_relation_id[i] = i % num_relations

        self.train_idx = self.create_store("train.txt")
        self.valid_idx = self.create_store("valid.txt")
        self.test_idx = self.create_store("test.txt")

        if test_mask is not None:
            mask = (self.test_idx[:, 3] >= test_mask[0]) * (self.test_idx[:, 3] <= test_mask[1])
            self.test_idx = self.test_idx[mask]

        if self.args is not None:
            if self.args["bgkg"] == "all":
                self.all_idx = np.vstack((self.train_idx, self.valid_idx, self.test_idx))
            elif self.args["bgkg"] == "train":
                self.all_idx = self.train_idx
            elif self.args["bgkg"] == "valid":
                self.all_idx = self.valid_idx
            elif self.args["bgkg"] == "test":
                self.all_idx = self.test_idx
            elif self.args["bgkg"] == "train_valid":
                self.all_idx = np.vstack((self.train_idx, self.valid_idx))
            elif self.args["bgkg"] == "train_test":
                self.all_idx = np.vstack((self.train_idx, self.test_idx))
            elif self.args["bgkg"] == "valid_test":
                self.all_idx = np.vstack((self.valid_idx, self.test_idx))

        print("Grapher initialized.")

    def create_store(self, file):
        """
        Store the quadruples from the file as indices.
        The quadruples in the file should be in the format "subject\trelation\tobject\ttimestamp\n".

        Parameters:
            file (str): file name

        Returns:
            store_idx (np.ndarray): indices of quadruples
        """

        with open(os.path.join(self.dataset_dir, file), "r", encoding="utf-8") as f:
            quads = f.readlines()
        store = self.split_quads(quads)
        store_idx = self.map_to_idx(store)
        store_idx = self.add_inverses(store_idx)

        return store_idx

    def split_quads(self, quads):
        """
        Split quadruples into a list of strings.

        Parameters:
            quads (list): list of quadruples
                          Each quadruple has the form "subject\trelation\tobject\ttimestamp\n".

        Returns:
            split_q (list): list of quadruples
                            Each quadruple has the form [subject, relation, object, timestamp].
        """

        split_q = []
        for quad in quads:
            split_q.append(quad[:-1].split("\t"))

        return split_q

    def map_to_idx(self, quads):
        """
        Map quadruples to their indices.

        Parameters:
            quads (list): list of quadruples
                          Each quadruple has the form [subject, relation, object, timestamp].

        Returns:
            quads (np.ndarray): indices of quadruples
        """

        subs = [self.entity2id[x[0]] for x in quads]
        rels = [self.relation2id[x[1]] for x in quads]
        objs = [self.entity2id[x[2]] for x in quads]
        tss = [self.ts2id[x[3]] for x in quads]
        quads = np.column_stack((subs, rels, objs, tss))

        return quads

    def add_inverses(self, quads_idx):
        """
        Add the inverses of the quadruples as indices.

        Parameters:
            quads_idx (np.ndarray): indices of quadruples

        Returns:
            quads_idx (np.ndarray): indices of quadruples along with the indices of their inverses
        """

        subs = quads_idx[:, 2]
        rels = [self.inv_relation_id[x] for x in quads_idx[:, 1]]
        objs = quads_idx[:, 0]
        tss = quads_idx[:, 3]
        inv_quads_idx = np.column_stack((subs, rels, objs, tss))
        quads_idx = np.vstack((quads_idx, inv_quads_idx))

        return quads_idx


## Main

In [8]:
import argparse
import glob
import os.path
import time
import traceback
from difflib import get_close_matches

from tqdm import tqdm
from functools import partial


from multiprocessing.pool import ThreadPool




from concurrent.futures import ThreadPoolExecutor



def read_paths(path):
    results = []
    with open(path, "r") as f:
        for line in f:
            results.append(json.loads(line.strip()))
    return results


def build_prompt(head, prompt_dict):
    definition = prompt_dict["common"]["definition"].format(head=head)
    context = prompt_dict["common"]["context"].format(head=head)
    chain = prompt_dict["common"]["chain"]
    predict = prompt_dict["common"]["predict"].format(head=head)
    return_rules = prompt_dict["common"]["return"]
    return definition + context + chain, predict, return_rules


def build_prompt_for_zero(head, prompt_dict):
    context = prompt_dict["zero"]["context"].format(head=head)
    predict = prompt_dict["zero"]["predict"].format(head=head)
    return_rules = prompt_dict["zero"]["return"]
    return context, predict, return_rules


def build_prompt_based_high(head, candidate_rels, is_zero, args, prompt_dict):
    # head = clean_symbol_in_rel(head)
    chain_defination = prompt_dict["chain_defination_for_high"].format(head=head)

    context = prompt_dict["iteration_context_for_high"].format(head=head)

    high_quality_context = prompt_dict["example_for_high"]
    # predict = prompt_dict['interaction_finale_predict_for_high'].format(head=head, k=20)
    predict = prompt_dict["interaction_finale_predict_for_high"].format(head=head)
    return_rules = prompt_dict["return_for_high"]

    return chain_defination + context + high_quality_context, predict, return_rules


def build_prompt_based_low(head, candidate_rels, is_zero, args, prompt_dict):
    # head = clean_symbol_in_rel(head)
    # chain_defination = prompt_dict['iteration']['defination']['chain'].format(head=head)

    context = prompt_dict["iteration"]["context"].format(head=head)
    low_quality_context = prompt_dict["iteration"]["low_quality_example"]
    sampled_rules = prompt_dict["iteration"]["sampled_rules"]
    predict = prompt_dict["iteration"]["predict"].format(head=head)
    return_rules = prompt_dict["iteration"]["return"]
    return context + low_quality_context, sampled_rules, predict, return_rules


# def build_prompt_for_unknown(head, candidate_rels, is_zero, args, prompt_dict): # UNUSED
#     # head = clean_symbol_in_rel(head)
#     chain_defination = prompt_dict["chain_defination"].format(head=head)

#     context = prompt_dict["unknown_relation_context"].format(head=head)

#     predict = prompt_dict["unknown_relation_final_predict"].format(head=head)
#     return_rules = prompt_dict["unknown_relation_return"]
#     return chain_defination + context, predict, return_rules


def get_rule_format(head, path, kg_rules_path):
    kg_rules_dict = load_json_data(kg_rules_path)
    if kg_rules_dict is None:
        path_list = []
        # head = clean_symbol_in_rel(head)
        for p in path:
            context = f"{head}(X,Y) <-- "
            for i, r in enumerate(p.split("|")):
                # r = clean_symbol_in_rel(r)
                if i == 0:
                    first = "X"
                else:
                    first = f"Z_{i}"
                if i == len(p.split("|")) - 1:
                    last = "Y"
                else:
                    last = f"Z_{i + 1}"
                context += f"{r}({first}, {last}) & "
            context = context.strip(" & ")
            path_list.append(context)
        return path_list
    else:
        return kg_rules_dict[head]


def generate_rule(
    row, rdict, rule_path, kg_rules_path, model, args, relation_regex, similiary_rel_dict, prompt_info_dict
):
    relation2id = rdict.rel2idx
    head = row["head"]
    paths = row["paths"]
    head_id = relation2id[head]

    head_formate = head
    if args.is_rel_name is True:
        all_rels = list(relation2id.keys())
        candidate_rels = ", ".join(all_rels)
        head_formate = head
    else:
        all_rels = list(relation2id.values())
        str_list = [str(item) for item in all_rels]
        candidate_rels = ", ".join(str_list)
        head_formate = head_id

    # Raise an error if k=0 for zero-shot setting
    if args.k == 0 and args.is_zero:
        raise NotImplementedError(f"""Cannot implement for zero-shot(f=0) and generate zero(k=0) rules.""")
    # Build prompt excluding rules
    fixed_context, predict, return_rules = build_prompt(head_formate, prompt_info_dict)
    current_prompt = fixed_context + predict + return_rules

    if args.is_zero:  # For zero-shot setting
        with open(os.path.join(rule_path, f"{head}_zero_shot.query"), "w") as f:
            f.write(current_prompt + "\n")
            f.close()
        if not args.dry_run:
            response = query(current_prompt, model)
            with open(os.path.join(rule_path, f"{head}_zero_shot.txt"), "w") as f:
                f.write(response + "\n")
                f.close()
    else:  # For few-shot setting
        path_content_list = get_rule_format(head, paths, kg_rules_path)
        file_name = head.replace("/", "-")
        with (
            open(os.path.join(rule_path, f"{file_name}.txt"), "w") as rule_file,
            open(os.path.join(rule_path, f"{file_name}.query"), "w") as query_file,
        ):
            rule_file.write(f"Rule_head: {head}\n")
            for i in range(args.l):

                if args.select_with_confidence is True:
                    sorted_list = sorted(path_content_list, key=lambda x: float(x.split("&")[-1]), reverse=True)
                    # few_shot_samples = sorted_list[:args.f]
                    new_shot_samples = [item for item in sorted_list if float(item.split("&")[-1]) > 0.01]
                    if len(new_shot_samples) >= args.f:
                        few_shot_samples = new_shot_samples
                    else:
                        few_shot_samples = sorted_list[: args.f]
                else:
                    few_shot_samples = random.sample(path_content_list, min(args.f, len(path_content_list)))
                    relation_set = set()
                    for rule in few_shot_samples:
                        rule_body = rule.split("<-")[-1]
                        matches = re.findall(relation_regex, rule_body)
                        for match in matches:
                            relation = match[0]
                            relation_set.update([relation])

                    similiary_rel_set = set()
                    for rel_name in relation_set:
                        similiary_rel_set.update(similiary_rel_dict[rel_name])

                    condicate = similiary_rel_set.union(relation_set)

                    formatted_string = ";".join([f"{name}" for name in condicate])

                return_rules = return_rules.format(candidate_rels=formatted_string)

                temp_current_prompt = fixed_context + predict + return_rules

                few_shot_paths = check_prompt_length(temp_current_prompt, few_shot_samples, model)

                if not few_shot_paths:
                    raise ValueError("few_shot_paths is empty, head:{}".format(head))

                few_shot_paths = few_shot_paths + "\n\n"

                return_rules = "\n\n" + return_rules

                prompt = fixed_context + few_shot_paths + predict + return_rules
                query_file.write(f"Sample {i + 1} time: \n")
                query_file.write(prompt + "\n")
                if not args.dry_run:
                    response = model.generate_sentence(prompt)
                    if response is not None:
                        # tqdm.write("Response: \n{}".format(response))
                        rule_file.write(f"Sample {i + 1} time: \n")
                        rule_file.write(response + "\n")
                    else:
                        with open(os.path.join(rule_path, f"fail_{file_name}.txt"), "w") as fail_rule_file:
                            fail_rule_file.write(prompt + "\n")
                        break


def generate_rule_for_zero(head, rdict, rule_path, model, args, prompt_info_zero_dict):
    relation2id = rdict.rel2idx
    all_rels = list(relation2id.keys())

    fixed_context, predict, return_rules_template = build_prompt_for_zero(head, prompt_info_zero_dict)
    return_rules = return_rules_template.format(candidate_rels=all_rels)
    current_prompt = fixed_context + predict + return_rules

    # 정의 파일 경로
    query_file_path = os.path.join(rule_path, f"{head}.query")
    txt_file_path = os.path.join(rule_path, f"{head}.txt")
    fail_file_path = os.path.join(rule_path, f"fail_{head}.txt")

    try:
        with open(query_file_path, "w") as fout_zero_query, open(txt_file_path, "w") as fout_zero_txt:
            for i in range(args.l):
                entry = f"Sample {i + 1} time:\n"
                fout_zero_query.write(entry + current_prompt + "\n")

                response = query(current_prompt, model)
                if response:
                    fout_zero_txt.write(entry + response + "\n")
                else:
                    raise ValueError("Failed to generate response.")
    except ValueError as e:
        with open(fail_file_path, "w") as fail_rule_file:
            fail_rule_file.write(current_prompt + "\n")
        print(e)  # Optional: Handle the exception as needed


def extract_and_expand_relations(args, path_content_list, similiary_rel_dict, relation_regex):
    """
    제공된 규칙 샘플에서 일정 수의 규칙을 무작위로 추출하고 이러한 규칙의 관계 집합을 확장합니다.

    :param args: 명명 공간, f 매개변수를 포함하여 매개변수를 포함합니다.
    :param path_content_list: 규칙 목록을 포함합니다.
    :param similiary_rel_dict: 관계 이름을 키로 하고 해당 관계와 유사한 관계 집합을 값으로 하는 사전입니다.
    :param relation_regex: 규칙에서 관계를 추출하는 데 사용되는 정규 표현식입니다.
    :return: 원본 관계와 유사한 관계의 집합을 반환합니다.
    """
    # 무작위로 샘플 추출
    few_shot_samples = random.sample(path_content_list, min(args.f, len(path_content_list)))

    # 추출된 샘플에서 관계 추출
    relation_set = set()
    for rule in few_shot_samples:
        rule_body = rule.split("<-")[-1]
        matches = re.findall(relation_regex, rule_body)
        for match in matches:
            relation = match[0]
            relation_set.update([relation])

    # 찾은 관계 집합 확장, 유사한 관계 포함
    similiary_rel_set = set()
    for rel_name in relation_set:
        similiary_rel_set.update(similiary_rel_dict[rel_name])

    # 원본 관계 집합과 유사한 관계 집합 합치기
    condicate = similiary_rel_set.union(relation_set)

    return condicate


def generate_rule_for_iteration_by_multi_thread(
    row,
    rdict,
    rule_path,
    kg_rules_path,
    model,
    args,
    relation_regex,
    similiary_rel_dict,
    kg_rules_path_with_valid,
    prompt_dict_low,
    prompt_dict_high,
):
    relation2id = rdict.rel2idx
    head = row["head"]
    rules = row["rules"]

    head_id = relation2id[head]

    valid_rules_name = load_json_data(kg_rules_path_with_valid)

    if args.is_rel_name is True:
        all_rels = list(relation2id.keys())
        candidate_rels = ", ".join(all_rels)
        head_formate = head
    else:
        all_rels = list(relation2id.values())
        str_list = [str(item) for item in all_rels]
        candidate_rels = ", ".join(str_list)
        head_formate = head_id

    # Raise an error if k=0 for zero-shot setting
    if args.k == 0 and args.is_zero:
        raise NotImplementedError(f"""Cannot implement for zero-shot(f=0) and generate zero(k=0) rules.""")

    if args.based_rule_type == "low":
        # Build prompt excluding rules
        fixed_context, sampled_rules, predict, return_rules = build_prompt_based_low(
            head_formate, candidate_rels, args.is_zero, args, prompt_dict_low
        )
    else:
        fixed_context, predict, return_rules = build_prompt_based_high(
            head_formate, candidate_rels, args.is_zero, args, prompt_dict_high
        )

    kg_rules_dict = load_json_data(kg_rules_path)
    path_content_list = kg_rules_dict.get(head, None)
    file_name = head.replace("/", "-")
    with (
        open(os.path.join(rule_path, f"{file_name}.txt"), "w") as rule_file,
        open(os.path.join(rule_path, f"{file_name}.query"), "w") as query_file,
    ):
        rule_file.write(f"Rule_head: {head}\n")
        for i in range(args.second):

            if path_content_list is not None:
                condicate = extract_and_expand_relations(args, path_content_list, similiary_rel_dict, relation_regex)
            else:
                condicate = set(all_rels)

            temp_valid_rules = valid_rules_name.get(head, "")
            valid_rules_with_head = random.sample(temp_valid_rules, min(20, len(temp_valid_rules)))

            temp_rules = random.sample(rules, min(20, len(rules)))
            quitity_string = "".join(temp_rules)
            quitity_string = quitity_string + "\n"

            valid_rules_string = "".join(valid_rules_with_head)
            valid_rules_string = valid_rules_string + "\n"

            temp_current_prompt = fixed_context + quitity_string + valid_rules_string + predict

            formatted_string = iteration_check_prompt_length(temp_current_prompt, list(condicate), return_rules, model)

            return_rules = return_rules.format(candidate_rels=formatted_string)

            prompt = temp_current_prompt + return_rules
            query_file.write(f"Sample {i + 1} time: \n")
            query_file.write(prompt + "\n")
            if not args.dry_run:
                response = model.generate_sentence(prompt)
                if response is not None:
                    # tqdm.write("Response: \n{}".format(response))
                    rule_file.write(f"Sample {i + 1} time: \n")
                    rule_file.write(response + "\n")
                else:
                    with open(os.path.join(rule_path, f"fail_{file_name}.txt"), "w") as fail_rule_file:
                        fail_rule_file.write(prompt + "\n")


# def generate_rule_for_unknown_relation_by_multi_thread( # UNUSED
#     row, rdict, rule_path, kg_rules_path, model, args, relation_subgraph, relation_regex, similiary_rel_dict
# ):
#     relation2id = rdict.rel2idx
#     head = row

#     head_id = relation2id[head]
#     # print("Head: ", head)

#     head_formate = head
#     if args.is_rel_name is True:
#         all_rels = list(relation2id.keys())
#         candidate_rels = ", ".join(all_rels)
#         head_formate = head
#     else:
#         all_rels = list(relation2id.values())
#         str_list = [str(item) for item in all_rels]
#         candidate_rels = ", ".join(str_list)
#         head_formate = head_id

#     # Build prompt excluding rules
#     fixed_context, predict, return_rules = build_prompt_for_unknown(
#         head_formate, candidate_rels, args.is_zero, args, prompt_dict
#     )

#     file_name = head.strip()
#     with (
#         open(os.path.join(rule_path, f"{file_name}.txt"), "w") as rule_file,
#         open(os.path.join(rule_path, f"{file_name}.query"), "w") as query_file,
#     ):
#         rule_file.write(f"Rule_head: {head}\n")
#         for i in range(args.second):
#             # # Convert list elements to the desired string format
#             # formatted_string = ';'.join([f'{name}' for name in condicate])

#             formatted_string = unknown_check_prompt_length(fixed_context + predict, all_rels, return_rules, model)

#             return_rules = return_rules.format(candidate_rels=formatted_string)

#             prompt = fixed_context + predict + return_rules
#             # tqdm.write("Prompt: \n{}".format(prompt))
#             query_file.write(f"Sample {i + 1} time: \n")
#             query_file.write(prompt + "\n")
#             if not args.dry_run:
#                 response = model.generate_sentence(prompt)
#                 if response is not None:
#                     # tqdm.write("Response: \n{}".format(response))
#                     rule_file.write(f"Sample {i + 1} time: \n")
#                     rule_file.write(response + "\n")
#                 else:
#                     with open(os.path.join(rule_path, f"fail_{file_name}.txt"), "w") as fail_rule_file:
#                         fail_rule_file.write(prompt + "\n")


def copy_files(source_dir, destination_dir, file_extension):
    # 목적 폴더 만들기
    if not os.path.exists(destination_dir):
        os.makedirs(destination_dir)

    # 소스 폴더의 파일 반복
    for filename in os.listdir(source_dir):
        # 파일 유형이 요구 사항에 부합하는지 확인
        if filename.endswith(file_extension):
            source_file = os.path.join(source_dir, filename)
            destination_file = os.path.join(destination_dir, filename)
            # 파일 복사
            shutil.copyfile(source_file, destination_file)


def process_rules_files(input_dir, output_dir, rdict, relation_regex, error_file_path):
    sum = 0
    with open(error_file_path, "w") as f_error_out:
        for input_filepath in glob.glob(os.path.join(input_dir, "*.txt")):
            file_name = input_filepath.split("/")[-1]
            if file_name.startswith("fail"):
                continue
            else:
                with open(input_filepath, "r") as fin, open(os.path.join(output_dir, file_name), "w") as fout:
                    rules = fin.readlines()
                    for idx, rule in enumerate(rules):
                        is_save = True
                        if rule.startswith("Rule_head:"):
                            continue
                        elif rule.startswith("Sample"):
                            continue
                        else:
                            rule_by_name = ""
                            temp_rule = re.sub(r"\s*<-\s*", "&", rule)
                            regrex_list = re.split(r"\s*&\s*|\t", temp_rule)
                            confidence = regrex_list[-1].strip()
                            for id, regrex in enumerate(regrex_list[:-1]):
                                match = re.search(relation_regex, regrex)
                                if match:
                                    if match[1].strip().isdigit():
                                        rel_id = int(match[1].strip())
                                        if rel_id not in list(rdict.idx2rel):
                                            print(f"Error relation id:{rel_id}, rule:{rule}")
                                            f_error_out.write(f"Error relation id:{rel_id}, rule:{rule}")
                                            sum = sum + 1
                                            is_save = False
                                            break

                                        relation_name = rdict.idx2rel[rel_id]
                                        subject = match[2].strip()
                                        object = match[3].strip()
                                        timestamp = match[4].strip()
                                        regrex_name = f"{relation_name}({subject},{object},{timestamp})"
                                        if id == 0:
                                            regrex_name += "<-"
                                        else:
                                            regrex_name += "&"
                                        rule_by_name += regrex_name
                                    else:
                                        print(f"Error relation id:{match[1].strip()}, rule:{rule}")
                                        f_error_out.write(f"Error relation id:{match[1].strip()}, rule:{rule}")
                                        sum = sum + 1
                                        is_save = False
                                        break

                                else:
                                    print(f"Error rule:{rule}, rule:{rule}")
                                    f_error_out.write(f"Error rule:{rule}, rule:{rule}")
                                    sum = sum + 1
                                    is_save = False
                                    break
                            if is_save:
                                rule_by_name += confidence
                                fout.write(rule_by_name + "\n")
        f_error_out.write(f"The number of error during id maps name is:{sum}")


def get_topk_similiary_rel(topk, similary_matrix, transformers_id2rel, transformers_rel2id):
    # 각 행에서 숫자가 가장 큰 앞에서 topk 요소의 인덱스 계산
    topk = -topk
    top_k_indices = np.argsort(similary_matrix, axis=1)[:, topk:]
    similiary_rel_dict = {}
    for idx, similary_rels in enumerate(top_k_indices):
        rel_name = transformers_id2rel[str(idx)]
        similary_rel_name = [transformers_id2rel[str(i)] for i in similary_rels]
        similiary_rel_dict[rel_name] = similary_rel_name

    return similiary_rel_dict


def get_low_conf(low_conf_file_path, relation_regex, rdict):
    rule_dict = {}
    with open(low_conf_file_path, "r") as fin_low:
        rules = fin_low.readlines()
        for rule in rules:
            if "index" in rule:
                continue
            regrex_list = rule.split("<-")
            match = re.search(relation_regex, regrex_list[0])
            if match:
                head = match[1].strip()
                if head not in list(rdict.rel2idx.keys()):
                    raise ValueError(f"Not exist relation:{head}")

                if head not in rule_dict:
                    rule_dict[head] = []
                rule_dict[head].append(rule)

    rule_list = []
    for key, value in rule_dict.items():
        rule_list.append({"head": key, "rules": value})

    return rule_list


def get_high_conf(high_conf_file_path, relation_regex, rdict):
    rule_dict = {}
    with open(high_conf_file_path, "r") as fin_low:
        rules = fin_low.readlines()
        for rule in rules:
            if "index" in rule:
                continue
            regrex_list = rule.split("<-")
            match = re.search(relation_regex, regrex_list[0])
            if match:
                head = match[1].strip()
                if head not in list(rdict.rel2idx.keys()):
                    raise ValueError(f"Not exist relation:{head}")

                if head not in rule_dict:
                    rule_dict[head] = []
                rule_dict[head].append(rule)

    rule_list = []
    for key, value in rule_dict.items():
        rule_list.append({"head": key, "rules": value})

    return rule_list


def analysis_data(confidence_folder, kg_rules_path):
    with (
        open(os.path.join(confidence_folder, "hight_conf.txt"), "r") as fin_hight,
        open(os.path.join(confidence_folder, "low_conf.txt"), "r") as fin_low,
    ):
        hight_rule_set = set()
        rules = fin_hight.readlines()
        for rule in rules:
            if "index" in rule:
                continue
            hight_rule_set.update([rule.strip()])

        low_rule_set = set()
        rules = fin_low.readlines()
        for rule in rules:
            if "index" in rule:
                continue
            low_rule_set.update([rule.strip()])

    rules_dict = load_json_data(kg_rules_path)

    all_rules = [item.strip() for sublist in rules_dict.values() for item in sublist]
    all_rules_set = set(all_rules)

    with open(os.path.join(confidence_folder, "statistic.txt"), "w") as fout_state:
        fout_state.write(f"valid_high:{len(hight_rule_set-all_rules_set)}\n")
        fout_state.write(f"valid_low:{len(low_rule_set-all_rules_set)}\n")


def load_data_and_paths(args):
    data_path = os.path.join(args.data_path, args.dataset) + "/"
    dataset = Dataset(data_root=data_path, inv=True)

    sampled_path_with_valid_dir = os.path.join(args.sampled_paths, args.dataset + "_valid")
    sampled_path_dir = os.path.join(args.sampled_paths, args.dataset)
    sampled_path = read_paths(os.path.join(sampled_path_dir, "closed_rel_paths.jsonl"))

    prompt_path = os.path.join(args.prompt_paths, "common.json")
    prompt_path_for_zero = os.path.join(args.prompt_paths, "zero.json")
    prompt_path_low = os.path.join(args.prompt_paths, "iteration_low.json")
    prompt_path_high = os.path.join(args.prompt_paths, "iteration_high.json")

    return (
        dataset,
        sampled_path,
        sampled_path_with_valid_dir,
        sampled_path_dir,
        prompt_path,
        prompt_path_for_zero,
        prompt_path_low,
        prompt_path_high,
    )


def prepare_rule_heads(dataset, sampled_path):
    rule_head_without_zero = {rule["head"] for rule in sampled_path}
    rule_head_with_zero = set(dataset.rdict.rel2idx.keys()) - rule_head_without_zero
    return rule_head_without_zero, rule_head_with_zero


def determine_kg_rules_path(args, sampled_path_dir):
    if args.is_rel_name:
        return os.path.join(sampled_path_dir, "rules_name.json")
    else:
        return os.path.join(sampled_path_dir, "rules_id.json")


def load_configuration(dataset, sampled_path_dir, args):
    constant_config = load_json_data("./Config/constant.json")
    relation_regex = constant_config["relation_regex"][args.dataset]

    rdict = dataset.get_relation_dict()
    similarity_matrix = np.load(os.path.join(sampled_path_dir, "matrix.npy"))
    transformers_id2rel = load_json_data(os.path.join(sampled_path_dir, "transfomers_id2rel.json"))
    transformers_rel2id = load_json_data(os.path.join(sampled_path_dir, "transfomers_rel2id.json"))

    similar_rel_dict = get_topk_similiary_rel(args.topk, similarity_matrix, transformers_id2rel, transformers_rel2id)

    return rdict, relation_regex, similar_rel_dict


def create_directories(args):
    rule_path = os.path.join(
        args.rule_path, args.dataset, f"{args.prefix}{args.model_name}-top-{args.k}-f-{args.f}-l-{args.l}"
    )
    os.makedirs(rule_path, exist_ok=True)

    filter_rule_path = os.path.join(
        args.rule_path, args.dataset, f"copy_{args.prefix}{args.model_name}-top-{args.k}-f-{args.f}-l-{args.l}"
    )
    if not os.path.exists(filter_rule_path):
        os.makedirs(filter_rule_path)
    else:
        clear_folder(filter_rule_path)

    return rule_path, filter_rule_path


def main(args, LLM):
    (
        dataset,
        sampled_path,
        sampled_path_with_valid_dir,
        sampled_path_dir,
        prompt_path,
        prompt_path_for_zero,
        prompt_path_low,
        prompt_path_high,
    ) = load_data_and_paths(args)
    rule_head_without_zero, rule_head_with_zero = prepare_rule_heads(dataset, sampled_path)
    kg_rules_path = determine_kg_rules_path(args, sampled_path_dir)
    rdict, relation_regex, similar_rel_dict = load_configuration(dataset, sampled_path_dir, args)
    rule_path, filter_rule_path = create_directories(args)

    prompt_info_dict = load_json_data(prompt_path)
    prompt_info_zero_dict = load_json_data(prompt_path_for_zero)
    prompt_dict_low = load_json_data(prompt_path_low)
    prompt_dict_high = load_json_data(prompt_path_high)

    model = LLM(args)
    print("Preparing pipeline for inference...")
    model.prepare_for_inference()

    llm_rule_generate(
        args,
        filter_rule_path,
        kg_rules_path,
        model,
        rdict,
        relation_regex,
        rule_path,
        sampled_path,
        similar_rel_dict,
        sampled_path_with_valid_dir,
        rule_head_with_zero,
        prompt_info_dict,
        prompt_info_zero_dict,
        prompt_dict_low,
        prompt_dict_high,
    )


def llm_rule_generate(
    args,
    filter_rule_path,
    kg_rules_path,
    model,
    rdict,
    relation_regex,
    rule_path,
    sampled_path,
    similiary_rel_dict,
    sampled_path_with_valid_dir,
    rule_head_with_zero,
    prompt_info_dict,
    prompt_info_zero_dict,
    prompt_dict_low,
    prompt_dict_high,
):
    # 규칙 생성
    with ThreadPool(args.n) as p:
        for _ in tqdm(
            p.imap_unordered(
                partial(
                    generate_rule,
                    rdict=rdict,
                    rule_path=rule_path,
                    kg_rules_path=kg_rules_path,
                    model=model,
                    args=args,
                    relation_regex=relation_regex,
                    similiary_rel_dict=similiary_rel_dict,
                    prompt_info_dict=prompt_info_dict,
                ),
                sampled_path,
            ),
            total=len(sampled_path),
        ):
            pass
    print("Generating rules for zero-shot...")

    with ThreadPool(args.n) as p:
        for _ in tqdm(
            p.imap_unordered(
                partial(
                    generate_rule_for_zero,
                    rdict=rdict,
                    rule_path=rule_path,
                    model=model,
                    args=args,
                    prompt_info_zero_dict=prompt_info_zero_dict,
                ),
                rule_head_with_zero,
            ),
            total=len(rule_head_with_zero),
        ):
            pass

    for input_filepath in tqdm(glob.glob(os.path.join(rule_path, "fail_*.txt")), desc="Generating rules for failed..."):
        filename = input_filepath.split("/")[-1].split("fail_")[-1]
        with open(input_filepath, "r") as fin, open(os.path.join(rule_path, filename), "w") as fout:
            content = fin.read()
            response = model.generate_sentence(content)
            if response is not None:
                fout.write(response + "\n")
            else:
                print(f"Error:{filename}")

    # valid 데이터셋의 규칙
    kg_rules_path_with_valid = os.path.join(sampled_path_with_valid_dir, "rules_name.json")

    # LLM 첫 번째 규칙 생성 결과 분석
    statistics_dir = os.path.join(
        args.rule_path,
        args.dataset,
        "statistics",
    )

    if not os.path.exists(statistics_dir):
        os.makedirs(statistics_dir)
    else:
        clear_folder(statistics_dir)

    statistics_file_path = os.path.join(statistics_dir, "statistics.txt")
    error_file_path = os.path.join(statistics_dir, "error.txt")

    constant_config = load_json_data("./Config/constant.json")
    relation_regex = constant_config["relation_regex"][args.dataset]
    if args.is_rel_name is True:
        copy_files(rule_path, filter_rule_path, "txt")
    else:
        process_rules_files(rule_path, filter_rule_path, rdict, relation_regex, error_file_path)

    model.gen_rule_statistic(rule_path, statistics_file_path)

    output_clean_folder = os.path.join(args.rule_path, args.dataset, "clean")
    if not os.path.exists(output_clean_folder):
        os.makedirs(output_clean_folder)
    else:
        clear_folder(output_clean_folder)

    output_filter_train_folder = os.path.join(args.rule_path, args.dataset, "filter", "train")
    if not os.path.exists(output_filter_train_folder):
        os.makedirs(output_filter_train_folder)
    else:
        clear_folder(output_filter_train_folder)

    # filter 폴더: 반복적으로 생성된 규칙 저장 (high conf 및 low conf 파일 포함)
    output_filter_eva_folder = os.path.join(args.rule_path, args.dataset, "filter", "eva")
    if not os.path.exists(output_filter_eva_folder):
        os.makedirs(output_filter_eva_folder)
    else:
        clear_folder(output_filter_eva_folder)

    output_filter_train_eva_folder = os.path.join(args.rule_path, args.dataset, "filter", "train_eva")
    if not os.path.exists(output_filter_train_eva_folder):
        os.makedirs(output_filter_train_eva_folder)
    else:
        clear_folder(output_filter_train_eva_folder)

    # evaluation 폴더: 각 반복에서 중간 생성된 confidence.json 기록
    output_eva_train_folder = os.path.join(args.rule_path, args.dataset, "evaluation", "train")
    if not os.path.exists(output_eva_train_folder):
        os.makedirs(output_eva_train_folder)
    else:
        clear_folder(output_eva_train_folder)

    output_eva_eva_folder = os.path.join(args.rule_path, args.dataset, "evaluation", "eva")
    if not os.path.exists(output_eva_eva_folder):
        os.makedirs(output_eva_eva_folder)
    else:
        clear_folder(output_eva_eva_folder)

    output_eva_train_eva_folder = os.path.join(args.rule_path, args.dataset, "evaluation", "train_eva")
    if not os.path.exists(output_eva_train_eva_folder):
        os.makedirs(output_eva_train_eva_folder)
    else:
        clear_folder(output_eva_train_eva_folder)

    # 임시 폴더: LLM 요청 및 응답 저장
    iteration_rule_file_path = os.path.join(args.rule_path, args.dataset, "iteration")
    if not os.path.exists(iteration_rule_file_path):
        os.makedirs(iteration_rule_file_path)
    else:
        clear_folder(iteration_rule_file_path)

    # 각 반복에서 LLM 응답 저장
    for i in range(args.num_iter + 1):
        temp_only_txt_file_path = os.path.join(args.rule_path, args.dataset, f"only_txt_{i}")
        clear_folder(temp_only_txt_file_path)

    # 임시 폴더: LLM 응답 저장
    iteration_only_txt_file_path = os.path.join(args.rule_path, args.dataset, "only_txt")
    if not os.path.exists(iteration_only_txt_file_path):
        os.makedirs(iteration_only_txt_file_path)
    else:
        clear_folder(iteration_only_txt_file_path)

    copy_folder_contents(filter_rule_path, iteration_only_txt_file_path)

    for i in range(args.num_iter):
        temp_only_txt_file_path = os.path.join(args.rule_path, args.dataset, f"only_txt_{i}")
        copy_folder_contents(iteration_only_txt_file_path, temp_only_txt_file_path)

        start_time = time.time()

        output_rules_folder_dir = clean(args, model, iteration_only_txt_file_path, output_clean_folder)
        conf_folder = None
        if args.bgkg == "train":
            train_rule_set = evaluation(args, output_rules_folder_dir, output_eva_train_folder, "train", index=i)
            filter_rules_based_confidence(train_rule_set, args.min_conf, output_filter_train_folder, i)
            conf_folder = output_filter_train_folder
        elif args.bgkg == "valid":
            env_rule_set = evaluation(args, output_rules_folder_dir, output_eva_eva_folder, "eva", index=i)
            filter_rules_based_confidence(env_rule_set, args.min_conf, output_filter_eva_folder, i)
            conf_folder = output_filter_eva_folder
        elif args.bgkg == "train_valid":
            train_env_rule_set = evaluation(
                args, output_rules_folder_dir, output_eva_train_eva_folder, "train_eva", index=i
            )
            filter_rules_based_confidence(train_env_rule_set, args.min_conf, output_filter_train_eva_folder, i)
            conf_folder = output_filter_train_eva_folder

        end_time = time.time()
        elapsed_time = end_time - start_time
        elapsed_minutes = elapsed_time / 60

        print(f"프로그램 실행 시간: {elapsed_minutes}분")

        if args.is_high is False:
            conf = get_low_conf(os.path.join(conf_folder, "temp_low_conf.txt"), relation_regex, rdict)
        else:
            conf = get_high_conf(os.path.join(conf_folder, "temp_hight_conf.txt"), relation_regex, rdict)

        clear_folder(iteration_only_txt_file_path)
        clear_folder(iteration_rule_file_path)
        gen_rules_iteration(
            args,
            kg_rules_path,
            model,
            rdict,
            relation_regex,
            iteration_rule_file_path,
            conf,
            similiary_rel_dict,
            kg_rules_path_with_valid,
            prompt_dict_low,
            prompt_dict_high,
        )
        copy_files(iteration_rule_file_path, iteration_only_txt_file_path, "txt")

    temp_only_txt_file_path = os.path.join(args.rule_path, args.dataset, f"only_txt_{args.num_iter}")
    copy_folder_contents(iteration_only_txt_file_path, temp_only_txt_file_path)

    output_rules_folder_dir = clean(args, model, iteration_only_txt_file_path, output_clean_folder)

    source_rule_path = None
    if args.bgkg == "train":
        train_rule_set = evaluation(
            args, output_rules_folder_dir, output_eva_train_folder, "train", index=args.num_iter
        )
        filter_rules_based_confidence(train_rule_set, args.min_conf, output_filter_train_folder, args.num_iter)
        analysis_data(output_filter_train_folder, kg_rules_path)
        source_rule_path = output_filter_train_folder
    elif args.bgkg == "valid":
        env_rule_set = evaluation(args, output_rules_folder_dir, output_eva_eva_folder, "eva", index=args.num_iter)
        filter_rules_based_confidence(env_rule_set, args.min_conf, output_filter_eva_folder, args.num_iter)
        analysis_data(output_filter_eva_folder, kg_rules_path)
        source_rule_path = output_filter_eva_folder
    elif args.bgkg == "train_valid":
        train_env_rule_set = evaluation(
            args, output_rules_folder_dir, output_eva_train_eva_folder, "train_eva", index=args.num_iter
        )
        filter_rules_based_confidence(train_env_rule_set, args.min_conf, output_filter_train_eva_folder, args.num_iter)
        analysis_data(output_filter_train_eva_folder, kg_rules_path)
        source_rule_path = output_filter_train_eva_folder

    final_sumary_file_path = os.path.join("gen_rules_iteration", args.dataset, "final_summary")
    os.makedirs(final_sumary_file_path, exist_ok=True)
    if args.rule_domain == "high":
        high_train_eva_file_path = os.path.join(source_rule_path, "hight_conf.txt")
        with open(high_train_eva_file_path, "r") as fin_high:
            high_unique_strings = set(fin_high.read().split())

        unique_strings = high_unique_strings

    elif args.rule_domain == "iteration":
        high_train_eva_file_path = os.path.join(source_rule_path, "hight_conf.txt")
        with open(high_train_eva_file_path, "r") as fin_high:
            high_unique_strings = set(fin_high.read().split())

        low_train_eva_file_path = os.path.join(source_rule_path, "low_conf.txt")
        with open(low_train_eva_file_path, "r") as fin_low:
            low_unique_strings = set(fin_low.read().split())

        unique_strings = low_unique_strings.union(high_unique_strings)

    else:
        pass

    with open(os.path.join(final_sumary_file_path, "rules.txt"), "w") as fout_final:
        for rule in unique_strings:
            fout_final.write(f"{rule}\n")


def gen_rules_iteration(
    args,
    kg_rules_path,
    model,
    rdict,
    relation_regex,
    rule_path,
    conf,
    similiary_rel_dict,
    kg_rules_path_with_valid,
    prompt_dict_low,
    prompt_dict_high,
):
    with ThreadPool(args.n) as p:
        for _ in tqdm(
            p.imap_unordered(
                partial(
                    generate_rule_for_iteration_by_multi_thread,
                    prompt_dict_low=prompt_dict_low,
                    prompt_dict_high=prompt_dict_high,
                    rdict=rdict,
                    rule_path=rule_path,
                    kg_rules_path=kg_rules_path,
                    model=model,
                    args=args,
                    relation_regex=relation_regex,
                    similiary_rel_dict=similiary_rel_dict,
                    kg_rules_path_with_valid=kg_rules_path_with_valid,
                ),
                conf,
            ),
            total=len(conf),
        ):
            pass


# def gen_rules_for_unknown_relation( # UNUSED
#     args, kg_rules_path, model, rdict, relation_regex, relation_subgraph, rule_path, low_conf, similiary_rel_dict
# ):
#     with ThreadPool(args.n) as p:
#         for _ in tqdm(
#             p.imap_unordered(
#                 partial(
#                     generate_rule_for_unknown_relation_by_multi_thread,
#                     rdict=rdict,
#                     rule_path=rule_path,
#                     kg_rules_path=kg_rules_path,
#                     model=model,
#                     args=args,
#                     relation_subgraph=relation_subgraph,
#                     relation_regex=relation_regex,
#                     similiary_rel_dict=similiary_rel_dict,
#                 ),
#                 low_conf,
#             ),
#             total=len(low_conf),
#         ):
#             pass


def filter_rules_based_confidence(rule_set, min_conf, output_folder, index):
    with (
        open(os.path.join(output_folder, "hight_conf.txt"), "a") as fout_hight,
        open(os.path.join(output_folder, "low_conf.txt"), "a") as fout_low,
        open(os.path.join(output_folder, "temp_hight_conf.txt"), "w") as fout_temp_hight,
        open(os.path.join(output_folder, "temp_low_conf.txt"), "w") as fout_temp_low,
    ):
        fout_hight.write(f"index:{index}\n")
        fout_low.write(f"index:{index}\n")
        for rule in rule_set:
            confidence = float(rule.split("&")[-1].strip())
            temp_rule = rule.split("&")[:-1]
            rule_without_confidence = "&".join(temp_rule)
            if confidence > min_conf:
                fout_hight.write(rule_without_confidence + "\n")
                fout_temp_hight.write(rule_without_confidence + "\n")
            else:
                fout_low.write(rule_without_confidence + "\n")
                fout_temp_low.write(rule_without_confidence + "\n")


def evaluation(args, output_rules_folder_dir, output_evaluation_folder, dataset_type, index=0):
    is_merge = args.is_merge
    dataset_dir = "./datasets/" + args.dataset + "/"
    data = Grapher(dataset_dir)

    if dataset_type == "train":
        temporal_walk = Temporal_Walk(np.array(data.train_idx.tolist()), data.inv_relation_id, args.transition_distr)
    elif dataset_type == "eva":
        temporal_walk = Temporal_Walk(np.array(data.valid_idx.tolist()), data.inv_relation_id, args.transition_distr)
    else:
        temporal_walk = Temporal_Walk(
            np.array(data.valid_idx.tolist() + data.train_idx.tolist()), data.inv_relation_id, args.transition_distr
        )

    rl = Rule_Learner(temporal_walk.edges, data.id2relation, data.inv_relation_id, args.dataset)
    rule_path = output_rules_folder_dir
    constant_config = load_json_data("./Config/constant.json")
    relation_regex = constant_config["relation_regex"][args.dataset]

    rules_var_path = os.path.join("sampled_path", args.dataset, "original", "rules_var.json")
    rules_var_dict = load_json_data(rules_var_path)

    if args.is_only_with_original_rules:
        for key, value in rules_var_dict.items():
            temp_var = {}
            temp_var["head_rel"] = value["head_rel"]
            temp_var["body_rels"] = value["body_rels"]
            temp_var["var_constraints"] = value["var_constraints"]
            if temp_var not in rl.original_found_rules:
                rl.original_found_rules.append(temp_var.copy())
                rl.update_rules_dict(value)
                rl.num_original += 1
    else:
        llm_gen_rules_list, fail_calc_confidence = calculate_confidence(
            rule_path,
            data.relation2id,
            data.inv_relation_id,
            rl,
            relation_regex,
            rules_var_dict,
            is_merge,
            is_has_confidence=False,
        )

    rules_statistics(rl.rules_dict)

    if args.is_only_with_original_rules:
        dir_path = output_evaluation_folder
        confidence_file_path = os.path.join(dir_path, "original_confidence.json")
        save_json_data(rl.rules_dict, confidence_file_path)
    else:
        if is_merge is True:
            original_rules_set = set(list(rules_var_dict.keys()))
            llm_gen_rules_set = set(llm_gen_rules_list)
            for idx, rule_chain in enumerate(original_rules_set - llm_gen_rules_set):
                rule = rules_var_dict[rule_chain]
                rl.update_rules_dict(rule)

            rules_statistics(rl.rules_dict)

            dir_path = output_evaluation_folder
            confidence_file_path = os.path.join(dir_path, "merge_confidence.json")
            save_json_data(rl.rules_dict, confidence_file_path)
        else:
            dir_path = output_evaluation_folder
            confidence_file_path = os.path.join(dir_path, f"{index}_confidence.json")
            save_json_data(rl.rules_dict, confidence_file_path)

            fail_confidence_file_path = os.path.join(dir_path, "fail_confidence.txt")
            with open(fail_confidence_file_path, "a") as fout:
                for fail_rule in fail_calc_confidence:
                    fout.write(f"{fail_rule}\n")

    return set(llm_gen_rules_list)


def calculate_confidence(
    rule_path, relation2id, inv_relation_id, rl, relation_regex, rules_var_dict, is_merge, is_has_confidence=False
):
    llm_gen_rules_list = []
    fail_calc_confidence = []
    for input_filepath in glob.glob(os.path.join(rule_path, "*_cleaned_rules.txt")):
        with open(input_filepath, "r") as f:
            rules = f.readlines()
            for i_, rule in enumerate(rules):
                try:
                    if is_has_confidence:
                        try:
                            confidence = float(rule.split("&")[-1].strip())
                            temp_rule = rule.split("&")[:-1]
                            rule_without_confidence = "&".join(temp_rule)
                            rule_without_confidence = rule_without_confidence.strip()
                            walk = get_walk(rule_without_confidence, relation2id, inv_relation_id, relation_regex)

                            rule_with_confidence = rl.create_rule_for_merge_for_iteration(
                                walk, confidence, rule_without_confidence, rules_var_dict, is_merge
                            )
                            llm_gen_rules_list.append(rule_with_confidence + "\n")
                        except Exception as e:
                            print(e)
                            fail_calc_confidence.append(rule + "\n")
                    else:
                        try:
                            confidence = 0
                            temp_rule = rule.split("&")
                            rule_without_confidence = "&".join(temp_rule)
                            rule_without_confidence = rule_without_confidence.strip()
                            walk = get_walk(rule_without_confidence, relation2id, inv_relation_id, relation_regex)
                            rule_with_confidence = rl.create_rule_for_merge_for_iteration(
                                walk, confidence, rule_without_confidence, rules_var_dict, is_merge
                            )
                            llm_gen_rules_list.append(rule_with_confidence + "\n")
                        except Exception as e:
                            print(e)
                            fail_calc_confidence.append(rule + "\n")

                except Exception as e:
                    print(f"Error processing rule: {rule}")
                    traceback.print_exc()  # 예외의 자세한 정보와 호출 스택 인쇄

    return llm_gen_rules_list, fail_calc_confidence


def process_rule(rule, relation2id, inv_relation_id, rl, relation_regex, rules_var_dict, is_merge, is_has_confidence):
    try:
        if is_has_confidence:
            confidence = float(rule.split("&")[-1].strip())
            temp_rule = rule.split("&")[:-1]
            rule_without_confidence = "&".join(temp_rule).strip()
            walk = get_walk(rule_without_confidence, relation2id, inv_relation_id, relation_regex)
            rule_with_confidence = rl.create_rule_for_merge_for_iteration(
                walk, confidence, rule_without_confidence, rules_var_dict, is_merge
            )
            return rule_with_confidence + "\n", None
        else:
            confidence = 0
            temp_rule = rule.split("&")
            rule_without_confidence = "&".join(temp_rule).strip()
            walk = get_walk(rule_without_confidence, relation2id, inv_relation_id, relation_regex)
            rule_with_confidence = rl.create_rule_for_merge_for_iteration(
                walk, confidence, rule_without_confidence, rules_var_dict, is_merge
            )
            return rule_with_confidence + "\n", None
    except Exception as e:
        return None, rule + "\n"


def calculate_confidence_O(
    rule_path, relation2id, inv_relation_id, rl, relation_regex, rules_var_dict, is_merge, is_has_confidence=False
):
    llm_gen_rules_list = []
    fail_calc_confidence = []

    with ThreadPoolExecutor(max_workers=12) as executor:
        futures = []
        for input_filepath in glob.glob(os.path.join(rule_path, "*_cleaned_rules.txt")):
            with open(input_filepath, "r") as f:
                rules = f.readlines()
                for rule in rules:
                    future = executor.submit(
                        process_rule,
                        rule,
                        relation2id,
                        inv_relation_id,
                        rl,
                        relation_regex,
                        rules_var_dict,
                        is_merge,
                        is_has_confidence,
                    )
                    futures.append(future)

        for future in futures:
            result, error = future.result()
            if result:
                llm_gen_rules_list.append(result)
            if error:
                fail_calc_confidence.append(error)

    return llm_gen_rules_list, fail_calc_confidence


def calculate_confidence_1(
    rule_path,
    relation2id,
    inv_relation_id,
    rl,
    relation_regex,
    rules_var_dict,
    is_merge,
    is_has_confidence=False,
    num_threads=10,
):
    llm_gen_rules_list = []
    fail_calc_confidence = []

    all_rules = []
    for input_filepath in glob.glob(os.path.join(rule_path, "*_cleaned_rules.txt")):
        with open(input_filepath, "r") as f:
            all_rules.extend(f.readlines())

    with ThreadPool(num_threads) as p:
        for result, error in tqdm(
            p.imap_unordered(
                partial(
                    process_rule,
                    relation2id=relation2id,
                    inv_relation_id=inv_relation_id,
                    rl=rl,
                    relation_regex=relation_regex,
                    rules_var_dict=rules_var_dict,
                    is_merge=is_merge,
                    is_has_confidence=is_has_confidence,
                ),
                all_rules,
            ),
            total=len(all_rules),
        ):
            if result:
                llm_gen_rules_list.append(result)
            if error:
                fail_calc_confidence.append(error)

    return llm_gen_rules_list, fail_calc_confidence


def get_walk(rule, relation2id, inv_relation_id, regex):
    head_body = rule.split("<-")
    rule_head_full_name = head_body[0].strip()
    condition_string = head_body[1].strip()

    # 정규 표현식 정의
    relation_regex = regex

    # 규칙 헤드의 관계, 주어와 목적어 추출
    match = re.search(relation_regex, rule_head_full_name)
    head_relation_name, head_subject, head_object, head_timestamp = match.groups()[:4]

    # 규칙 본문의 관계와 엔티티 추출
    matches = re.findall(relation_regex, condition_string)
    entities = (
        [head_object] + [match[1].strip() for match in matches[:-1]] + [matches[-1][1].strip(), matches[-1][2].strip()]
    )

    relation_ids = [relation2id[head_relation_name]] + [relation2id[match[0].strip()] for match in matches]

    # 첫 번째 요소를 제외한 목록 반전
    entities = entities[:1] + entities[1:][::-1]
    relation_ids = relation_ids[:1] + [inv_relation_id[x] for x in relation_ids[:0:-1]]

    # 결과 사전 구성
    result = {"entities": entities, "relations": relation_ids}

    return result


def clean(args, llm_model, filter_rule_path, output_folder):
    data_path = os.path.join(args.data_path, args.dataset) + "/"
    dataset = Dataset(data_root=data_path, inv=True)
    rdict = dataset.get_relation_dict()
    all_rels = list(rdict.rel2idx.keys())
    input_folder = filter_rule_path

    output_statistic_folder_dir = os.path.join(output_folder, "clean_statistics")
    if not os.path.exists(output_statistic_folder_dir):
        os.makedirs(output_statistic_folder_dir)

    output_rules_folder_dir = os.path.join(output_folder, "rules")
    if not os.path.exists(output_rules_folder_dir):
        os.makedirs(output_rules_folder_dir)
    else:
        clear_folder(output_rules_folder_dir)

    # 분석 clean 과정에서 success와 error 경우
    output_error_file_path = os.path.join(output_statistic_folder_dir, "error.txt")
    output_suc_file_path = os.path.join(output_statistic_folder_dir, "suc.txt")
    with open(output_error_file_path, "a") as fout_error, open(output_suc_file_path, "a") as fout_suc:
        num_error, num_suc = clean_processing(
            all_rels, args, fout_error, input_folder, llm_model, output_rules_folder_dir, fout_suc
        )
        fout_error.write(f"The number of cleaned rules is {num_error}\n")
        fout_suc.write(f"The number of retain rules is {num_suc}\n")

    return output_rules_folder_dir


def clean_processing(all_rels, args, fout_error, input_folder, llm_model, output_folder, fout_suc):
    constant_config = load_json_data("./Config/constant.json")
    rule_start_with_regex = constant_config["rule_start_with_regex"]
    replace_regex = constant_config["replace_regex"]
    relation_regex = constant_config["relation_regex"][args.dataset]
    num_error = 0
    num_suc = 0
    for filename in os.listdir(input_folder):
        if filename.endswith(".txt") and "query" not in filename and filename.startswith("fail") is False:
            input_filepath = os.path.join(input_folder, filename)
            name, ext = os.path.splitext(filename)
            summarized_filepath = os.path.join(output_folder, f"{name}_summarized_rules.txt")
            clean_filename = name + "_cleaned_rules.txt"
            clean_filepath = os.path.join(output_folder, clean_filename)

            if not args.clean_only:
                # Step 1: Summarize rules from the input file
                print("Start summarize: ", filename)
                # Summarize rules
                summarized_rules = summarize_rule(input_filepath, llm_model, args, rule_start_with_regex, replace_regex)
                print("write file", summarized_filepath)
                with open(summarized_filepath, "w") as f:
                    f.write("\n".join(summarized_rules))

            # Step 2: Clean summarized rules and keep format
            print(f"Clean file {summarized_filepath} with keeping the format")
            cleaned_rules, num, num_0 = clean_rules(summarized_filepath, all_rels, relation_regex, fout_error, fout_suc)
            num_error = num_error + num
            num_suc = num_suc + num_0

            if len(cleaned_rules) != 0:
                with open(clean_filepath, "w") as f:
                    f.write("\n".join(cleaned_rules))
    return num_error, num_suc


def extract_rules(content_list, rule_start_with_regex, replace_regex):
    """Extract the rules in the content without any explanation and the leading number if it has."""
    rule_pattern = re.compile(rule_start_with_regex)
    extracted_rules = [s.strip() for s in content_list if rule_pattern.match(s)]
    number_pattern = re.compile(replace_regex)
    cleaned_rules = [number_pattern.sub("", s) for s in extracted_rules]
    return list(set(cleaned_rules))  # Remove duplicates by converting to set and back to list


def summarize_rules_prompt(relname, k):
    """
    Generate prompt for the relation in the content_list
    """

    if k != 0:
        prompt = f'\n\nPlease identify the most important {k} rules from the following rules for the rule head: "{relname}(X,Y,T)". '
    else:  # k ==0
        prompt = f'\n\nPlease identify as many of the most important rules for the rule head: "{relname}(X,Y,T)" as possible. '

    prompt += (
        "You can summarize the rules that have similar meanings as one rule, if you think they are important. "
        "Return the rules only without any explanations. "
    )
    return prompt


def summarize_rule_for_unkown(file, llm_model, args, rule_start_with_regex, replace_regex):
    """
    Summarize the rules
    """
    with open(file, "r") as f:  # Load files
        content = f.read()
        rel_name = os.path.splitext(file)[0].split("/")[-1]

    content_list = content.split("\n")
    rule_list = extract_rules(
        content_list, rule_start_with_regex, replace_regex
    )  # Extract rules and remove any explanations
    if not args.force_summarize or llm_model is None:  # just return the whole rule_list
        return rule_list
    else:  # Do summarization and correct the spelling error
        summarize_prompt = summarize_rules_prompt(rel_name, args.k)
        summarize_prompt_len = num_tokens_from_message(summarize_prompt, args.model_name)
        list_of_rule_lists = shuffle_split_path_list(rule_list, summarize_prompt_len, args.model_name)
        response_list = []
        for rule_list in list_of_rule_lists:
            message = "\n".join(rule_list) + summarize_prompt
            print("prompt: ", message)
            response = query(message, llm_model)
            response_list.extend(response.split("\n"))
        response_rules = extract_rules(
            response_list, rule_start_with_regex, replace_regex
        )  # Extract rules and remove any explanations from summarized response

        return response_rules


def summarize_rule(file, llm_model, args, rule_start_with_regex, replace_regex):
    """
    Summarize the rules
    """
    with open(file, "r") as f:  # Load files
        content = f.read()
        rel_name = os.path.splitext(file)[0].split("/")[-1]

    content_list = content.split("\n")
    rule_list = extract_rules(
        content_list, rule_start_with_regex, replace_regex
    )  # Extract rules and remove any explanations
    if not args.force_summarize or llm_model is None:  # just return the whole rule_list
        return rule_list
    else:  # Do summarization and correct the spelling error
        summarize_prompt = summarize_rules_prompt(rel_name, args.k)
        summarize_prompt_len = num_tokens_from_message(summarize_prompt, args.model_name)
        list_of_rule_lists = shuffle_split_path_list(rule_list, summarize_prompt_len, args.model_name)
        response_list = []
        for rule_list in list_of_rule_lists:
            message = "\n".join(rule_list) + summarize_prompt
            print("prompt: ", message)
            response = query(message, llm_model)
            response_list.extend(response.split("\n"))
        response_rules = extract_rules(
            response_list, rule_start_with_regex, replace_regex
        )  # Extract rules and remove any explanations from summarized response

        return response_rules


def modify_process(temp_rule, relation_regex):
    regrex_list = temp_rule.split("&")
    for idx, regrex in enumerate(regrex_list):
        match = re.search(relation_regex, regrex)
        if match:
            relation_name = match[1].strip()
            subject = match[2].strip()
            object = match[3].strip()
            timestamp = match[4].strip()


def clean_rules_for_unknown(summarized_file_path, all_rels, relation_regex, fout_error, fout_suc):
    """
    Clean error rules and remove rules with error relation.
    """
    num_error = 0
    num_suc = 0
    with open(summarized_file_path, "r") as f:
        input_rules = [line.strip() for line in f]
    cleaned_rules = list()
    # Correct spelling error/grammar error for the relation in the rules and Remove rules with error relation.
    for input_rule in input_rules:
        if input_rule == "":
            continue
        rule_list = []
        temp_rule = re.sub(r"\s*<-\s*", "&", input_rule)
        regrex_list = temp_rule.split("&")
        last_subject = None
        final_object = None
        time_squeque = []
        final_time = None
        is_save = True
        is_check = True
        try:
            for idx, regrex in enumerate(regrex_list):
                match = re.search(relation_regex, regrex)
                if match:
                    relation_name = match[1].strip()
                    subject = match[2].strip()
                    object = match[3].strip()
                    timestamp = match[4].strip()

                    if timestamp[1:].isdigit() is False:
                        correct_rule = modify_process(temp_rule, relation_regex)
                        is_check = False
                        break

                    if relation_name not in all_rels:
                        best_match = get_close_matches(relation_name, all_rels, n=1)
                        if not best_match:
                            print(f"Cannot correctify this rule, head not in relation:{input_rule}\n")
                            fout_error.write(f"Cannot correctify this rule, head not in relation:{input_rule}\n")
                            is_save = False
                            num_error = num_error + 1
                            break
                        relation_name = best_match[0].strip()

                    rule_list.append(f"{relation_name}({subject},{object},{timestamp})")

                    if idx == 0:
                        head_subject = subject
                        head_object = object
                        head_subject = head_subject

                        last_subject = head_subject
                        final_object = head_object

                        final_time = int(timestamp[1:])
                    else:
                        if last_subject == subject:
                            last_subject = object
                        else:
                            print(f"Error: Rule {input_rule} does not conform to the definition of chain rule.")
                            fout_error.write(
                                f"Error: Rule {input_rule} does not conform to the definition of chain rule.\n"
                            )
                            num_error = num_error + 1
                            is_save = False
                            break

                        time_squeque.append(int(timestamp[1:]))

                    if idx == len(regrex_list) - 1:
                        if last_subject != final_object:
                            print(f"Error: Rule {input_rule} does not conform to the definition of chain rule.")
                            fout_error.write(
                                f"Error: Rule {input_rule} does not conform to the definition of chain rule.\n"
                            )
                            num_error = num_error + 1
                            is_save = False
                            break

                else:
                    print(f"Error: rule {input_rule}")
                    fout_error.write(f"Error: rule {input_rule}\n")
                    num_error = num_error + 1
                    is_save = False
                    break

            if is_check is True:
                if all(time_squeque[i] <= time_squeque[i + 1] for i in range(len(time_squeque) - 1)) is False:
                    print(f"Error: Rule {input_rule} time_squeque is error.")
                    fout_error.write(f"Error: Rule {input_rule} time_squeque is error.\n")
                    num_error = num_error + 1
                    is_save = False
                elif final_time < time_squeque[-1]:
                    print(f"Error: Rule {input_rule} time_squeque is error.")
                    fout_error.write(f"Error: Rule {input_rule} time_squeque is error.\n")
                    num_error = num_error + 1
                    is_save = False

            if is_save:
                correct_rule = "&".join(rule_list).strip().replace("&", "<-", 1)
                cleaned_rules.append(correct_rule)
                fout_suc.write(correct_rule + "\n")
                num_suc = num_suc + 1

        except Exception as e:
            print(f"Processing {input_rule} failed.\n Error: {str(e)}")
            fout_error.write(f"Processing {input_rule} failed.\n Error: {str(e)}\n")
            num_error = num_error + 1
    return cleaned_rules, num_error, num_suc


def clean_rules(summarized_file_path, all_rels, relation_regex, fout_error, fout_suc):
    """
    Clean error rules and remove rules with error relation.
    """
    num_error = 0
    num_suc = 0
    with open(summarized_file_path, "r") as f:
        input_rules = [line.strip() for line in f]
    cleaned_rules = list()
    # Correct spelling error/grammar error for the relation in the rules and Remove rules with error relation.
    for input_rule in input_rules:
        if input_rule == "":
            continue
        rule_list = []
        temp_rule = re.sub(r"\s*<-\s*", "&", input_rule)
        regrex_list = temp_rule.split("&")
        last_subject = None
        final_object = None
        time_squeque = []
        final_time = None
        is_save = True
        try:
            for idx, regrex in enumerate(regrex_list):
                match = re.search(relation_regex, regrex)
                if match:
                    relation_name = match[1].strip()
                    subject = match[2].strip()
                    object = match[3].strip()
                    timestamp = match[4].strip()

                    if timestamp[1:].isdigit() is False:
                        print(f"Error: Rule {input_rule}:{timestamp} is not digit")
                        fout_error.write(f"Error: Rule {input_rule}:{timestamp} is not digit\n")
                        num_error = num_error + 1
                        is_save = False
                        break

                    if relation_name not in all_rels:
                        best_match = get_close_matches(relation_name, all_rels, n=1)
                        if not best_match:
                            print(f"Cannot correctify this rule, head not in relation:{input_rule}\n")
                            fout_error.write(f"Cannot correctify this rule, head not in relation:{input_rule}\n")
                            is_save = False
                            num_error = num_error + 1
                            break
                        relation_name = best_match[0].strip()

                    rule_list.append(f"{relation_name}({subject},{object},{timestamp})")

                    if idx == 0:
                        head_subject = subject
                        head_object = object
                        head_subject = head_subject

                        last_subject = head_subject
                        final_object = head_object

                        final_time = int(timestamp[1:])
                    else:
                        if last_subject == subject:
                            last_subject = object
                        else:
                            print(f"Error: Rule {input_rule} does not conform to the definition of chain rule.")
                            fout_error.write(
                                f"Error: Rule {input_rule} does not conform to the definition of chain rule.\n"
                            )
                            num_error = num_error + 1
                            is_save = False
                            break

                        time_squeque.append(int(timestamp[1:]))

                    if idx == len(regrex_list) - 1:
                        if last_subject != final_object:
                            print(f"Error: Rule {input_rule} does not conform to the definition of chain rule.")
                            fout_error.write(
                                f"Error: Rule {input_rule} does not conform to the definition of chain rule.\n"
                            )
                            num_error = num_error + 1
                            is_save = False
                            break

                else:
                    print(f"Error: rule {input_rule}")
                    fout_error.write(f"Error: rule {input_rule}\n")
                    num_error = num_error + 1
                    is_save = False
                    break

            if all(time_squeque[i] <= time_squeque[i + 1] for i in range(len(time_squeque) - 1)) is False:
                print(f"Error: Rule {input_rule} time_squeque is error.")
                fout_error.write(f"Error: Rule {input_rule} time_squeque is error.\n")
                num_error = num_error + 1
                is_save = False
            elif final_time < time_squeque[-1]:
                print(f"Error: Rule {input_rule} time_squeque is error.")
                fout_error.write(f"Error: Rule {input_rule} time_squeque is error.\n")
                num_error = num_error + 1
                is_save = False

            if is_save:
                correct_rule = "&".join(rule_list).strip().replace("&", "<-", 1)
                cleaned_rules.append(correct_rule)
                fout_suc.write(correct_rule + "\n")
                num_suc = num_suc + 1

        except Exception as e:
            print(f"Processing {input_rule} failed.\n Error: {str(e)}")
            fout_error.write(f"Processing {input_rule} failed.\n Error: {str(e)}\n")
            num_error = num_error + 1
    return cleaned_rules, num_error, num_suc




## Execution

In [None]:
from dotenv import load_dotenv

load_dotenv()

def parse_arguments():
    parser = argparse.ArgumentParser(description="KGC rule generation parameters")
    parser.add_argument("--data_path", type=str, default="datasets", help="Data directory")
    parser.add_argument("--dataset", "-d", type=str, default="icews14", help="Dataset name")
    parser.add_argument("--sampled_paths", type=str, default="sampled_path", help="Sampled path directory")
    parser.add_argument("--prompt_paths", type=str, default="prompt", help="Sampled path directory")
    parser.add_argument("--rule_path", type=str, default="gen_rules_iteration", help="Path to rule file")
    parser.add_argument("--model_name", type=str, default="gpt-4.1-nano", help="Model name")
    parser.add_argument("--is_zero", action="store_true", help="Enable zero-shot rule generation")
    parser.add_argument("-k", type=int, default=0, help="Number of generated rules, 0 denotes as much as possible")
    parser.add_argument("-f", type=int, default=50, help="Few-shot number")
    parser.add_argument("-topk", type=int, default=20, help="Top-k paths")
    parser.add_argument("-n", type=int, default=10, help="Number of threads")  # 변경했음
    parser.add_argument("-l", type=int, default=5, help="Sample times for generating k rules")
    parser.add_argument("--prefix", type=str, default="", help="Prefix for files")
    parser.add_argument("--dry_run", action="store_true", help="Dry run mode")
    parser.add_argument("--is_rel_name", type=str_to_bool, default="yes", help="Enable relation names")
    parser.add_argument("--select_with_confidence", type=str_to_bool, default="no", help="Select with confidence")
    parser.add_argument("--clean_only", action="store_true", help="Load summarized rules and clean rules only")
    # parser.add_argument("--force_summarize", action="store_true", help="Force summarize rules")
    parser.add_argument("--is_merge", type=str_to_bool, default="no", help="Enable merge")
    parser.add_argument("--transition_distr", type=str, default="exp", help="Transition distribution")
    parser.add_argument("--is_only_with_original_rules", type=str_to_bool, default="no", help="Use only original rules")
    parser.add_argument("--is_high", type=str_to_bool, default="No", help="Enable high mode")
    parser.add_argument("--min_conf", type=float, default=0.01, help="Minimum confidence")
    parser.add_argument("--num_iter", type=int, default=2, help="Number of iterations")
    parser.add_argument("-second", type=int, default=3, help="Second sampling times for generating k rules")
    parser.add_argument(
        "--bgkg",
        type=str,
        default="valid",
        choices=["train", "train_valid", "valid", "test"],
        help="Background knowledge graph",
    )
    parser.add_argument("--based_rule_type", type=str, default="low", choices=["low", "high"], help="Base rule type")
    parser.add_argument(
        "--rule_domain", type=str, default="iteration", choices=["iteration", "high", "all"], help="Rule domain"
    )


    args, _ = parser.parse_known_args()
    return args, parser

registed_language_models = {
    'gpt-4.1-nano': ChatGPT,
}



def get_registed_model(model_name) -> BaseLanguageModel:
    for key, value in registed_language_models.items():
        if key in model_name.lower():
            return value
    raise ValueError(f"No registered model found for name {model_name}")


args, parser = parse_arguments()
LLM = get_registed_model(args.model_name)
LLM.add_args(parser)
args, _ = parser.parse_known_args()
main(args, LLM)


Request Time out. Retrying in 30 seconds...
