# Dynamic rule adaption (1) rule sampling

## Input
- files in /datasets

## Output
- files in /sampled_path



# Setup Requirements
- python 3.12


In [1]:
!pip install joblib>=1.5.0 \
    matplotlib \
    networkx==3.2.1 \
    numpy==1.26.3 \
    openai==1.79.0 \
    pandas==2.2.0 \
    python-dotenv==1.0.0 \
    scipy==1.12.0 \
    seaborn==0.13.1 \
    sentence-transformers>=3.0.1 \
    tiktoken==0.9.0 \
    torch==2.4.0 \
    tqdm==4.66.1 \
    transformers==4.35.2 \
    datasets

In [None]:
from dotenv import load_dotenv
load_dotenv()



### Helper functions

In [3]:

import torch
import time
import numpy as np


def load_json_data(file_path, default=None):
    try:
        if os.path.exists(file_path):
            print(f"Use cache from: {file_path}")
            with open(file_path, "r") as file:
                return json.load(file)
        else:
            print(f"File not found: {file_path}")
            return default
    except Exception as e:
        print(f"Error loading JSON data from {file_path}: {e}")
    return default


### Grapher


In [4]:
import os.path

import numpy as np



class Grapher(object):
    def __init__(self, dataset_dir, args=None, test_mask=None):
        """
        Store information about the graph (train/valid/test set).
        Add corresponding inverse quadruples to the data.

        Parameters:
            dataset_dir (str): path to the graph dataset directory

        Returns:
            None
        """

        self.args = args
        self.dataset_dir = dataset_dir
        self.entity2id = load_json_data(os.path.join(self.dataset_dir, "entity2id.json"))
        self.relation2id_old = load_json_data(os.path.join(self.dataset_dir, "relation2id.json"))
        self.relation2id = self.relation2id_old.copy()
        counter = len(self.relation2id_old)
        for relation in self.relation2id_old:
            self.relation2id["inv_" + relation] = counter  # Inverse relation
            counter += 1
        self.ts2id = load_json_data(os.path.join(self.dataset_dir, "ts2id.json"))
        self.id2entity = dict([(v, k) for k, v in self.entity2id.items()])
        self.id2relation = dict([(v, k) for k, v in self.relation2id.items()])
        self.id2ts = dict([(v, k) for k, v in self.ts2id.items()])

        self.inv_relation_id = dict()
        num_relations = len(self.relation2id_old)
        for i in range(num_relations):
            self.inv_relation_id[i] = i + num_relations
        for i in range(num_relations, num_relations * 2):
            self.inv_relation_id[i] = i % num_relations

        self.train_idx = self.create_store("train.txt")
        self.valid_idx = self.create_store("valid.txt")
        self.test_idx = self.create_store("test.txt")

        if test_mask is not None:
            mask = (self.test_idx[:, 3] >= test_mask[0]) * (self.test_idx[:, 3] <= test_mask[1])
            self.test_idx = self.test_idx[mask]

        if self.args is not None:
            if self.args["bgkg"] == "all":
                self.all_idx = np.vstack((self.train_idx, self.valid_idx, self.test_idx))
            elif self.args["bgkg"] == "train":
                self.all_idx = self.train_idx
            elif self.args["bgkg"] == "valid":
                self.all_idx = self.valid_idx
            elif self.args["bgkg"] == "test":
                self.all_idx = self.test_idx
            elif self.args["bgkg"] == "train_valid":
                self.all_idx = np.vstack((self.train_idx, self.valid_idx))
            elif self.args["bgkg"] == "train_test":
                self.all_idx = np.vstack((self.train_idx, self.test_idx))
            elif self.args["bgkg"] == "valid_test":
                self.all_idx = np.vstack((self.valid_idx, self.test_idx))

        print("Grapher initialized.")

    def create_store(self, file):
        """
        Store the quadruples from the file as indices.
        The quadruples in the file should be in the format "subject\trelation\tobject\ttimestamp\n".

        Parameters:
            file (str): file name

        Returns:
            store_idx (np.ndarray): indices of quadruples
        """

        with open(os.path.join(self.dataset_dir, file), "r", encoding="utf-8") as f:
            quads = f.readlines()
        store = self.split_quads(quads)
        store_idx = self.map_to_idx(store)
        store_idx = self.add_inverses(store_idx)

        return store_idx

    def split_quads(self, quads):
        """
        Split quadruples into a list of strings.

        Parameters:
            quads (list): list of quadruples
                          Each quadruple has the form "subject\trelation\tobject\ttimestamp\n".

        Returns:
            split_q (list): list of quadruples
                            Each quadruple has the form [subject, relation, object, timestamp].
        """

        split_q = []
        for quad in quads:
            split_q.append(quad[:-1].split("\t"))

        return split_q

    def map_to_idx(self, quads):
        """
        Map quadruples to their indices.

        Parameters:
            quads (list): list of quadruples
                          Each quadruple has the form [subject, relation, object, timestamp].

        Returns:
            quads (np.ndarray): indices of quadruples
        """

        subs = [self.entity2id[x[0]] for x in quads]
        rels = [self.relation2id[x[1]] for x in quads]
        objs = [self.entity2id[x[2]] for x in quads]
        tss = [self.ts2id[x[3]] for x in quads]
        quads = np.column_stack((subs, rels, objs, tss))

        return quads

    def add_inverses(self, quads_idx):
        """
        Add the inverses of the quadruples as indices.

        Parameters:
            quads_idx (np.ndarray): indices of quadruples

        Returns:
            quads_idx (np.ndarray): indices of quadruples along with the indices of their inverses
        """

        subs = quads_idx[:, 2]
        rels = [self.inv_relation_id[x] for x in quads_idx[:, 1]]
        objs = quads_idx[:, 0]
        tss = quads_idx[:, 3]
        inv_quads_idx = np.column_stack((subs, rels, objs, tss))
        quads_idx = np.vstack((quads_idx, inv_quads_idx))

        return quads_idx


### Temporal walk

In [5]:
import pandas as pd

def initialize_temporal_walk(version_id, data, transition_distr):
    idx_map = {
        "all": np.array(data.train_idx.tolist() + data.valid_idx.tolist() + data.test_idx.tolist()),
        "train_valid": np.array(data.train_idx.tolist() + data.valid_idx.tolist()),
        "train": np.array(data.train_idx.tolist()),
        "test": np.array(data.test_idx.tolist()),
        "valid": np.array(data.valid_idx.tolist()),
    }
    return Temporal_Walk(idx_map[version_id], data.inv_relation_id, transition_distr)

def store_neighbors(quads):
    """
    Store all neighbors (outgoing edges) for each node.

    Parameters:
        quads (np.ndarray): indices of quadruples

    Returns:
        neighbors (dict): neighbors for each node
    """

    # 将 quads 转换为 DataFrame
    df = pd.DataFrame(quads, columns=["head", "relation", "target", "timestamp"])

    # 按 'node' 列分组，并将每组转换为数组
    neighbors = {node: group.values for node, group in df.groupby("head")}

    return neighbors


def store_edges(quads):
    """
    Store all edges for each relation.

    Parameters:
        quads (np.ndarray): indices of quadruples

    Returns:
        edges (dict): edges for each relation
    """

    edges = dict()
    relations = list(set(quads[:, 1]))
    for rel in relations:
        edges[rel] = quads[quads[:, 1] == rel]

    return edges

class Temporal_Walk(object):
    def __init__(self, learn_data, inv_relation_id, transition_distr):
        """
        Initialize temporal random walk object.

        Parameters:
            learn_data (np.ndarray): data on which the rules should be learned
            inv_relation_id (dict): mapping of relation to inverse relation
            transition_distr (str): transition distribution
                                    "unif" - uniform distribution
                                    "exp"  - exponential distribution

        Returns:
            None
        """

        self.learn_data = learn_data
        self.inv_relation_id = inv_relation_id
        self.transition_distr = transition_distr
        self.neighbors = store_neighbors(learn_data)
        self.edges = store_edges(learn_data)

    def sample_start_edge(self, rel_idx):
        """
        Define start edge distribution.

        Parameters:
            rel_idx (int): relation index

        Returns:
            start_edge (np.ndarray): start edge
        """

        rel_edges = self.edges[rel_idx]
        start_edge = rel_edges[np.random.choice(len(rel_edges))]

        return start_edge

    def sample_next_edge(self, filtered_edges, cur_ts):
        """
        Define next edge distribution.

        Parameters:
            filtered_edges (np.ndarray): filtered (according to time) edges
            cur_ts (int): current timestamp

        Returns:
            next_edge (np.ndarray): next edge
        """

        if self.transition_distr == "unif":
            next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
        elif self.transition_distr == "exp":
            tss = filtered_edges[:, 3]
            prob = np.exp(tss - cur_ts)
            try:
                prob = prob / np.sum(prob)
                next_edge = filtered_edges[np.random.choice(range(len(filtered_edges)), p=prob)]
            except ValueError:  # All timestamps are far away
                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]

        return next_edge

    def transition_step(self, cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts=None):
        """
        Sample a neighboring edge given the current node and timestamp.
        In the second step (step == 1), the next timestamp should be smaller than the current timestamp.
        In the other steps, the next timestamp should be smaller than or equal to the current timestamp.
        In the last step (step == L-1), the edge should connect to the source of the walk (cyclic walk).
        It is not allowed to go back using the inverse edge.

        Parameters:
            cur_node (int): current node
            cur_ts (int): current timestamp
            prev_edge (np.ndarray): previous edge
            start_node (int): start node
            step (int): number of current step
            L (int): length of random walk
            target_cur_ts (int, optional): target current timestamp for relaxed time. Defaults to cur_ts.

        Returns:
            next_edge (np.ndarray): next edge
        """

        next_edges = self.neighbors[cur_node]
        if target_cur_ts is None:
            target_cur_ts = cur_ts

        if step == 1:  # The next timestamp should be smaller than the current timestamp
            filtered_edges = next_edges[next_edges[:, 3] < target_cur_ts]
        else:  # The next timestamp should be smaller than or equal to the current timestamp
            filtered_edges = next_edges[next_edges[:, 3] <= target_cur_ts]
            # Delete inverse edge
            inv_edge = [
                cur_node,
                self.inv_relation_id[prev_edge[1]],
                prev_edge[0],
                cur_ts,
            ]
            row_idx = np.where(np.all(filtered_edges == inv_edge, axis=1))
            filtered_edges = np.delete(filtered_edges, row_idx, axis=0)

        if step == L - 1:  # Find an edge that connects to the source of the walk
            filtered_edges = filtered_edges[filtered_edges[:, 2] == start_node]

        if len(filtered_edges):
            next_edge = self.sample_next_edge(filtered_edges, cur_ts)
        else:
            next_edge = []

        return next_edge

    def transition_step_with_relax_time(self, cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts):
        """
        Wrapper for transition_step with relaxed time handling.

        Parameters:
            cur_node (int): current node
            cur_ts (int): current timestamp
            prev_edge (np.ndarray): previous edge
            start_node (int): start node
            step (int): number of current step
            L (int): length of random walk
            target_cur_ts (int): target current timestamp for relaxed time

        Returns:
            next_edge (np.ndarray): next edge
        """
        return self.transition_step(cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts)

    def sample_walk(self, L, rel_idx, use_relax_time=False):
        """
        Try to sample a cyclic temporal random walk of length L (for a rule of length L-1).

        Parameters:
            L (int): length of random walk
            rel_idx (int): relation index
            use_relax_time (bool): whether to use relaxed time sampling

        Returns:
            walk_successful (bool): if a cyclic temporal random walk has been successfully sampled
            walk (dict): information about the walk (entities, relations, timestamps)
        """

        walk_successful = True
        walk = dict()
        prev_edge = self.sample_start_edge(rel_idx)
        start_node = prev_edge[0]
        cur_node = prev_edge[2]
        cur_ts = prev_edge[3]
        target_cur_ts = cur_ts
        walk["entities"] = [start_node, cur_node]
        walk["relations"] = [prev_edge[1]]
        walk["timestamps"] = [cur_ts]

        for step in range(1, L):
            if use_relax_time:
                next_edge = self.transition_step_with_relax_time(
                    cur_node, cur_ts, prev_edge, start_node, step, L, target_cur_ts
                )
            else:
                next_edge = self.transition_step(cur_node, cur_ts, prev_edge, start_node, step, L)

            if len(next_edge):
                cur_node = next_edge[2]
                cur_ts = next_edge[3]
                walk["relations"].append(next_edge[1])
                walk["entities"].append(cur_node)
                walk["timestamps"].append(cur_ts)
                prev_edge = next_edge
            else:  # No valid neighbors (due to temporal or cyclic constraints)
                walk_successful = False
                break

        return walk_successful, walk


### Rule learner 


In [6]:
import os
import json
import itertools
import numpy as np
from collections import Counter

import copy
import re
import traceback

def save_json_data(data, file_path):
    try:
        with open(file_path, "w", encoding="utf-8") as file:
            json.dump(data, file, indent=4, ensure_ascii=False)
        print(f"Data has been converted to JSON and saved to {file_path}")
    except Exception as e:
        print(f"Error saving JSON data to {file_path}: {e}")


def write_to_file(content, path):
    with open(path, "w", encoding="utf-8") as fout:
        fout.write(content)

class Rule_Learner(object):
    def __init__(self, edges, id2relation, inv_relation_id, dataset):
        """
        Initialize rule learner object.

        Parameters:
            edges (dict): edges for each relation
            id2relation (dict): mapping of index to relation
            inv_relation_id (dict): mapping of relation to inverse relation
            dataset (str): dataset name

        Returns:
            None
        """

        self.edges = edges
        self.id2relation = id2relation
        self.inv_relation_id = inv_relation_id
        self.num_individual = 0
        self.num_shared = 0
        self.num_original = 0

        self.found_rules = []
        self.rule2confidence_dict = {}
        self.original_found_rules = []
        self.rules_dict = dict()
        self.output_dir = "./sampled_path/" + dataset + "/"
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def create_rule(self, walk, confidence=0, use_relax_time=False):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}
            confidence (float): confidence value
            use_relax_time (bool): whether the rule is created with relaxed time

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        if rule not in self.found_rules:
            self.found_rules.append(rule.copy())
            (
                rule["conf"],
                rule["rule_supp"],
                rule["body_supp"],
            ) = self.estimate_confidence(rule, is_relax_time=use_relax_time)

            rule["llm_confidence"] = confidence

            if rule["conf"] or confidence:
                self.update_rules_dict(rule)

    def create_rule_for_merge(
        self, walk, confidence=0, rule_without_confidence="", rules_var_dict=None, is_merge=False, is_relax_time=False
    ):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        if is_merge is True:
            if rules_var_dict.get(rule_without_confidence) is None:
                if rule not in self.found_rules:
                    self.found_rules.append(rule.copy())
                    (
                        rule["conf"],
                        rule["rule_supp"],
                        rule["body_supp"],
                    ) = self.estimate_confidence(rule)

                    rule["llm_confidence"] = confidence

                    if rule["conf"] or confidence:
                        self.num_individual += 1
                        self.update_rules_dict(rule)

            else:
                rule_var = rules_var_dict[rule_without_confidence]
                rule_var["llm_confidence"] = confidence
                temp_var = {}
                temp_var["head_rel"] = rule_var["head_rel"]
                temp_var["body_rels"] = rule_var["body_rels"]
                temp_var["var_constraints"] = rule_var["var_constraints"]
                if temp_var not in self.original_found_rules:
                    self.original_found_rules.append(temp_var.copy())
                    self.update_rules_dict(rule_var)
                    self.num_shared += 1
        else:
            if rule not in self.found_rules:
                self.found_rules.append(rule.copy())
                (
                    rule["conf"],
                    rule["rule_supp"],
                    rule["body_supp"],
                ) = self.estimate_confidence(rule, is_relax_time=is_relax_time)

                # if rule["body_supp"] == 0:
                #     rule["body_supp"] = 2

                rule["llm_confidence"] = confidence

                if rule["conf"] or confidence:
                    self.update_rules_dict(rule)

    def create_rule_for_merge_for_iteration(
        self, walk, llm_confidence=0, rule_without_confidence="", rules_var_dict=None, is_merge=False
    ):
        """
        Create a rule given a cyclic temporal random walk.
        The rule contains information about head relation, body relations,
        variable constraints, confidence, rule support, and body support.
        A rule is a dictionary with the content
        {"head_rel": int, "body_rels": list, "var_constraints": list,
         "conf": float, "rule_supp": int, "body_supp": int}

        Parameters:
            walk (dict): cyclic temporal random walk
                         {"entities": list, "relations": list, "timestamps": list}

        Returns:
            rule (dict): created rule
        """

        rule = dict()
        rule["head_rel"] = int(walk["relations"][0])
        rule["body_rels"] = [self.inv_relation_id[x] for x in walk["relations"][1:][::-1]]
        rule["var_constraints"] = self.define_var_constraints(walk["entities"][1:][::-1])

        rule_with_confidence = ""

        if is_merge is True:
            if rules_var_dict.get(rule_without_confidence) is None:
                if rule not in self.found_rules:
                    self.found_rules.append(rule.copy())
                    (
                        rule["conf"],
                        rule["rule_supp"],
                        rule["body_supp"],
                    ) = self.estimate_confidence(rule)

                    tuple_key = str(rule)
                    self.rule2confidence_dict[tuple_key] = rule["conf"]
                    rule_with_confidence = rule_without_confidence + "&" + str(rule["conf"])

                    rule["llm_confidence"] = llm_confidence

                    if rule["conf"] or llm_confidence:
                        self.num_individual += 1
                        self.update_rules_dict(rule)
                else:
                    tuple_key = tuple(rule)
                    confidence = self.rule2confidence_dict[tuple_key]
                    rule_with_confidence = rule_without_confidence + "&" + confidence

            else:
                rule_var = rules_var_dict[rule_without_confidence]
                rule_var["llm_confidence"] = llm_confidence
                temp_var = {}
                temp_var["head_rel"] = rule_var["head_rel"]
                temp_var["body_rels"] = rule_var["body_rels"]
                temp_var["var_constraints"] = rule_var["var_constraints"]
                if temp_var not in self.original_found_rules:
                    self.original_found_rules.append(temp_var.copy())
                    self.update_rules_dict(rule_var)
                    self.num_shared += 1
        else:
            if rule not in self.found_rules:
                tuple_key = str(rule)
                self.found_rules.append(rule.copy())
                (
                    rule["conf"],
                    rule["rule_supp"],
                    rule["body_supp"],
                ) = self.estimate_confidence(rule)

                self.rule2confidence_dict[tuple_key] = rule["conf"]
                rule_with_confidence = rule_without_confidence + "&" + str(rule["conf"])

                if rule["body_supp"] == 0:
                    rule["body_supp"] = 2

                rule["llm_confidence"] = llm_confidence

                if rule["conf"] or llm_confidence:
                    self.update_rules_dict(rule)
            else:
                tuple_key = str(rule)
                confidence = self.rule2confidence_dict[tuple_key]
                rule_with_confidence = rule_without_confidence + "&" + str(confidence)

        return rule_with_confidence

    def define_var_constraints(self, entities):
        """
        Define variable constraints, i.e., state the indices of reoccurring entities in a walk.

        Parameters:
            entities (list): entities in the temporal walk

        Returns:
            var_constraints (list): list of indices for reoccurring entities
        """

        var_constraints = []
        for ent in set(entities):
            all_idx = [idx for idx, x in enumerate(entities) if x == ent]
            var_constraints.append(all_idx)
        var_constraints = [x for x in var_constraints if len(x) > 1]

        return sorted(var_constraints)

    def estimate_confidence(self, rule, num_samples=2000, is_relax_time=False):
        """
        Estimate the confidence of the rule by sampling bodies and checking the rule support.

        Parameters:
            rule (dict): rule
                         {"head_rel": int, "body_rels": list, "var_constraints": list}
            num_samples (int): number of samples

        Returns:
            confidence (float): confidence of the rule, rule_support/body_support
            rule_support (int): rule support
            body_support (int): body support
        """

        if any(body_rel not in self.edges for body_rel in rule["body_rels"]):
            return 0, 0, 0

        if rule["head_rel"] not in self.edges:
            return 0, 0, 0

        all_bodies = []
        for _ in range(num_samples):
            sample_successful, body_ents_tss = self.sample_body(
                rule["body_rels"], rule["var_constraints"], is_relax_time
            )
            if sample_successful:
                all_bodies.append(body_ents_tss)

        all_bodies.sort()
        unique_bodies = list(x for x, _ in itertools.groupby(all_bodies))
        body_support = len(unique_bodies)

        confidence, rule_support = 0, 0
        if body_support:
            rule_support = self.calculate_rule_support(unique_bodies, rule["head_rel"])
            confidence = round(rule_support / body_support, 6)

        return confidence, rule_support, body_support

    def sample_body(self, body_rels, var_constraints, use_relax_time=False):
        """
        Sample a walk according to the rule body.
        The sequence of timesteps should be non-decreasing.

        Parameters:
            body_rels (list): relations in the rule body
            var_constraints (list): variable constraints for the entities
            use_relax_time (bool): whether to use relaxed time sampling

        Returns:
            sample_successful (bool): if a body has been successfully sampled
            body_ents_tss (list): entities and timestamps (alternately entity and timestamp)
                                  of the sampled body
        """

        sample_successful = True
        body_ents_tss = []
        cur_rel = body_rels[0]
        rel_edges = self.edges[cur_rel]
        next_edge = rel_edges[np.random.choice(len(rel_edges))]
        cur_ts = next_edge[3]
        cur_node = next_edge[2]
        body_ents_tss.append(next_edge[0])
        body_ents_tss.append(cur_ts)
        body_ents_tss.append(cur_node)

        for cur_rel in body_rels[1:]:
            next_edges = self.edges[cur_rel]
            if use_relax_time:
                mask = next_edges[:, 0] == cur_node
            else:
                mask = (next_edges[:, 0] == cur_node) * (next_edges[:, 3] >= cur_ts)

            filtered_edges = next_edges[mask]

            if len(filtered_edges):
                next_edge = filtered_edges[np.random.choice(len(filtered_edges))]
                cur_ts = next_edge[3]
                cur_node = next_edge[2]
                body_ents_tss.append(cur_ts)
                body_ents_tss.append(cur_node)
            else:
                sample_successful = False
                break

        if sample_successful and var_constraints:
            # Check variable constraints
            body_var_constraints = self.define_var_constraints(body_ents_tss[::2])
            if body_var_constraints != var_constraints:
                sample_successful = False

        return sample_successful, body_ents_tss

    def calculate_rule_support(self, unique_bodies, head_rel):
        """
        Calculate the rule support. Check for each body if there is a timestamp
        (larger than the timestamps in the rule body) for which the rule head holds.

        Parameters:
            unique_bodies (list): bodies from self.sample_body
            head_rel (int): head relation

        Returns:
            rule_support (int): rule support
        """

        rule_support = 0
        try:
            head_rel_edges = self.edges[head_rel]
        except Exception as e:
            print(head_rel)
        for body in unique_bodies:
            mask = (
                (head_rel_edges[:, 0] == body[0])
                * (head_rel_edges[:, 2] == body[-1])
                * (head_rel_edges[:, 3] > body[-2])
            )

            if True in mask:
                rule_support += 1

        return rule_support

    def update_rules_dict(self, rule):
        """
        Update the rules if a new rule has been found.

        Parameters:
            rule (dict): generated rule from self.create_rule

        Returns:
            None
        """

        try:
            self.rules_dict[rule["head_rel"]].append(rule)
        except KeyError:
            self.rules_dict[rule["head_rel"]] = [rule]

    def sort_rules_dict(self):
        """
        Sort the found rules for each head relation by decreasing confidence.

        Parameters:
            None

        Returns:
            None
        """

        for rel in self.rules_dict:
            self.rules_dict[rel] = sorted(self.rules_dict[rel], key=lambda x: x["conf"], reverse=True)

    def save_rules(self, dt, rule_lengths, num_walks, transition_distr, seed):
        """
        Save all rules.

        Parameters:
            dt (str): time now
            rule_lengths (list): rule lengths
            num_walks (int): number of walks
            transition_distr (str): transition distribution
            seed (int): random seed

        Returns:
            None
        """

        rules_dict = {int(k): v for k, v in self.rules_dict.items()}
        filename = "{0}_r{1}_n{2}_{3}_s{4}_rules.json".format(dt, rule_lengths, num_walks, transition_distr, seed)
        filename = filename.replace(" ", "")
        with open(self.output_dir + filename, "w", encoding="utf-8") as fout:
            json.dump(rules_dict, fout)

    def save_rules_verbalized(self, dt, rule_lengths, num_walks, transition_distr, seed, rel2idx, relation_regex):
        """
        Save all rules in a human-readable format.

        Parameters:
            dt (str): time now
            rule_lengths (list): rule lengths
            num_walks (int): number of walks
            transition_distr (str): transition distribution
            seed (int): random seed

        Returns:
            None
        """

        output_original_dir = os.path.join(self.output_dir, "original/")
        os.makedirs(output_original_dir, exist_ok=True)

        rules_str, rules_var = self.verbalize_rules()
        save_json_data(rules_var, output_original_dir + "rules_var.json")

        filename = self.generate_filename(dt, rule_lengths, num_walks, transition_distr, seed, "rules.txt")
        write_to_file(rules_str, self.output_dir + filename)

        original_rule_txt = self.output_dir + filename
        remove_filename = self.generate_filename(
            dt, rule_lengths, num_walks, transition_distr, seed, "remove_rules.txt"
        )

        rule_id_content = self.remove_first_three_columns(self.output_dir + filename, self.output_dir + remove_filename)

        self.parse_and_save_rules(remove_filename, list(rel2idx.keys()), relation_regex, "closed_rel_paths.jsonl")
        self.parse_and_save_rules_with_names(
            remove_filename, rel2idx, relation_regex, "rules_name.json", rule_id_content
        )
        self.parse_and_save_rules_with_ids(rule_id_content, rel2idx, relation_regex, "rules_id.json")

        self.save_rule_name_with_confidence(
            original_rule_txt,
            relation_regex,
            self.output_dir + "relation_name_with_confidence.json",
            list(rel2idx.keys()),
        )

    def verbalize_rules(self):
        rules_str = ""
        rules_var = {}
        for rel in self.rules_dict:
            for rule in self.rules_dict[rel]:
                single_rule = verbalize_rule(rule, self.id2relation) + "\n"
                part = re.split(r"\s+", single_rule.strip())
                rule_with_confidence = f"{part[-1]}"
                rules_var[rule_with_confidence] = rule
                rules_str += single_rule
        return rules_str, rules_var

    def generate_filename(self, dt, rule_lengths, num_walks, transition_distr, seed, suffix):
        filename = f"{dt}_r{rule_lengths}_n{num_walks}_{transition_distr}_s{seed}_{suffix}"
        return filename.replace(" ", "")

    def remove_first_three_columns(self, input_path, output_path):
        rule_id_content = []
        with open(input_path, "r") as input_file, open(output_path, "w", encoding="utf-8") as output_file:
            for line in input_file:
                columns = line.split()
                new_line = " ".join(columns[3:])
                new_line_for_rule_id = " ".join(columns[3:]) + "&" + columns[0] + "\n"
                rule_id_content.append(new_line_for_rule_id)
                output_file.write(new_line + "\n")
        return rule_id_content

    def parse_and_save_rules(self, remove_filename, keys, relation_regex, output_filename):
        output_file_path = os.path.join(self.output_dir, output_filename)
        with open(self.output_dir + remove_filename, "r") as file:
            lines = file.readlines()
            converted_rules = parse_rules_for_path(lines, keys, relation_regex)
        with open(output_file_path, "w") as file:
            for head, paths in converted_rules.items():
                json.dump({"head": head, "paths": paths}, file)
                file.write("\n")
        print(f"Rules have been converted and saved to {output_file_path}")
        return converted_rules

    def parse_and_save_rules_with_names(
        self, remove_filename, rel2idx, relation_regex, output_filename, rule_id_content
    ):
        input_file_path = os.path.join(self.output_dir, remove_filename)
        output_file_path = os.path.join(self.output_dir, output_filename)
        with open(input_file_path, "r") as file:
            rules_content = file.readlines()
            rules_name_dict = parse_rules_for_name(rules_content, list(rel2idx.keys()), relation_regex)
        with open(output_file_path, "w") as file:
            json.dump(rules_name_dict, file, indent=4)
        print(f"Rules have been converted and saved to {output_file_path}")

    def parse_and_save_rules_with_ids(self, rule_id_content, rel2idx, relation_regex, output_filename):
        output_file_path = os.path.join(self.output_dir, output_filename)
        rules_id_dict = parse_rules_for_id(rule_id_content, rel2idx, relation_regex)
        with open(output_file_path, "w") as file:
            json.dump(rules_id_dict, file, indent=4)
        print(f"Rules have been converted and saved to {output_file_path}")

    def save_rule_name_with_confidence(self, file_path, relation_regex, out_file_path, relations):
        rules_dict = {}
        with open(file_path, "r") as fin:
            rules = fin.readlines()
            for rule in rules:
                # Split the string by spaces to get the columns
                columns = rule.split()

                # Extract the first and fourth columns
                first_column = columns[0]
                fourth_column = "".join(columns[3:])
                output = f"{fourth_column}&{first_column}"

                regrex_list = fourth_column.split("<-")
                match = re.search(relation_regex, regrex_list[0])
                if match:
                    head = match[1].strip()
                    if head not in relations:
                        raise ValueError(f"Not exist relation:{head}")
                else:
                    continue

                if head not in rules_dict:
                    rules_dict[head] = []
                rules_dict[head].append(output)
        save_json_data(rules_dict, out_file_path)


def parse_rules_for_path(lines, relations, relation_regex):
    converted_rules = {}
    for line in lines:
        rule = line.strip()
        if not rule:
            continue
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")

        head = ""
        body_list = []
        for idx, regrex_item in enumerate(regrex_list):
            match = re.search(relation_regex, regrex_item)
            if match:
                rel_name = match.group(1).strip()
                if rel_name not in relations:
                    raise ValueError(f"Not exist relation:{rel_name}")
                if idx == 0:
                    head = rel_name
                    paths = converted_rules.setdefault(head, [])
                else:
                    body_list.append(rel_name)

        path = "|".join(body_list)
        paths.append(path)

    return converted_rules


def parse_rules_for_name(lines, relations, relation_regex):
    rules_dict = {}
    for rule in lines:
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")
        match = re.search(relation_regex, regrex_list[0])
        if match:
            head = match[1].strip()
            if head not in relations:
                raise ValueError(f"Not exist relation:{head}")
        else:
            continue

        if head not in rules_dict:
            rules_dict[head] = []
        rules_dict[head].append(rule)

    return rules_dict


def parse_rules_for_id(rules, rel2idx, relation_regex):
    rules_dict = {}
    for rule in rules:
        temp_rule = re.sub(r"\s*<-\s*", "&", rule)
        regrex_list = temp_rule.split("&")
        match = re.search(relation_regex, regrex_list[0])
        if match:
            head = match[1].strip()
            if head not in rel2idx:
                raise ValueError(f"Relation '{head}' not found in rel2idx")
        else:
            continue

        rule_id = rule2id(rule.rsplit("&", 1)[0], rel2idx, relation_regex)
        rule_id = rule_id + "&" + rule.rsplit("&", 1)[1].strip()
        rules_dict.setdefault(head, []).append(rule_id)
    return rules_dict


def rule2id(rule, relation2id, relation_regex):
    temp_rule = copy.deepcopy(rule)
    temp_rule = re.sub(r"\s*<-\s*", "&", temp_rule)
    temp_rule = temp_rule.split("&")
    rule2id_str = ""

    try:
        for idx, _ in enumerate(temp_rule):
            match = re.search(relation_regex, temp_rule[idx])
            rel_name = match[1].strip()
            subject = match[2].strip()
            object = match[3].strip()
            timestamp = match[4].strip()
            rel_id = relation2id[rel_name]
            full_id = f"{rel_id}({subject},{object},{timestamp})"
            if idx == 0:
                full_id = f"{full_id}<-"
            else:
                full_id = f"{full_id}&"

            rule2id_str += f"{full_id}"
    except KeyError as keyerror:
        # 捕获异常并打印调用栈信息
        traceback.print_exc()
        raise ValueError(f"KeyError: {keyerror}")

    except Exception as e:
        raise ValueError(f"An error occurred: {rule}")

    return rule2id_str[:-1]


def verbalize_rule(rule, id2relation):
    """
    Verbalize the rule to be in a human-readable format.

    Parameters:
        rule (dict): rule from Rule_Learner.create_rule
        id2relation (dict): mapping of index to relation

    Returns:
        rule_str (str): human-readable rule
    """

    if rule["var_constraints"]:
        var_constraints = rule["var_constraints"]
        constraints = [x for sublist in var_constraints for x in sublist]
        for i in range(len(rule["body_rels"]) + 1):
            if i not in constraints:
                var_constraints.append([i])
        var_constraints = sorted(var_constraints)
    else:
        var_constraints = [[x] for x in range(len(rule["body_rels"]) + 1)]

    rule_str = "{0:8.6f}  {1:4}  {2:4}  {3}(X0,X{4},T{5})<-"
    obj_idx = [idx for idx in range(len(var_constraints)) if len(rule["body_rels"]) in var_constraints[idx]][0]
    rule_str = rule_str.format(
        rule["conf"],
        rule["rule_supp"],
        rule["body_supp"],
        id2relation[rule["head_rel"]],
        obj_idx,
        len(rule["body_rels"]),
    )

    for i in range(len(rule["body_rels"])):
        sub_idx = [idx for idx in range(len(var_constraints)) if i in var_constraints[idx]][0]
        obj_idx = [idx for idx in range(len(var_constraints)) if i + 1 in var_constraints[idx]][0]
        rule_str += "{0}(X{1},X{2},T{3})&".format(id2relation[rule["body_rels"][i]], sub_idx, obj_idx, i)

    return rule_str[:-1]


def rules_statistics(rules_dict):
    """
    Show statistics of the rules.

    Parameters:
        rules_dict (dict): rules

    Returns:
        None
    """

    print("Number of relations with rules: ", len(rules_dict))  # Including inverse relations
    print("Total number of rules: ", sum([len(v) for k, v in rules_dict.items()]))

    lengths = []
    for rel in rules_dict:
        lengths += [len(x["body_rels"]) for x in rules_dict[rel]]
    rule_lengths = [(k, v) for k, v in Counter(lengths).items()]
    print("Number of rules by length: ", sorted(rule_lengths))


In [None]:

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

def select_similary_relations(relation2id, output_dir):
    id2relation = dict([(v, k) for k, v in relation2id.items()])

    save_json_data(id2relation, os.path.join(output_dir, "transfomers_id2rel.json"))
    save_json_data(relation2id, os.path.join(output_dir, "transfomers_rel2id.json"))

    all_rels = list(relation2id.keys())
    # 사전 훈련된 모델 로드
    model = SentenceTransformer("bert-base-nli-mean-tokens")

    # 문장 정의
    sentences_A = all_rels
    sentences_B = all_rels

    # 모델을 사용하여 문장 인코딩
    embeddings_A = model.encode(sentences_A)
    embeddings_B = model.encode(sentences_B)

    # 문장 간 코사인 유사도 계산
    similarity_matrix = cosine_similarity(embeddings_A, embeddings_B)

    np.fill_diagonal(similarity_matrix, 0)

    np.save(os.path.join(output_dir, "matrix.npy"), similarity_matrix)



### Main function

In [8]:
from datetime import datetime

from joblib import Parallel, delayed

def main(parsed):
    dataset = parsed["dataset"]
    rule_lengths = parsed["max_path_len"]
    rule_lengths = (torch.arange(rule_lengths) + 1).tolist()
    num_walks = parsed["num_walks"]
    transition_distr = parsed["transition_distr"]
    num_processes = parsed["cores"]
    seed = parsed["seed"]
    version_id = parsed["version"]

    dataset_dir = "./datasets/" + dataset + "/"
    data = Grapher(dataset_dir)

    temporal_walk = initialize_temporal_walk(version_id, data, transition_distr)

    rl = Rule_Learner(temporal_walk.edges, data.id2relation, data.inv_relation_id, dataset)
    all_relations = sorted(temporal_walk.edges)  # 모든 관계에 대해 학습
    all_relations = [int(item) for item in all_relations]
    rel2idx = data.relation2id

    select_similary_relations(data.relation2id, rl.output_dir)

    constant_config = load_json_data("./Config/constant.json")
    relation_regex = constant_config["relation_regex"][dataset]

    def learn_rules(i, num_relations, use_relax_time=False):
        """
        선택적 완화 시간(멀티프로세싱 가능)으로 규칙을 학습합니다.

        매개변수:
            i (int): 프로세스 번호
            num_relations (int): 각 프로세스에 대한 최소 관계 수
            use_relax_time (bool): 샘플링 시 완화 시간 사용 여부

        반환값:
            rl.rules_dict (dict): 규칙 사전
        """

        set_seed_if_provided()
        relations_idx = calculate_relations_idx(i, num_relations)
        num_rules = [0]

        for k in relations_idx:
            rel = all_relations[k]
            for length in rule_lengths:
                it_start = time.time()
                process_rules_for_relation(rel, length, use_relax_time)
                it_end = time.time()
                it_time = round(it_end - it_start, 6)
                num_rules.append(sum([len(v) for k, v in rl.rules_dict.items()]) // 2)
                num_new_rules = num_rules[-1] - num_rules[-2]

                print(
                    f"Process {i}: relation {k - relations_idx[0] + 1}/{len(relations_idx)}, length {length}: {it_time} sec, {num_new_rules} rules"
                )

        return rl.rules_dict

    def set_seed_if_provided():
        if seed:
            np.random.seed(seed)

    def calculate_relations_idx(i, num_relations):
        if i < num_processes - 1:
            return range(i * num_relations, (i + 1) * num_relations)
        else:
            return range(i * num_relations, len(all_relations))

    def process_rules_for_relation(rel, length, use_relax_time):
        for _ in range(num_walks):
            walk_successful, walk = temporal_walk.sample_walk(length + 1, rel, use_relax_time)
            if walk_successful:
                rl.create_rule(walk, use_relax_time)

    start = time.time()
    num_relations = len(all_relations) // num_processes
    output = Parallel(n_jobs=num_processes)(
        delayed(learn_rules)(i, num_relations, parsed["is_relax_time"]) for i in range(num_processes)
    )
    end = time.time()
    all_graph = output[0]
    for i in range(1, num_processes):
        all_graph.update(output[i])

    total_time = round(end - start, 6)
    print("학습 완료: {} 초".format(total_time))

    rl.rules_dict = all_graph
    rl.sort_rules_dict()
    dt = datetime.now()
    dt = dt.strftime("%d%m%y%H%M%S")
    rl.save_rules(dt, rule_lengths, num_walks, transition_distr, seed)
    save_json_data(rl.rules_dict, rl.output_dir + "confidence.json")
    rules_statistics(rl.rules_dict)
    rl.save_rules_verbalized(dt, rule_lengths, num_walks, transition_distr, seed, rel2idx, relation_regex)




In [None]:
# uv run rule_sampler.py -d icews14 -m 3 -n 200 -p 16 -s 12 --is_relax_time No 


parsed = {
    "data_path": "datasets",
    "dataset": "icews14",
    "max_path_len": 3,
    "anchor": 5,
    "output_path": "sampled_path",
    "sparsity": 1,
    "cores": 20,
    "num_walks": 100,
    "transition_distr": "exp",
    "seed": None,
    "window": 0,
    "version": "train",
    "is_relax_time": "no",
}

main(parsed)