In [1]:
import pickle

class Dataset:
    def __init__(self):
        self.id2node = {}
        self.node2id = {}
        self.id2rel = {}
        self.rel2id = {}
        self.node2title = {}
        self.title2node = {}

    def load_key_value_files(self, filename):
        '''
        Load key-value pairs from a file
        Args:
            filename (str): The name of the file to load.
            file_format (str): The format of the file ('pkl' or 'txt').
        Returns:
            tuple: (id2value, value2id) where:
                id2value (dict): Dictionary mapping IDs to values.
                value2id (dict): Dictionary mapping values to IDs.
        '''
        id2value = {}
        value2id = {}
        file_format = filename.split('.')[-1]
        if file_format == 'pkl':
            with open(filename, 'rb') as f:
                id2value = pickle.load(f)
        elif file_format == 'txt':
            with open(filename, 'r') as f:
                id2value = {}
                for line in f:
                    id, value = line.strip().split('\t')
                    id2value[id] = value
        else:
            raise ValueError("Unsupported file format. Use 'pkl' or 'txt'.")
        value2id = {v: k for k, v in id2value.items()}
        return id2value, value2id
    
    def set_id2node(self, filename):
        id2node, node2id = self.load_key_value_files(filename)
        self.id2node = id2node
        self.node2id = node2id
        print(f"Loaded {len(self.id2node)} nodes from {filename}.")

    def get_node_by_id(self, node_id):
        return self.id2node.get(node_id, None)
    
    def get_id_by_node(self, node):
        return self.node2id.get(node, None)

    def set_id2rel(self, filename):
        id2rel, rel2id = self.load_key_value_files(filename)
        self.id2rel = id2rel
        self.rel2id = rel2id
        print(f"Loaded {len(self.id2rel)} relations from {filename}.")

    def get_relation_by_id(self, rel_id):
        return self.id2rel.get(rel_id, None)
    
    def get_id_by_relation(self, relation):
        return self.rel2id.get(relation, None)

    def set_node2title(self, filename):
        node2title, title2node = self.load_key_value_files(filename)
        self.node_to_title = node2title
        self.title2node = title2node
        print(f"Loaded {len(self.node_to_title)} node titles from {filename}.")


    def get_title_by_node(self, node):
        return self.node_to_title.get(node, None)
    
    def get_node_by_title(self, title):
        return self.title2node.get(title, None)
    
    def get_num_nodes(self):
        return len(self.id2node)

In [2]:
import numpy as np

class Node:
    def __init__(self, name: str, id: int, title: str):
        self.id = id
        self.name = name
        self.title = title

    def get_id(self):
        return self.id

    def get_name(self):
        return self.name
    
    def get_title(self):
        return self.title

class Edge:
    def __init__(self, name:str, id: int, head: Node, tail: Node):
        self.id = id
        self.name = name
        self.head = head
        self.tail = tail

    def get_id(self):
        return self.id
    
    def get_name(self):
        return self.name

    def get_head(self):
        return self.head

    def get_tail(self):
        return self.tail

class Graph:
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        self.edges = []
    
    def add_edge(self, head: str, relation: str, tail: str, skip_missing: bool = True, add_reverse: bool = True):
        head_id = self.dataset.get_id_by_node(head)
        tail_id = self.dataset.get_id_by_node(tail)
        relation_id = self.dataset.get_id_by_relation(relation)
        if head_id is None and skip_missing:
            print(f'Node {head} not found in dataset, skipping edge')
        elif tail_id is None and skip_missing:
            print(f'Node {tail} not found in dataset, skipping edge')
        elif relation_id is None and skip_missing:
            print(f'Relation {relation} not found in dataset, skipping edge')
        else:
            head_node = Node(head, head_id, self.dataset.get_title_by_node(head))
            tail_node = Node(tail, tail_id, self.dataset.get_title_by_node(tail))
            edge = Edge(relation, relation_id, head_node, tail_node)
            self.edges.append(edge)
            if add_reverse:
                reverse_relation = f'{relation}_reverse'
                reverse_relation_id = self.dataset.get_id_by_relation(reverse_relation)
                reverse_edge = Edge(reverse_relation, reverse_relation_id, tail_node, head_node)
                self.edges.append(reverse_edge)

    def load_triples(self, filename: str, skip_missing: bool = True, add_reverse: bool = True):
        try:
            with open(filename, 'r') as f:
                for line in f:
                    head, relation, tail = line.strip().split('\t')
                    self.add_edge(head, relation, tail, skip_missing, add_reverse)
        except FileNotFoundError:
            raise ValueError(f'File {filename} not found')
        except Exception as e:
            raise ValueError(f'Error loading triples from {filename}: {e}')
        
    def get_num_edges(self):
        return len(self.edges)
    
    def get_edges(self):
        return self.edges

    def get_num_nodes(self):
        return self.dataset.get_num_nodes()

In [3]:
dataset = Dataset()

data_dir = 'data/FB15k-237'
dataset.set_id2node(f'{data_dir}/ind2ent.pkl')
dataset.set_id2rel(f'{data_dir}/ind2rel.pkl')
dataset.set_node2title(f'{data_dir}/extra/entity2text.txt')

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.


In [4]:
graph_train = Graph(dataset)
graph_train.load_triples(f'{data_dir}/train.txt', skip_missing=False, add_reverse=True)
graph_train.get_num_nodes(), graph_train.get_num_edges()

(14505, 544230)

In [5]:
graph_valid = Graph(dataset)
# add training edges to validation graph
for edge in graph_train.get_edges():
    # we set add_reverse=False because it already exists in the training graph
    graph_valid.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
graph_valid.load_triples(f'{data_dir}/valid.txt', skip_missing=False, add_reverse=True)
graph_valid.get_num_nodes(), graph_valid.get_num_edges()

(14505, 579300)

In [6]:
graph_test = Graph(dataset)
# add training and validation edges to test graph (validation graph contains all training edges)
for edge in graph_valid.get_edges():
    # we set add_reverse=False because it already exists in the validation graph
    graph_test.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
graph_test.load_triples(f'{data_dir}/test.txt', skip_missing=False, add_reverse=True)
graph_test.get_num_nodes(), graph_test.get_num_edges()

(14505, 620232)

In [7]:
class Query:
    def __init__(self, query_type: str, query_answer: tuple):
        self.query_type = query_type
        if len(query_answer) != 2:
            raise ValueError("Query answer must be a tuple of (query, answer)")
        elif type(query_answer[1]) is not list:
            raise ValueError("Query answer must be a tuple of (query, answer) where answer is a list")
        self.query = query_answer[0]
        self.answer = query_answer[1]

    def get_query(self):
        return self.query
    
    def get_answer(self):
        return self.answer
    
    def __repr__(self):
        return f"Query(type={self.query_type}, query={self.query}, answer={self.answer})"
    
class QueryDataset:
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        self.queries = {}

    def add_query(self, query_type: str, query_answer: tuple):
        if query_type not in self.queries:
            self.queries[query_type] = []
        query = Query(query_type, query_answer)
        self.queries[query_type].append(query)

    def get_queries(self, query_type: str):
        if query_type not in self.queries:
            raise ValueError(f"No queries of type {query_type} found")
        return self.queries[query_type]
    
    def get_all_queries(self):
        all_queries = []
        for query_type, queries in self.queries.items():
            all_queries.extend(queries)
        return all_queries
    
    def get_num_queries(self):
        return sum(len(queries) for queries in self.queries.values())
    
    def get_num_queries_by_type(self, query_type: str):
        if query_type not in self.queries:
            return 0
        return len(self.queries[query_type])
    
    def load_queries_from_pkl(self, filename: str, query_type: str = ''):
        try:
            with open(filename, 'rb') as f:
                queries = pickle.load(f)
                for query, answer in queries.items():
                    answer = list(answer)
                    self.add_query(query_type, (query, answer))
        except FileNotFoundError:
            raise ValueError(f'File {filename} not found')
        except Exception as e:
            raise ValueError(f'Error loading queries from {filename}: {e}')
        
def human_readable(query: Query, dataset: Dataset):
    if query.query_type == '2p' or query.query_type == '3p':
        anchor = query.query[0][0]
        relations = query.query[0][1]
        rel1 = relations[0]
        rel2 = relations[1]
        if query.query_type =='3p':
            rel3 = relations[2]
        anchor_name = dataset.get_node_by_id(anchor)
        rel1_name = dataset.get_relation_by_id(rel1)
        rel2_name = dataset.get_relation_by_id(rel2)
        if query.query_type == '3p':
            rel3_name = dataset.get_relation_by_id(rel3)
        anchor_title = dataset.get_title_by_node(anchor_name)
        answers_titles = [dataset.get_title_by_node(dataset.get_node_by_id(a)) for a in query.answer]
        if query.query_type == '3p':
            print(f"Query:\n{anchor_title}\t--{rel1_name}-->\tV1")
            print(f"V1\t--{rel2_name}-->\tV2")
            print(f"V2\t--{rel3_name}-->\t?")
        else:
            print(f"Query:\n{anchor_title}\t--{rel1_name}-->\tV")
            print(f"V\t--{rel2_name}-->\t?")
        print(f"\nAnswer Set (?): \n{answers_titles}")
    elif query.query_type == '2u':
        query1 = query.query[0]
        query2 = query.query[1]
        anchor1 = query1[0]
        relation1 = query1[1][0]
        anchor2 = query2[0]
        relation2 = query2[1][0]
        anchor1_name = dataset.get_node_by_id(anchor1)
        anchor2_name = dataset.get_node_by_id(anchor2)
        rel1_name = dataset.get_relation_by_id(relation1)
        rel2_name = dataset.get_relation_by_id(relation2)
        anchor1_title = dataset.get_title_by_node(anchor1_name)
        anchor2_title = dataset.get_title_by_node(anchor2_name)
        answers_titles = [dataset.get_title_by_node(dataset.get_node_by_id(a)) for a in query.answer]
        print(f"Query:\n{anchor1_title}\t--{rel1_name}-->\tV1")
        print(f"{anchor2_title}\t--{rel2_name}-->\tV2")
        print(f"V1 or V2 --> ?")
        print(f"\nAnswer Set (?): \n{answers_titles}")

In [8]:
import pandas as pd

class SymbolicReasoning:
    def __init__(self, graph: Graph, logging: bool = True):
        self.graph = graph
        self.logging = logging

    def query_1p(self, head: int, relation: int):
        if self.logging:
            print(f"Querying for head: {self.graph.dataset.get_title_by_node(self.graph.dataset.get_node_by_id(head))} ({head} | {self.graph.dataset.get_node_by_id(head)}) and relation: {self.graph.dataset.get_relation_by_id(relation)} ({relation})")
        answers = []
        edges = self.graph.get_edges()
        for edge in edges:
            if edge.get_head().get_id() == head and edge.get_id() == relation:
                if self.logging:
                    print(f"Found edge: {edge.get_head().get_title()} --{edge.get_name()}--> {edge.get_tail().get_title()} ({edge.get_tail().get_id()})")
                answers.append(edge.get_tail().get_id())
        if self.logging:
            print("-" * 50)
        return list(set(answers))
    
    def query_2p(self, head: int, relations: tuple):
        first_level_answers = self.query_1p(head, relations[0])
        second_level_answers = {}
        for answer in first_level_answers:
            second_level_answers[answer] = self.query_1p(answer, relations[1])
        answers_set = set()
        for answer, second_level in second_level_answers.items():
            for item in second_level:
                answers_set.add(item)
        return second_level_answers, list(answers_set)
    
    def query_3p(self, head: int, relations: tuple):
        first_level_answers = self.query_2p(head, (relations[0], relations[1]))
        second_level_answers = {}
        for answer, second_level in first_level_answers[0].items():
            second_level_answers[answer] = self.query_1p(answer, relations[2])
        answers_set = set()
        for answer, second_level in second_level_answers.items():
            for item in second_level:
                answers_set.add(item)
        return second_level_answers, list(answers_set)
    
    def fixed_size_answer(self, answers: list, size: int):
        # make a dataframe which the index are answers and there is a column called score which the value is 1 for all answers
        array = np.full((len(answers), 1), 1)
        answers = np.array(answers)
        df = pd.DataFrame(array, index=answers, columns=['score'])

        if len(df) < size:
            # add random nodes to fill the size
            all_nodes = list(self.graph.dataset.id2node.keys())
            all_nodes_remaining = [node for node in all_nodes if node not in df.index]
            additional_nodes = np.random.choice(all_nodes_remaining, size - len(df), replace=False)
            additional_nodes = [int(node) for node in additional_nodes]
            # add them with score 0
            additional_df = pd.DataFrame(np.zeros((len(additional_nodes), 1)), index=additional_nodes, columns=['score'])
            if len(df) == 0:
                df = additional_df
            else:
                df = pd.concat([df, additional_df])
        elif len(df) > size:
            # truncate the dataframe to the size
            df = df.sample(size, replace=False)
        return df

In [9]:
dir_query_2p = 'data/FB15k-237/test_ans_2c.pkl'
dir_query_3p = 'data/FB15k-237/test_ans_3c.pkl'
dir_query_2i = 'data/FB15k-237/test_ans_2i.pkl'
dir_query_2u = 'data/FB15k-237/test_ans_2u.pkl'

query_dataset = QueryDataset(dataset)
query_dataset.load_queries_from_pkl(dir_query_2p, query_type='2p')
query_dataset.load_queries_from_pkl(dir_query_3p, query_type='3p')
query_dataset.load_queries_from_pkl(dir_query_2i, query_type='2i')
query_dataset.load_queries_from_pkl(dir_query_2u, query_type='2u')

In [10]:
sample_idx = 4000
sample_query_type = '2p'
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]


In [11]:
query_dataset.get_num_queries()

20000

In [12]:
dir_query_2p = 'data/FB15k-237/test_ans_2c_hard.pkl'
dir_query_3p = 'data/FB15k-237/test_ans_3c_hard.pkl'
dir_query_2i = 'data/FB15k-237/test_ans_2i_hard.pkl'
dir_query_2u = 'data/FB15k-237/test_ans_2u_hard.pkl'
query_dataset_hard = QueryDataset(dataset)
query_dataset_hard.load_queries_from_pkl(dir_query_2p, query_type='2p')
query_dataset_hard.load_queries_from_pkl(dir_query_3p, query_type='3p')
query_dataset_hard.load_queries_from_pkl(dir_query_2i, query_type='2i')
query_dataset_hard.load_queries_from_pkl(dir_query_2u, query_type='2u')
sample_query_type = '2p'

In [13]:
sample_idx = 4000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
human_readable(query_hard, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']


In [14]:
query_dataset_hard.get_num_queries()

20000

In [15]:
def accuracy(query: Query, answers: list):
    correct_answers = set(query.get_answer())
    predicted_answers = set(answers)
    if len(correct_answers) == 0:
        return 0.0
    return len(correct_answers.intersection(predicted_answers)) / len(correct_answers)

In [16]:
reasoner_train = SymbolicReasoning(graph_train)

sample_idx = 4000
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

middle_steps, answers_test = reasoner_train.query_2p(query.get_query()[0][0], query.get_query()[0][1])
print(f"Answers from test graph: {middle_steps}")
print(f"Final Answers: {answers_test}")
print(f"Expected Answers: {query.get_answer()}")
print(f"Accuracy: {accuracy(query, answers_test)}")

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]
Querying for head: Lamar Odom (12324 | /m/02_nkp) and relation: /education/educational_institution/students_graduates./education/education/student_reverse (45)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Rhode Island (4074)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Nevada, Las Vegas (9463)
--------------------------------------------------
Querying for head: University of Rhode Island (4074 | /m/02fjzt) and relation: /education/educational_degree/people_with_this_degree./education/education/i

In [17]:
reasoner_test = SymbolicReasoning(graph_test)

middle_steps, answers_test = reasoner_test.query_2p(query.get_query()[0][0], query.get_query()[0][1])
print(f"Answers from test graph: {middle_steps}")
print(f"Final Answers: {answers_test}")
print(f"Expected Answers: {query.get_answer()}")
print(f"Accuracy: {accuracy(query, answers_test)}")

Querying for head: Lamar Odom (12324 | /m/02_nkp) and relation: /education/educational_institution/students_graduates./education/education/student_reverse (45)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Rhode Island (4074)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Nevada, Las Vegas (9463)
--------------------------------------------------
Querying for head: University of Rhode Island (4074 | /m/02fjzt) and relation: /education/educational_degree/people_with_this_degree./education/education/institution_reverse (179)
Found edge: University of Rhode Island --/education/educational_degree/people_with_this_degree./education/education/institution_reverse--> PhD (587)
Found edge: University of Rhode Island --/education/educational_degree/people_with_this_degree./education/education/institution_reverse--> Bachelor o

In [18]:
reasoner_test.fixed_size_answer(answers_test, 10)

Unnamed: 0,score
706,1.0
587,1.0
3181,1.0
1177,1.0
1566,1.0
5495,0.0
5540,0.0
2795,0.0
11495,0.0
3617,0.0


In [19]:
reasoner_train = SymbolicReasoning(graph_train, logging=False)
reasoner_valid = SymbolicReasoning(graph_valid, logging=False)
reasoner_test = SymbolicReasoning(graph_test, logging=False)

answers_train = reasoner_train.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]
answers_valid = reasoner_valid.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]
answers_test = reasoner_test.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]

print(f"Train Accuracy: {accuracy(query, answers_train)}")
print(f"Valid Accuracy: {accuracy(query, answers_valid)}")
print(f"Test Accuracy: {accuracy(query, answers_test)}")

Train Accuracy: 0.8
Valid Accuracy: 0.8
Test Accuracy: 1.0


In [20]:
import pickle
from tqdm import tqdm
from kbc.chain_dataset import Chain, ChaineDataset

def create_cqd_file(queries: list, output_file: str = 'data/FB15k-237/FB15k-237_test_hard_sample.pkl'):
    """
    Create a file in CQD desired format. The CQD will make its predictions based on the queries provided in this file.
    Args:
        queries (list): List of Query objects.
        output_file (str): Path to the output CQD file.
    """
    
    query_chains = []
    for query in tqdm(queries, desc="Creating CQD file"):
        query_chain = Chain()
        # The last element in the query is the target, which is not important for our work, as it will be used only for metrics calculation.
        query_chain.data['raw_chain'] = [query.get_query()[0][0], query.get_query()[0][1][0], [0]]
        query_chain.data['anchors'] = [query.get_query()[0][0]]
        query_chain.data['optimisable'] = [-1, 0]
        query_chain.data['targets'] = [0]
        query_chains.append(query_chain)
        
    chains = ChaineDataset(None)
    chains.type1_1chain = query_chains
    
    with open(output_file, 'wb') as f:
        pickle.dump(chains, f)

In [21]:
create_cqd_file([query], output_file='data/FB15k-237/FB15k-237_test_hard_sample.pkl')

Creating CQD file: 100%|██████████| 1/1 [00:00<00:00, 21509.25it/s]


In [22]:
def get_cach_prediction(predictions_df: pd.DataFrame, entity: int, relation: int):
    """
    Get the prediction for a specific entity and relation from the predictions DataFrame.
    
    Args:
        predictions_df (pd.DataFrame): DataFrame containing the predictions.
        entity (int): The ID of the entity to get the prediction for.
        relation (int): The ID of the relation to get the prediction for.
        
    Returns:
        pd.DataFrame: A DataFrame containing the predicted entities and their scores, sorted by score in descending order.
    """
    filtered_df = predictions_df[(predictions_df['entity_id'] == entity) & (predictions_df['relation_id'] == relation)]
    if filtered_df.empty:
        return [], []
    
    predicted_entities = filtered_df['top_k_entities'].tolist()[0]
    scores = filtered_df['top_k_scores'].tolist()[0]
    
    # make a df which index is predicted_entities and column is scores
    predictions = pd.DataFrame(scores, index=predicted_entities, columns=['score'])
    predictions = predictions.sort_values(by='score', ascending=False)
    
    return predictions

In [23]:
import argparse
import torch
from kbc.cqd_co_xcqa import main
import pandas as pd
import json

def cqd_query(query: Query, sample_path: str = 'data/FB15k-237/FB15k-237_test_hard_sample.pkl', result_path: str = 'scores.json', k: int = 5, cqd_cache: pd.DataFrame = None):
    
    if cqd_cache is not None:
        # If a cache DataFrame is provided, use it to get the predictions
        entity = query.get_query()[0][0]
        relation = query.get_query()[0][1][0]
        
        if entity is None or relation is None:
            raise ValueError("Entity or relation not found in the dataset.")
        
        # Get cached predictions
        predictions = get_cach_prediction(cqd_cache, entity, relation)
        
        # Get the top k answers
        top_k_answers = predictions.head(k)
        
        return top_k_answers
    
    else:

        # Create a CQD file with the query
        create_cqd_file([query], output_file=sample_path)

        # Set up the arguments for the CQD model (cqd_co_xcqa)
        args = argparse.Namespace(
            path = 'FB15k-237',
            sample_path = sample_path,
            model_path = 'models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt',
            dataset = 'FB15k-237',
            mode = 'test',
            chain_type = '1_1', # '1_1', '1_2', '2_2', '2_2_disj', '1_3', '2_3', '3_3', '4_3', '4_3_disj', '1_3_joint'
            t_norm = 'prod', # 'min', 'prod'
            reg = None,
            lr = 0.1,
            optimizer='adam', # 'adam', 'adagrad', 'sgd'
            max_steps = 1000,
            sample = True,
            result_path = result_path,
            save_result = True,
            save_k = 5,
        )

        # Run the CQD model
        main(args)

        # Load the scores
        scores = None
        with open(result_path, 'rb') as f:
            scores = json.load(f)

        tmp_df = pd.read_json(result_path)
        tmp_df = pd.DataFrame(tmp_df.loc[0]['top_k_scores'], index=tmp_df.loc[0]['top_k_entities'], columns=['score'])
        tmp_df = tmp_df.sort_values(by='score', ascending=False)
        
        # Get the top k answers
        top_k_answers = tmp_df.head(k)
        
        return top_k_answers

In [24]:
cqd_query(query, sample_path='data/FB15k-237/FB15k-237_test_hard_sample.pkl', result_path='scores.json', k=5)

Creating CQD file: 100%|██████████| 1/1 [00:00<00:00, 21509.25it/s]


ComplEx(
  (embeddings): ModuleList(
    (0): Embedding(14505, 2000, sparse=True)
    (1): Embedding(474, 2000, sparse=True)
  )
)


100%|██████████| 1/1 [00:00<00:00,  9.34it/s]


Saving results to scores.json


Unnamed: 0,score
9463,10.657951
4074,10.455688
7265,6.088031
4683,5.748784
1236,5.67069


In [25]:
import pandas as pd

cqd_cache = pd.read_json('data/FB15k-237/all_1p_queries.json', orient='records')
cqd_cache

Unnamed: 0,entity_id,relation_id,top_k_entities,top_k_scores
0,8227,402,"[1964, 7094, 5233, 3593, 6775]","[1.641523718833923, 1.5853451490402222, 1.5796..."
1,8227,403,"[5145, 12787, 400, 7208, 4947]","[6.992974281311035, 4.854063987731934, 4.71984..."
2,8227,404,"[1527, 7216, 4168, 226, 4994]","[9.617685317993164, 4.58512020111084, 4.580133..."
3,8227,405,"[2037, 7859, 1779, 7421, 8288]","[1.398288369178772, 1.334656715393066, 1.30624..."
4,8227,406,"[10948, 6607, 9650, 4005, 4]","[1.6829309463500972, 1.536678314208984, 1.4762..."
...,...,...,...,...
6875365,1265,385,"[8410, 1265, 2328, 1266, 1568]","[8.260475158691406, 7.059150695800781, 6.96028..."
6875366,1265,386,"[62, 4, 11, 163, 23]","[2.077280521392822, 1.39265489578247, 1.258444..."
6875367,1265,387,"[470, 862, 657, 6676, 818]","[1.121290802955627, 1.064321041107177, 1.04515..."
6875368,1265,388,"[1265, 2328, 8410, 1266, 1568]","[4.588125228881836, 3.088305473327636, 2.88703..."


In [26]:
cqd_query(query, k=5, cqd_cache=cqd_cache)

Unnamed: 0,score
9463,10.657951
4074,10.455688
7265,6.088031
4683,5.748784
1236,5.67069


## Example Usage

In [361]:
sample_idx = 4000
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]


In [34]:
symbolic = SymbolicReasoning(graph_train, logging=False)

In [69]:
import time

def query_execution(query: Query, k: int = 10, coalition: list = None, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True, cqd_cache: pd.DataFrame = None, inner_cache: dict = None):
    # =============================================== 2p query =========================================================
    if query.query_type == '2p':
        anchor = query.get_query()[0][0]
        relations = query.get_query()[0][1]

        first_level_answers = None
        final_answers = None

        time_start = time.time()
        if coalition[0] == 1:
            if inner_cache is not None and (anchor, relations[0]) in inner_cache['cqd']:
                # print(f"Using cached CQD results for anchor: {anchor}, relation: {relations[0]}")
                first_level_answers = inner_cache['cqd'][(anchor, relations[0])]
            else:
                current_query = Query('1p', (((anchor, (relations[0],)),), []))
                first_level_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                if inner_cache is not None:
                    inner_cache['cqd'][(anchor, relations[0])] = first_level_answers
        elif coalition[0] == 0:
            if inner_cache is not None and (anchor, relations[0]) in inner_cache['symbolic']:
                # print(f"Using cached Symbolic results for anchor: {anchor}, relation: {relations[0]}")
                first_level_answers = inner_cache['symbolic'][(anchor, relations[0])]
            else:
                first_level_answers = symbolic.query_1p(anchor, relations[0])
                first_level_answers = symbolic.fixed_size_answer(first_level_answers, k)
                if inner_cache is not None:
                    inner_cache['symbolic'][(anchor, relations[0])] = first_level_answers
        time_end = time.time()

        if logging:
            print(f"Time taken for first level query: {time_end - time_start:.2f} seconds")

        start_time = time.time()
        for answer_idx, row in first_level_answers.iterrows():
            if coalition[1] == 1:
                if inner_cache is not None and (answer_idx, relations[1]) in inner_cache['cqd']:
                    # print(f"Using cached CQD results for answer: {answer_idx}, relation: {relations[1]}")
                    second_level_answers = inner_cache['cqd'][(answer_idx, relations[1])].copy()
                else:
                    current_query = Query('1p', (((answer_idx, (relations[1],)),), []))
                    second_level_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                    if inner_cache is not None:
                        inner_cache['cqd'][(answer_idx, relations[1])] = second_level_answers.copy()
            elif coalition[1] == 0:
                if inner_cache is not None and (answer_idx, relations[1]) in inner_cache['symbolic']:
                    # print(f"Using cached symbolic results for answer: {answer_idx}, relation: {relations[1]}")
                    second_level_answers = inner_cache['symbolic'][(answer_idx, relations[1])].copy()
                else:
                    second_level_answers = symbolic.query_1p(answer_idx, relations[1])
                    second_level_answers = symbolic.fixed_size_answer(second_level_answers, k)
                    if inner_cache is not None:
                        inner_cache['symbolic'][(answer_idx, relations[1])] = second_level_answers.copy()
            if t_norm == 'prod':  
                    second_level_answers['score'] = second_level_answers['score'] * row['score']
            elif t_norm == 'min':
                    second_level_answers['score'] = second_level_answers['score'].min(row['score'])
            # second_level_answers['path'] = str((anchor, relations[0], answer_idx, relations[1]))
            # second_row['path'] + f'--{relations[2]}-->{second_answer_idx}'
            second_level_answers['path'] = str(anchor) + f'--{relations[0]}-->{answer_idx}' + f'--{relations[1]}-->'
            if final_answers is None:
                final_answers = second_level_answers
            else:
                final_answers = pd.concat([final_answers, second_level_answers], axis=0)
        time_end = time.time()
        if logging:
            print(f"Time taken for second level query: {time_end - start_time:.2f} seconds")
    # =============================================== 2u query =========================================================
    elif query.query_type == '2u':
        query1 = query.get_query()[0]
        query2 = query.get_query()[1]
        anchor1 = query1[0]
        relation1 = query1[1][0]
        anchor2 = query2[0]
        relation2 = query2[1][0]

        first_branch_answers = None

        time_start = time.time()
        if coalition[0] == 1:
            if inner_cache is not None and (anchor1, relation1) in inner_cache['cqd']:
                # print(f"Using cached CQD results for anchor: {anchor1}, relation: {relation1}")
                first_branch_answers = inner_cache['cqd'][(anchor1, relation1)]
            else:
                current_query = Query('1p', (((anchor1, (relation1,)),), []))
                first_branch_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                if inner_cache is not None:
                    inner_cache['cqd'][(anchor1, relation1)] = first_branch_answers
        elif coalition[0] == 0:
            if inner_cache is not None and (anchor1, relation1) in inner_cache['symbolic']:
                # print(f"Using cached Symbolic results for anchor: {anchor1}, relation: {relation1}")
                first_branch_answers = inner_cache['symbolic'][(anchor1, relation1)]
            else:
                first_branch_answers = symbolic.query_1p(anchor1, relation1)
                first_branch_answers = symbolic.fixed_size_answer(first_branch_answers, k)
                if inner_cache is not None:
                    inner_cache['symbolic'][(anchor1, relation1)] = first_branch_answers
        first_branch_answers['path1'] = str(anchor1) + f'--{relation1}-->'
        time_end = time.time()

        if logging:
            print(f"Time taken for first level query: {time_end - time_start:.2f} seconds")
            
        second_branch_answers = None

        time_start = time.time()
        if coalition[1] == 1:
            if inner_cache is not None and (anchor2, relation2) in inner_cache['cqd']:
                # print(f"Using cached CQD results for anchor: {anchor2}, relation: {relation2}")
                second_branch_answers = inner_cache['cqd'][(anchor2, relation2)]
            else:
                current_query = Query('1p', (((anchor2, (relation2,)),), []))
                second_branch_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                if inner_cache is not None:
                    inner_cache['cqd'][(anchor2, relation2)] = second_branch_answers
        elif coalition[1] == 0:
            if inner_cache is not None and (anchor2, relation2) in inner_cache['symbolic']:
                # print(f"Using cached Symbolic results for anchor: {anchor2}, relation: {relation2}")
                second_branch_answers = inner_cache['symbolic'][(anchor2, relation2)]
            else:
                second_branch_answers = symbolic.query_1p(anchor2, relation2)
                second_branch_answers = symbolic.fixed_size_answer(second_branch_answers, k)
                if inner_cache is not None:
                    inner_cache['symbolic'][(anchor2, relation2)] = second_branch_answers
        second_branch_answers['path2'] = str(anchor2) + f'--{relation2}-->'

        if t_conorm == 'min':
            first_branch_answers = first_branch_answers.rename(columns={'score': 'score1'})
            second_branch_answers = second_branch_answers.rename(columns={'score': 'score2'})
            final_answers = pd.merge(first_branch_answers, second_branch_answers, left_index=True, right_index=True)
            final_answers['score'] = final_answers[['score1', 'score2']].min(axis=1)
            final_answers = final_answers.drop(columns=['score1', 'score2'])
        elif t_conorm == 'prod':
            first_branch_answers = first_branch_answers.rename(columns={'score': 'score1'})
            second_branch_answers = second_branch_answers.rename(columns={'score': 'score2'})
            final_answers = pd.merge(first_branch_answers, second_branch_answers, left_index=True, right_index=True)
            final_answers['score'] = final_answers['score1'] * final_answers['score2']
            final_answers = final_answers.drop(columns=['score1', 'score2'])
            
        time_end = time.time() 
        if logging:
            print(f"Time taken for second level query: {time_end - time_start:.2f} seconds")
    
    # =============================================== 3p query =========================================================
    elif query.query_type == '3p':
        anchor = query.get_query()[0][0]
        relations = query.get_query()[0][1]

        first_level_answers = None
        final_answers = None

        time_start = time.time()
        if coalition[0] == 1:
            if inner_cache is not None and (anchor, relations[0]) in inner_cache['cqd']:
                # print(f"Using cached CQD results for anchor: {anchor}, relation: {relations[0]}")
                first_level_answers = inner_cache['cqd'][(anchor, relations[0])]
            else:
                current_query = Query('1p', (((anchor, (relations[0],)),), []))
                first_level_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                if inner_cache is not None:
                    inner_cache['cqd'][(anchor, relations[0])] = first_level_answers
        elif coalition[0] == 0:
            if inner_cache is not None and (anchor, relations[0]) in inner_cache['symbolic']:
                # print(f"Using cached Symbolic results for anchor: {anchor}, relation: {relations[0]}")
                first_level_answers = inner_cache['symbolic'][(anchor, relations[0])]
            else:
                first_level_answers = symbolic.query_1p(anchor, relations[0])
                first_level_answers = symbolic.fixed_size_answer(first_level_answers, k)
                if inner_cache is not None:
                    inner_cache['symbolic'][(anchor, relations[0])] = first_level_answers
        time_end = time.time()

        if logging:
            print(f"Time taken for first level query: {time_end - time_start:.2f} seconds")

        start_time = time.time()
        second_level_answers = None
        for answer_idx, row in first_level_answers.iterrows():
            if coalition[1] == 1:
                if inner_cache is not None and (answer_idx, relations[1]) in inner_cache['cqd']:
                    # print(f"Using cached CQD results for answer: {answer_idx}, relation: {relations[1]}")
                    second_level_answer = inner_cache['cqd'][(answer_idx, relations[1])].copy()
                else:
                    current_query = Query('1p', (((answer_idx, (relations[1],)),), []))
                    second_level_answer = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                    if inner_cache is not None:
                        inner_cache['cqd'][(answer_idx, relations[1])] = second_level_answer.copy()
            elif coalition[1] == 0:
                if inner_cache is not None and (answer_idx, relations[1]) in inner_cache['symbolic']:
                    # print(f"Using cached symbolic results for answer: {answer_idx}, relation: {relations[1]}")
                    second_level_answer = inner_cache['symbolic'][(answer_idx, relations[1])].copy()
                else:
                    second_level_answer = symbolic.query_1p(answer_idx, relations[1])
                    second_level_answer = symbolic.fixed_size_answer(second_level_answer, k)
                    if inner_cache is not None:
                        inner_cache['symbolic'][(answer_idx, relations[1])] = second_level_answer.copy()
            if t_norm == 'prod':  
                    second_level_answer['score'] = second_level_answer['score'] * row['score']
            elif t_norm == 'min':
                    second_level_answer['score'] = min(second_level_answer['score'], row['score'])
            second_level_answer['path'] = str(anchor) + f'--{relations[0]}-->{answer_idx}' + f'--{relations[1]}-->'
            if second_level_answers is None:
                second_level_answers = second_level_answer
            else:
                second_level_answers = pd.concat([second_level_answers, second_level_answer], axis=0)
        
        for second_answer_idx, second_row in second_level_answers.iterrows():
            if coalition[2] == 1:
                if inner_cache is not None and (second_answer_idx, relations[2]) in inner_cache['cqd']:
                    # print(f"Using cached CQD results for second answer: {second_answer_idx}, relation: {relations[2]}")
                    third_level_answers = inner_cache['cqd'][(second_answer_idx, relations[2])].copy()
                else:
                    current_query = Query('1p', (((second_answer_idx, (relations[2],)),), []))
                    third_level_answers = cqd_query(current_query, k=k, cqd_cache=cqd_cache)
                    if inner_cache is not None:
                        inner_cache['cqd'][(second_answer_idx, relations[2])] = third_level_answers.copy()
            elif coalition[2] == 0:
                if inner_cache is not None and (second_answer_idx, relations[2]) in inner_cache['symbolic']:
                    # print(f"Using cached symbolic results for second answer: {second_answer_idx}, relation: {relations[2]}")
                    third_level_answers = inner_cache['symbolic'][(second_answer_idx, relations[2])].copy()
                else:
                    third_level_answers = symbolic.query_1p(second_answer_idx, relations[2])
                    third_level_answers = symbolic.fixed_size_answer(third_level_answers, k)
                    if inner_cache is not None:
                        inner_cache['symbolic'][(second_answer_idx, relations[2])] = third_level_answers.copy()
            if t_norm == 'prod':  
                third_level_answers['score'] = third_level_answers['score'] * second_row['score']
            elif t_norm == 'min':
                third_level_answers['score'] = min(third_level_answers['score'], second_row['score'])
            third_level_answers['path'] = second_row['path'] + str(second_answer_idx) + f'--{relations[2]}-->'
            if final_answers is None:
                final_answers = third_level_answers
            else:
                final_answers = pd.concat([final_answers, third_level_answers], axis=0)
        time_end = time.time()
        if logging:
            print(f"Time taken for second level query: {time_end - start_time:.2f} seconds")
            
    else:
        raise ValueError(f"Unsupported query type: {query.query_type}. Only '2p' queries are supported.")
    final_answers = final_answers.sort_values(by='score', ascending=False)

    # if we have duplicate answers, we need to keep only the one with the highest score
    # as the final_answers dataframe is already sorted by score, we can just keep the first occurrence of each index which means that we keep the highest score
    final_answers = final_answers[~final_answers.index.duplicated(keep='first')]

    # the output should be a dataframe of scores for each possible node in the graph
    df = pd.DataFrame(index=dataset.id2node.keys(), columns=['score', 'path'])
    df['score'] = 0.0
    for answer in final_answers.index:
        df.loc[answer, 'score'] = final_answers.loc[answer, 'score']
        df.loc[answer, 'path'] = final_answers.loc[answer, 'path']
    # shuffle the data to have a random order of answers and a fair measurement of performance
    # df = df.sample(frac=1)
    # sort by index and score at the same time to make sure that the order is consistent
    df.index.name = 'entity_id'
    df = df.sort_values(by=['score', 'entity_id'], ascending=[False, True])
    return df

In [364]:
sample_idx = 4000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
query = query_dataset.get_queries(sample_query_type)[sample_idx]
easy_answers = query.get_answer()
easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
human_readable(query_hard, dataset)
print("-"*70)
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']
----------------------------------------------------------------------
Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]


In [365]:
query_hard

Query(type=2p, query=((12324, (45, 179)),), answer=[3181])

In [366]:
query

Query(type=2p, query=((12324, (45, 179)),), answer=[706, 587, 3181, 1177, 1566])

In [367]:
inner_cache = {
    'cqd': {},
    'symbolic': {}
}

In [368]:
query_execution(query, k=5, coalition=[0, 0], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Time taken for first level query: 0.14 seconds
Time taken for second level query: 0.69 seconds


Unnamed: 0_level_0,score,path
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1
587,1.0,12324--45-->4074--179-->1177
706,1.0,12324--45-->4074--179-->1177
1177,1.0,12324--45-->4074--179-->1177
1566,1.0,12324--45-->9463--179-->1177
0,0.0,
...,...,...
14500,0.0,
14501,0.0,
14502,0.0,
14503,0.0,


In [None]:
query_execution(query, k=5, coalition=[0, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Using cached Symbolic results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.07 seconds


Unnamed: 0_level_0,score,path
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1177,8.299611,"(12324, 45, 4074, 179)"
706,8.166074,"(12324, 45, 4074, 179)"
1566,7.979342,"(12324, 45, 9463, 179)"
587,7.124403,"(12324, 45, 4074, 179)"
1019,6.903434,"(12324, 45, 4074, 179)"
...,...,...
14500,0.000000,
14501,0.000000,
14502,0.000000,
14503,0.000000,


In [None]:
query_execution(query, k=5, coalition=[1, 0], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Time taken for first level query: 0.01 seconds
Using cached symbolic results for answer: 9463, relation: 179
Using cached symbolic results for answer: 4074, relation: 179
Time taken for second level query: 0.49 seconds


Unnamed: 0_level_0,score,path
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1177,10.657951,"(12324, 45, 9463, 179)"
1566,10.657951,"(12324, 45, 9463, 179)"
587,10.455688,"(12324, 45, 4074, 179)"
706,10.455688,"(12324, 45, 4074, 179)"
484,6.088031,"(12324, 45, 7265, 179)"
...,...,...
14500,0.000000,
14501,0.000000,
14502,0.000000,
14503,0.000000,


In [None]:
query_execution(query, k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Using cached CQD results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 9463, relation: 179
Using cached CQD results for answer: 4074, relation: 179
Time taken for second level query: 0.04 seconds


Unnamed: 0_level_0,score,path
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1177,86.778140,"(12324, 45, 4074, 179)"
706,85.381916,"(12324, 45, 4074, 179)"
1566,85.043434,"(12324, 45, 9463, 179)"
587,74.490537,"(12324, 45, 4074, 179)"
1019,72.180147,"(12324, 45, 4074, 179)"
...,...,...
14500,0.000000,
14501,0.000000,
14502,0.000000,
14503,0.000000,


In [28]:
def value_function(query: Query, easy_answers: list, target_entity: int, qoi: str = 'rank', k: int = 10, coalition: list = None, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True, cqd_cache: pd.DataFrame = None, inner_cache: dict = None):
    
    if sum(coalition) == 0:
        # this is the requirement of shapley values definition
        return 0
    else:
        result = query_execution(query, k=k, coalition=coalition, t_norm=t_norm, t_conorm=t_conorm, logging=logging, cqd_cache=cqd_cache, inner_cache=inner_cache)
        # remove easy answers from the result
        result = result[~result.index.isin(easy_answers)]
        if qoi == 'rank':
            if target_entity in result.index:
                value = result.index.get_loc(target_entity)
            else:
                raise ValueError(f"Target entity {target_entity} not found in the result")
        elif qoi == 'hit1':
            value = 1 if target_entity in result.index[:1] else 0
        elif qoi == 'hit3':
            value = 1 if target_entity in result.index[:3] else 0
        elif qoi == 'hit10':
            value = 1 if target_entity in result.index[:10] else 0
        else:
            raise ValueError(f"Unsupported QoI: {qoi}. Supported values are 'rank', 'hit1', 'hit3', 'hit10'.")
        return value

In [333]:
inner_cache = {
    'cqd': {},
    'symbolic': {}
}

In [334]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[0, 0], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

0

In [369]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[0, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Using cached Symbolic results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.07 seconds


3178

In [370]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 0], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Time taken for first level query: 0.01 seconds
Using cached symbolic results for answer: 9463, relation: 179
Using cached symbolic results for answer: 4074, relation: 179
Time taken for second level query: 0.40 seconds


3181

In [371]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Using cached CQD results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 9463, relation: 179
Using cached CQD results for answer: 4074, relation: 179
Time taken for second level query: 0.04 seconds


3179

In [372]:
inner_cache = {
        'cqd': {},
        'symbolic': {}
}

In [29]:
import math

def shapley_value(query: Query, atom_idx: int, easy_answers: list, target_entity: int, qoi: str = 'rank', k: int = 10, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True, cqd_cache: pd.DataFrame = None, inner_cache: dict = None):
    num_atoms = 0
    if query.query_type == '2p':
        num_atoms = 2
    elif query.query_type == '3p':
        num_atoms = 3
    elif query.query_type == '2i':
        num_atoms = 2
    elif query.query_type == '2u':
        num_atoms = 2
    else:
        raise ValueError(f"Unsupported query type: {query.query_type}. Only '2p' queries are supported.")

    shapley_value = 0.0
    
    num_remaining_atoms = num_atoms - 1
    for i in range(2**num_remaining_atoms):
        coalition = [int(x) for x in bin(i)[2:].zfill(num_remaining_atoms)]
        # create the coalition in the format of a list of 0s and 1s
        coalition_mask = [0] * num_atoms
        counter = 0
        for idx, _ in enumerate(coalition_mask):
            if idx== atom_idx:
                coalition_mask[idx] = 0
            else:
                coalition_mask[idx] = coalition[counter]
                counter += 1
        if logging:
            print(f"Coalition: {coalition_mask}, Atom Index: {atom_idx}")

        # calculate the weight term (|S|! (p-|S|-1)! \ p!
        weight = (math.factorial(sum(coalition)) * math.factorial(num_atoms - sum(coalition) - 1)) / math.factorial(num_atoms)
        
        # calculate the value function for the current coalition
        value = value_function(query, easy_answers, target_entity=target_entity, qoi=qoi, k=k, coalition=coalition_mask, t_norm=t_norm, t_conorm=t_conorm, logging=logging, cqd_cache=cqd_cache, inner_cache=inner_cache)
        
        # calculate the contribution of the current coalition when the atom is added
        added_coalition_mask = coalition_mask.copy()
        added_coalition_mask[atom_idx] = 1
        added_value = value_function(query, easy_answers, target_entity=target_entity, qoi=qoi, k=k, coalition=added_coalition_mask, t_norm=t_norm, t_conorm=t_conorm, logging=logging, cqd_cache=cqd_cache, inner_cache=inner_cache)
        
        # compute the difference
        contribution = added_value - value
        if logging:
            print(f"Coalition: {coalition_mask}, Contribution: {contribution} (before adding atom: {value}, after adding atom: {added_value}), weight: {weight})")
            
        # add the contribution to the shapley value
        shapley_value += contribution * weight
        
    if logging:
        print(f"Shapley value for atom {atom_idx}: {shapley_value}")
    return shapley_value

Shapley value of the first player (atom 0):

In [374]:
shapley_value(query, atom_idx=0, easy_answers=easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 0
Time taken for first level query: 0.01 seconds
Time taken for second level query: 0.68 seconds
Coalition: [0, 0], Contribution: 3179 (before adding atom: 0, after adding atom: 3179), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Time taken for first level query: 0.14 seconds
Time taken for second level query: 0.07 seconds
Using cached CQD results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 9463, relation: 179
Using cached CQD results for answer: 4074, relation: 179
Time taken for second level query: 0.04 seconds
Coalition: [0, 1], Contribution: 1 (before adding atom: 3178, after adding atom: 3179), weight: 0.5)
Shapley value for atom 0: 1590.0


1590.0

Shapley value of the second player (atom 1):

In [375]:
shapley_value(query, atom_idx=1, easy_answers=easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 1
Using cached Symbolic results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 4074, relation: 179
Using cached CQD results for answer: 9463, relation: 179
Using cached CQD results for answer: 5153, relation: 179
Using cached CQD results for answer: 13778, relation: 179
Using cached CQD results for answer: 6058, relation: 179
Time taken for second level query: 0.00 seconds
Coalition: [0, 0], Contribution: 3178 (before adding atom: 0, after adding atom: 3178), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Using cached CQD results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached symbolic results for answer: 9463, relation: 179
Using cached symbolic results for answer: 4074, relation: 179
Using cached symbolic results for answer: 7265, relation: 179
Using cached symbolic results for answer: 4683, relation: 179
Using cached symbolic results for answ

1589.0

Value function when every atom is present in the coalition:

In [376]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Using cached CQD results for anchor: 12324, relation: 45
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 9463, relation: 179
Using cached CQD results for answer: 4074, relation: 179
Using cached CQD results for answer: 7265, relation: 179
Using cached CQD results for answer: 4683, relation: 179
Using cached CQD results for answer: 1236, relation: 179
Time taken for second level query: 0.00 seconds


3179

Final check:

In [255]:
# efficiency check
3183 == 3177.5 + 5.5

True

## Shapley Value Examples

In [377]:
inner_cache = {
        'cqd': {},
        'symbolic': {}
}

In [378]:
sample_idx = 1000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
query = query_dataset.get_queries(sample_query_type)[sample_idx]
easy_answers = query.get_answer()
easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
human_readable(query_hard, dataset)
print("-"*70)
human_readable(query, dataset)

Query:
Perkin Medal	--/award/award_category/winners./award/award_honor/award_winner-->	V
V	--/people/ethnicity/people_reverse-->	?

Answer Set (?): 
['African American']
----------------------------------------------------------------------
Query:
Perkin Medal	--/award/award_category/winners./award/award_honor/award_winner-->	V
V	--/people/ethnicity/people_reverse-->	?

Answer Set (?): 
['African American']


In [379]:
cqd_result = query_execution(query, k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)
human_df = cqd_result.copy()
human_df['title'] = human_df.index.map(dataset.id2node)
human_df['title'] = human_df['title'].map(dataset.get_title_by_node)
human_df['is_easy_answer'] = human_df.index.isin(easy_answers)
human_df['is_hard_answer'] = human_df.index.isin(query_hard.get_answer())
human_df

Time taken for first level query: 0.01 seconds
Time taken for second level query: 0.07 seconds


Unnamed: 0_level_0,score,path,title,is_easy_answer,is_hard_answer
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1253,59.608866,14483--94-->4165--143-->1253,Irish American,False,False
879,59.601220,14483--94-->4165--143-->1253,Jewish people,False,False
11038,55.189793,14483--94-->4165--143-->1253,Mexican American,False,False
3324,55.045102,14483--94-->4165--143-->1253,White American,False,False
2279,53.064225,14483--94-->4165--143-->1253,African American,False,True
...,...,...,...,...,...
14500,0.000000,,"Strategic Simulations, Inc.",False,False
14501,0.000000,,House of Plantagenet,False,False
14502,0.000000,,Humour,False,False
14503,0.000000,,Modernism,False,False


Let's explain the true hard answer which the model could predict it correctly.

In [259]:
target = 'African American'
target_entity = dataset.get_node_by_title(target)
target_id = dataset.get_id_by_node(target_entity)
print(f"Target entity: {target} ({target_entity}) with ID {target_id}")

Target entity: African American (/m/0x67) with ID 2279


In [260]:
atom_idx = 0
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 0
Using cached CQD results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds


Time taken for second level query: 0.86 seconds
Coalition: [0, 0], Contribution: 2280 (before adding atom: 0, after adding atom: 2280), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Time taken for first level query: 0.17 seconds
Using cached CQD results for answer: 4165, relation: 143
Time taken for second level query: 0.06 seconds
Using cached CQD results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 4165, relation: 143
Using cached CQD results for answer: 6996, relation: 143
Using cached CQD results for answer: 227, relation: 143
Using cached CQD results for answer: 3505, relation: 143
Using cached CQD results for answer: 6599, relation: 143
Time taken for second level query: 0.00 seconds
Coalition: [0, 1], Contribution: 0 (before adding atom: 4, after adding atom: 4), weight: 0.5)
Shapley value for atom 0: 1140.0


1140.0

In [261]:
atom_idx = 1
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 1
Using cached Symbolic results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 4165, relation: 143
Using cached CQD results for answer: 908, relation: 143
Using cached CQD results for answer: 6722, relation: 143
Using cached CQD results for answer: 11605, relation: 143
Using cached CQD results for answer: 2424, relation: 143
Time taken for second level query: 0.00 seconds
Coalition: [0, 0], Contribution: 4 (before adding atom: 0, after adding atom: 4), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Using cached CQD results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached symbolic results for answer: 4165, relation: 143
Using cached symbolic results for answer: 6996, relation: 143
Using cached symbolic results for answer: 227, relation: 143
Using cached symbolic results for answer: 3505, relation: 143
Using cached symbolic results for answer: 6599

-1136.0

Now, let's look at the explanation for another target entity, which is not in the answer set of the query.

In [262]:
target = "Italian American"
target_entity = dataset.get_node_by_title(target)
target_id = dataset.get_id_by_node(target_entity)
print(f"Target entity: {target} ({target_entity}) with ID {target_id}")

Target entity: Italian American (/m/0xnvg) with ID 3330


In [263]:
atom_idx = 0
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 0
Using cached CQD results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached symbolic results for answer: 4165, relation: 143
Using cached symbolic results for answer: 6996, relation: 143
Using cached symbolic results for answer: 227, relation: 143
Using cached symbolic results for answer: 3505, relation: 143
Using cached symbolic results for answer: 6599, relation: 143
Time taken for second level query: 0.00 seconds
Coalition: [0, 0], Contribution: 3331 (before adding atom: 0, after adding atom: 3331), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Using cached Symbolic results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 4165, relation: 143
Using cached CQD results for answer: 908, relation: 143
Using cached CQD results for answer: 6722, relation: 143
Using cached CQD results for answer: 11605, relation: 143
Using cached CQD results for answer

3.5

In [264]:
atom_idx = 1
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0], Atom Index: 1
Using cached Symbolic results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached CQD results for answer: 4165, relation: 143
Using cached CQD results for answer: 908, relation: 143
Using cached CQD results for answer: 6722, relation: 143
Using cached CQD results for answer: 11605, relation: 143
Using cached CQD results for answer: 2424, relation: 143
Time taken for second level query: 0.00 seconds
Coalition: [0, 0], Contribution: 3331 (before adding atom: 0, after adding atom: 3331), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Using cached CQD results for anchor: 14483, relation: 94
Time taken for first level query: 0.00 seconds
Using cached symbolic results for answer: 4165, relation: 143
Using cached symbolic results for answer: 6996, relation: 143
Using cached symbolic results for answer: 227, relation: 143
Using cached symbolic results for answer: 3505, relation: 143
Using cached symbolic results for answer

3.5

## Shapley Value for 3p

In [81]:
inner_cache = {
        'cqd': {},
        'symbolic': {}
}

In [82]:
sample_query_type = '3p'
sample_idx = 40
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
query = query_dataset.get_queries(sample_query_type)[sample_idx]
easy_answers = query.get_answer()
easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
human_readable(query_hard, dataset)
print("-"*70)
human_readable(query, dataset)

Query:
Rocco and His Brothers	--/award/award_category/nominees./award/award_nomination/nominated_for_reverse-->	V1
V1	--/award/award_winning_work/awards_won./award/award_honor/award_reverse-->	V2
V2	--/film/director/film_reverse-->	?

Answer Set (?): 
['Steven Spielberg']
----------------------------------------------------------------------
Query:
Rocco and His Brothers	--/award/award_category/nominees./award/award_nomination/nominated_for_reverse-->	V1
V1	--/award/award_winning_work/awards_won./award/award_honor/award_reverse-->	V2
V2	--/film/director/film_reverse-->	?

Answer Set (?): 
['Joseph L. Mankiewicz', 'Peter Jackson', 'Peter Weir', 'Sam Mendes', 'Laurence Olivier', 'David Lean', 'Woody Allen', 'Danny Boyle', 'Ben Affleck', 'Mike Nichols', 'George Cukor', 'Tony Richardson', 'Ang Lee', 'Anthony Minghella', 'Roman Polanski', 'Kathryn Bigelow', 'Bob Fosse', 'Bernardo Bertolucci', 'Stanley Kubrick', 'Martin Scorsese', 'John Schlesinger', 'Steven Spielberg', 'Fred Zinnemann']


In [72]:
cqd_result = query_execution(query, k=5, coalition=[0, 0, 0], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)
human_df = cqd_result.copy()
human_df['title'] = human_df.index.map(dataset.id2node)
human_df['title'] = human_df['title'].map(dataset.get_title_by_node)
human_df['is_easy_answer'] = human_df.index.isin(easy_answers)
human_df['is_hard_answer'] = human_df.index.isin(query_hard.get_answer())
human_df

Time taken for first level query: 0.07 seconds
Time taken for second level query: 2.22 seconds


Unnamed: 0_level_0,score,path,title,is_easy_answer,is_hard_answer
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2494,1.0,11526--15-->112--61-->5992--205-->,Mike Nichols,True,False
7390,1.0,11526--15-->112--61-->5845--205-->,Roman Polanski,True,False
0,0.0,,Dominican Republic,False,False
1,0.0,,Republic,False,False
2,0.0,,Mighty Morphin Power Rangers,False,False
...,...,...,...,...,...
14500,0.0,,"Strategic Simulations, Inc.",False,False
14501,0.0,,House of Plantagenet,False,False
14502,0.0,,Humour,False,False
14503,0.0,,Modernism,False,False


In [73]:
cqd_result = query_execution(query, k=5, coalition=[1, 1, 1], t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)
human_df = cqd_result.copy()
human_df['title'] = human_df.index.map(dataset.id2node)
human_df['title'] = human_df['title'].map(dataset.get_title_by_node)
human_df['is_easy_answer'] = human_df.index.isin(easy_answers)
human_df['is_hard_answer'] = human_df.index.isin(query_hard.get_answer())
human_df

Time taken for first level query: 0.02 seconds
Time taken for second level query: 0.42 seconds


Unnamed: 0_level_0,score,path,title,is_easy_answer,is_hard_answer
entity_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2928,421.913810,11526--15-->20--61-->10827--205-->,Costa-Gavras,False,False
4383,410.450208,11526--15-->1127--61-->370--205-->,William Wyler,False,False
2558,407.203200,11526--15-->1127--61-->8492--205-->,Fred Zinnemann,True,False
2723,405.566671,11526--15-->1127--61-->4857--205-->,Michael Cimino,False,False
2222,396.516644,11526--15-->1127--61-->5031--205-->,Kevin Costner,False,False
...,...,...,...,...,...
14500,0.000000,,"Strategic Simulations, Inc.",False,False
14501,0.000000,,House of Plantagenet,False,False
14502,0.000000,,Humour,False,False
14503,0.000000,,Modernism,False,False


In [83]:
target = 'Steven Spielberg'
target_entity = dataset.get_node_by_title(target)
target_id = dataset.get_id_by_node(target_entity)
print(f"Target entity: {target} ({target_entity}) with ID {target_id}")

Target entity: Steven Spielberg (/m/06pj8) with ID 1917


In [84]:
atom_idx = 0
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0, 0], Atom Index: 0
Time taken for first level query: 0.02 seconds
Time taken for second level query: 2.14 seconds
Coalition: [0, 0, 0], Contribution: 1917 (before adding atom: 0, after adding atom: 1917), weight: 0.3333333333333333)
Coalition: [0, 0, 1], Atom Index: 0
Time taken for first level query: 0.07 seconds
Time taken for second level query: 0.77 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.39 seconds
Coalition: [0, 0, 1], Contribution: -1870 (before adding atom: 1927, after adding atom: 57), weight: 0.16666666666666666)
Coalition: [0, 1, 0], Atom Index: 0
Time taken for first level query: 0.00 seconds
Time taken for second level query: 1.83 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.67 seconds
Coalition: [0, 1, 0], Contribution: 4 (before adding atom: 1911, after adding atom: 1915), weight: 0.16666666666666666)
Coalition: [0, 1, 1], Atom Index: 0
Time taken for first

341.33333333333337

In [85]:
atom_idx = 1
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0, 0], Atom Index: 1
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [0, 0, 0], Contribution: 1911 (before adding atom: 0, after adding atom: 1911), weight: 0.3333333333333333)
Coalition: [0, 0, 1], Atom Index: 1
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [0, 0, 1], Contribution: -1 (before adding atom: 1927, after adding atom: 1926), weight: 0.16666666666666666)
Coalition: [1, 0, 0], Atom Index: 1
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [1, 0, 0], Contribution: -2 (before adding atom: 1917, after adding atom: 1915), weight: 0.16666666666666666)
Coalition: [1, 0, 1], Atom Index: 1
Time taken for first

1272.8333333333333

In [86]:
atom_idx = 2
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True, cqd_cache=cqd_cache, inner_cache=inner_cache)

Coalition: [0, 0, 0], Atom Index: 2
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [0, 0, 0], Contribution: 1927 (before adding atom: 0, after adding atom: 1927), weight: 0.3333333333333333)
Coalition: [0, 1, 0], Atom Index: 2
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [0, 1, 0], Contribution: 15 (before adding atom: 1911, after adding atom: 1926), weight: 0.16666666666666666)
Coalition: [1, 0, 0], Atom Index: 2
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Time taken for first level query: 0.00 seconds
Time taken for second level query: 0.01 seconds
Coalition: [1, 0, 0], Contribution: -1860 (before adding atom: 1917, after adding atom: 57), weight: 0.16666666666666666)
Coalition: [1, 1, 0], Atom Index: 2
Time taken for firs

351.83333333333326

## Caching

Compute all possible 1p queries and cache them in a file.

In [161]:
# for each possible combination of node and relation, make a query and append it
all_1p_queries = []
for node in tqdm(dataset.id2node.values(), desc="Creating 1p queries"):
    for relation in dataset.id2rel.values():
        node_id = dataset.get_id_by_node(node)
        relation_id = dataset.get_id_by_relation(relation)
        if node_id is not None and relation_id is not None:
            query = Query('1p', (((node_id, (relation_id,)),), []))
            all_1p_queries.append(query)

Creating 1p queries:   0%|          | 0/14505 [00:00<?, ?it/s]

Creating 1p queries: 100%|██████████| 14505/14505 [00:29<00:00, 493.24it/s] 


In [162]:
print(f'Number of all 1p queries: {len(all_1p_queries)}')

Number of all 1p queries: 6875370


In [29]:
import os
import argparse
import torch
from kbc.cqd_co_xcqa import main
import pandas as pd
from tqdm import tqdm

if not os.path.exists('data/FB15k-237/all_1p_queries'):
    os.makedirs('data/FB15k-237/all_1p_queries')
    os.system('cp data/FB15k-237/FB15k-237_test_complete.pkl data/FB15k-237/all_1p_queries/FB15k-237_test_complete.pkl')
    
chunk_size = 100000

for i in tqdm(range(0, len(all_1p_queries), chunk_size), desc="Creating CQD files for 1p queries"):
    batch_queries = all_1p_queries[i:i + chunk_size]
    create_cqd_file(batch_queries, output_file=f'data/FB15k-237/all_1p_queries/all_1p_queries_{i // chunk_size}.pkl')
    sample_path = f'data/FB15k-237/all_1p_queries/all_1p_queries_{i // chunk_size}.pkl'
    result_path = f'data/FB15k-237/all_1p_queries/results_{i // chunk_size}.json'
    # run the cqd_co_xcqa model
    args = argparse.Namespace(
        path = 'FB15k-237',
        sample_path = sample_path,
        model_path = 'models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt',
        dataset = 'FB15k-237',
        mode = 'test',
        chain_type = '1_1', # '1_1', '1_2', '2_2', '2_2_disj', '1_3', '2_3', '3_3', '4_3', '4_3_disj', '1_3_joint'
        t_norm = 'prod', # 'min', 'prod'
        reg = None,
        lr = 0.1,
        optimizer='adam', # 'adam', 'adagrad', 'sgd'
        max_steps = 1000,
        sample = False,
        result_path = result_path,
        save_result = True,
        save_k = 5
    )
    main(args)

Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 686782.44it/s]
100%|██████████| 100000/100000 [00:09<00:00, 10373.23it/s]


Saving results to data/FB15k-237/all_1p_queries/results_0.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 737489.87it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10388.94it/s]


Saving results to data/FB15k-237/all_1p_queries/results_1.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 713120.53it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10376.86it/s]


Saving results to data/FB15k-237/all_1p_queries/results_2.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 687463.47it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.54it/s]


Saving results to data/FB15k-237/all_1p_queries/results_3.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 703424.80it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10401.26it/s]


Saving results to data/FB15k-237/all_1p_queries/results_4.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 757150.62it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.93it/s]


Saving results to data/FB15k-237/all_1p_queries/results_5.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 727307.79it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10401.90it/s]


Saving results to data/FB15k-237/all_1p_queries/results_6.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 30473.58it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.64it/s]


Saving results to data/FB15k-237/all_1p_queries/results_7.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 686708.23it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10401.82it/s]


Saving results to data/FB15k-237/all_1p_queries/results_8.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 754898.02it/s]it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.03it/s]


Saving results to data/FB15k-237/all_1p_queries/results_9.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 739759.64it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10400.77it/s]


Saving results to data/FB15k-237/all_1p_queries/results_10.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 751636.40it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.97it/s]


Saving results to data/FB15k-237/all_1p_queries/results_11.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 709612.06it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.82it/s]


Saving results to data/FB15k-237/all_1p_queries/results_12.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 715097.45it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.77it/s]


Saving results to data/FB15k-237/all_1p_queries/results_13.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 712549.92it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10371.36it/s]


Saving results to data/FB15k-237/all_1p_queries/results_14.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 751381.91it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10395.23it/s]


Saving results to data/FB15k-237/all_1p_queries/results_15.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 738436.40it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.43it/s]


Saving results to data/FB15k-237/all_1p_queries/results_16.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 739723.11it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10395.48it/s]


Saving results to data/FB15k-237/all_1p_queries/results_17.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 702124.80it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.59it/s]


Saving results to data/FB15k-237/all_1p_queries/results_18.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 30695.91it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10385.99it/s]


Saving results to data/FB15k-237/all_1p_queries/results_19.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 702928.49it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.42it/s]


Saving results to data/FB15k-237/all_1p_queries/results_20.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 748326.73it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10364.04it/s]


Saving results to data/FB15k-237/all_1p_queries/results_21.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 740888.65it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.74it/s]


Saving results to data/FB15k-237/all_1p_queries/results_22.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 750749.80it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.01it/s]


Saving results to data/FB15k-237/all_1p_queries/results_23.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 31179.87it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.14it/s]


Saving results to data/FB15k-237/all_1p_queries/results_24.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 717686.60it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10395.23it/s]


Saving results to data/FB15k-237/all_1p_queries/results_25.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 705147.85it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10391.47it/s]


Saving results to data/FB15k-237/all_1p_queries/results_26.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 748135.86it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.63it/s]


Saving results to data/FB15k-237/all_1p_queries/results_27.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 732875.30it/s]/it]
100%|██████████| 100000/100000 [00:13<00:00, 7616.17it/s]


Saving results to data/FB15k-237/all_1p_queries/results_28.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 740654.46it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10394.97it/s]


Saving results to data/FB15k-237/all_1p_queries/results_29.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 699802.95it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.71it/s]


Saving results to data/FB15k-237/all_1p_queries/results_30.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 710841.15it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.70it/s]


Saving results to data/FB15k-237/all_1p_queries/results_31.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 709842.64it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.54it/s]


Saving results to data/FB15k-237/all_1p_queries/results_32.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 733045.66it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.38it/s]


Saving results to data/FB15k-237/all_1p_queries/results_33.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 731340.57it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.54it/s]


Saving results to data/FB15k-237/all_1p_queries/results_34.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 746155.46it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10396.84it/s]


Saving results to data/FB15k-237/all_1p_queries/results_35.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 31139.91it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.36it/s]


Saving results to data/FB15k-237/all_1p_queries/results_36.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 706008.38it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.97it/s]


Saving results to data/FB15k-237/all_1p_queries/results_37.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 705730.40it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10399.53it/s]


Saving results to data/FB15k-237/all_1p_queries/results_38.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 753606.77it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.76it/s]


Saving results to data/FB15k-237/all_1p_queries/results_39.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 739291.54it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10350.19it/s]


Saving results to data/FB15k-237/all_1p_queries/results_40.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 742632.03it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.19it/s]


Saving results to data/FB15k-237/all_1p_queries/results_41.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 697088.36it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10396.47it/s]


Saving results to data/FB15k-237/all_1p_queries/results_42.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 706269.92it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10315.44it/s]


Saving results to data/FB15k-237/all_1p_queries/results_43.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 697938.62it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.04it/s]


Saving results to data/FB15k-237/all_1p_queries/results_44.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 736239.34it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10393.21it/s]


Saving results to data/FB15k-237/all_1p_queries/results_45.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 723026.22it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.85it/s]


Saving results to data/FB15k-237/all_1p_queries/results_46.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 743917.58it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10396.38it/s]


Saving results to data/FB15k-237/all_1p_queries/results_47.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 30829.28it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.96it/s]


Saving results to data/FB15k-237/all_1p_queries/results_48.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 707538.77it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.83it/s]


Saving results to data/FB15k-237/all_1p_queries/results_49.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 632470.50it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.87it/s]


Saving results to data/FB15k-237/all_1p_queries/results_50.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 741542.27it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.52it/s]


Saving results to data/FB15k-237/all_1p_queries/results_51.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 747622.45it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.11it/s]


Saving results to data/FB15k-237/all_1p_queries/results_52.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 30657.41it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.63it/s]


Saving results to data/FB15k-237/all_1p_queries/results_53.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 718096.99it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.43it/s]


Saving results to data/FB15k-237/all_1p_queries/results_54.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 695534.74it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.95it/s]


Saving results to data/FB15k-237/all_1p_queries/results_55.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 738413.00it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.29it/s]


Saving results to data/FB15k-237/all_1p_queries/results_56.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 743467.92it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.81it/s]


Saving results to data/FB15k-237/all_1p_queries/results_57.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 748982.86it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10370.65it/s]


Saving results to data/FB15k-237/all_1p_queries/results_58.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 693651.74it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.21it/s]


Saving results to data/FB15k-237/all_1p_queries/results_59.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 31234.58it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10367.84it/s]


Saving results to data/FB15k-237/all_1p_queries/results_60.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 713864.55it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10356.48it/s]


Saving results to data/FB15k-237/all_1p_queries/results_61.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 755701.86it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10396.83it/s]


Saving results to data/FB15k-237/all_1p_queries/results_62.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 743317.71it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.71it/s]


Saving results to data/FB15k-237/all_1p_queries/results_63.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 753429.44it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10398.55it/s]


Saving results to data/FB15k-237/all_1p_queries/results_64.json


Creating CQD file: 100%|██████████| 100000/100000 [00:03<00:00, 31569.05it/s]s/it]
100%|██████████| 100000/100000 [00:09<00:00, 10395.25it/s]


Saving results to data/FB15k-237/all_1p_queries/results_65.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 721709.95it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10350.71it/s]


Saving results to data/FB15k-237/all_1p_queries/results_66.json


Creating CQD file: 100%|██████████| 100000/100000 [00:00<00:00, 709392.43it/s]/it]
100%|██████████| 100000/100000 [00:09<00:00, 10397.81it/s]


Saving results to data/FB15k-237/all_1p_queries/results_67.json


Creating CQD file: 100%|██████████| 75370/75370 [00:00<00:00, 737625.02it/s]7s/it]
100%|██████████| 75370/75370 [00:07<00:00, 10401.49it/s]


Saving results to data/FB15k-237/all_1p_queries/results_68.json


Creating CQD files for 1p queries: 100%|██████████| 69/69 [21:40<00:00, 18.85s/it]


In [30]:
# read all json files and merge them into a single one
import json
import glob

def merge_json_files(directory: str, output_file: str):
    all_data = []
    for filename in tqdm(glob.glob(os.path.join(directory, '*.json')), desc="Merging JSON files"):
        with open(filename, 'r') as f:
            data = json.load(f)
            all_data.extend(data)
    
    with open(output_file, 'w') as f:
        json.dump(all_data, f)
        
merge_json_files('data/FB15k-237/all_1p_queries', 'data/FB15k-237/all_1p_queries.json')

Merging JSON files: 100%|██████████| 69/69 [00:31<00:00,  2.17it/s]


In [53]:
import pandas as pd

link_predictions_df = pd.read_json('data/FB15k-237/all_1p_queries.json', orient='records')
link_predictions_df

Unnamed: 0,entity_id,relation_id,top_k_entities,top_k_scores
0,8227,402,"[1964, 7094, 5233, 3593, 6775]","[1.641523718833923, 1.5853451490402222, 1.5796..."
1,8227,403,"[5145, 12787, 400, 7208, 4947]","[6.992974281311035, 4.854063987731934, 4.71984..."
2,8227,404,"[1527, 7216, 4168, 226, 4994]","[9.617685317993164, 4.58512020111084, 4.580133..."
3,8227,405,"[2037, 7859, 1779, 7421, 8288]","[1.398288369178772, 1.334656715393066, 1.30624..."
4,8227,406,"[10948, 6607, 9650, 4005, 4]","[1.6829309463500972, 1.536678314208984, 1.4762..."
...,...,...,...,...
6875365,1265,385,"[8410, 1265, 2328, 1266, 1568]","[8.260475158691406, 7.059150695800781, 6.96028..."
6875366,1265,386,"[62, 4, 11, 163, 23]","[2.077280521392822, 1.39265489578247, 1.258444..."
6875367,1265,387,"[470, 862, 657, 6676, 818]","[1.121290802955627, 1.064321041107177, 1.04515..."
6875368,1265,388,"[1265, 2328, 8410, 1266, 1568]","[4.588125228881836, 3.088305473327636, 2.88703..."


In [56]:
# read results_44.json
tmp_df = pd.read_json('data/FB15k-237/all_1p_queries/results_44.json')
# convert tmp_df.loc[0] to a dataframe which index is top_k_entities and column scores is top_k_scores
tmp_df = pd.DataFrame(tmp_df.loc[0]['top_k_scores'], index=tmp_df.loc[0]['top_k_entities'], columns=['score'])
tmp_df = tmp_df.sort_values(by='score', ascending=False)
tmp_df

Unnamed: 0,score
62,2.415622
90,2.042978
141,1.919169
7265,1.865481
32,1.85036


In [None]:
def get_prediction(predictions_df: pd.DataFrame, entity: int, relation: int):
    """
    Get the prediction for a specific entity and relation from the predictions DataFrame.
    
    Args:
        predictions_df (pd.DataFrame): DataFrame containing the predictions.
        entity (int): The ID of the entity to get the prediction for.
        relation (int): The ID of the relation to get the prediction for.
        
    Returns:
        pd.DataFrame: A DataFrame containing the predicted entities and their scores, sorted by score in descending order.
    """
    filtered_df = predictions_df[(predictions_df['entity_id'] == entity) & (predictions_df['relation_id'] == relation)]
    if filtered_df.empty:
        return [], []
    
    predicted_entities = filtered_df['top_k_entities'].tolist()[0]
    scores = filtered_df['top_k_scores'].tolist()[0]
    
    # make a df which index is predicted_entities and column is scores
    predictions = pd.DataFrame(scores, index=predicted_entities, columns=['score'])
    predictions = predictions.sort_values(by='score', ascending=False)
    
    return predictions

In [67]:
entity = dataset.get_id_by_node(dataset.get_node_by_title('Goldfinger'))
relation = dataset.get_id_by_relation('/award/award_nominee/award_nominations./award/award_nomination/nominated_for_reverse')
results = get_prediction(link_predictions_df, entity=entity, relation=relation)
results

Unnamed: 0,score
2276,8.854484
3429,6.829935
5593,6.796679
4978,6.461661
5259,6.018537


In [48]:
for i in range(len(results[0])):
    entity_id = results[0][i]
    score = results[1][i]
    entity_name = dataset.get_node_by_id(entity_id)
    entity_title = dataset.get_title_by_node(entity_name)
    print(f"Entity: {entity_title} ({entity_id}), Score: {score:.4f}")

Entity: Ken Adam (2276), Score: 8.8545
Entity: Michael G. Wilson (3429), Score: 6.8299
Entity: Sean Connery (5593), Score: 6.7967
Entity: John Barry (4978), Score: 6.4617
Entity: Peter Lamont (5259), Score: 6.0185


In [49]:
# compute the average time taken for each query via testing on the first 1000 queries
import time
avg_time = 0.0
for i in tqdm(range(1000), desc="Calculating average time for 1p queries"):
    query = all_1p_queries[i]
    start_time = time.time()
    get_prediction(link_predictions_df, entity=query.get_query()[0][0], relation=query.get_query()[0][1][0])
    end_time = time.time()
    avg_time += (end_time - start_time)
avg_time /= 1000
print(f"Average time taken for each 1p query: {avg_time:.4f} seconds")

Calculating average time for 1p queries: 100%|██████████| 1000/1000 [00:12<00:00, 80.86it/s]

Average time taken for each 1p query: 0.0123 seconds





In [None]:
# cache all the possible predictions when using the symbolic reasoning
import pickle

def cache_symbolic_predictions(queries: list, result_path: str = 'data/FB15k-237/all_1p_queries_symbolic.json', k: int = 5):
    """
    Cache the symbolic predictions for all 1p queries.
    
    Args:
        queries (list): List of Query objects.
        output_file (str): Path to the output file where predictions will be cached.
    """
    results = []
    for query in tqdm(queries, desc="Caching symbolic predictions"):
        entity = query.get_query()[0][0]
        relation = query.get_query()[0][1][0]
        answers = symbolic.query_1p(entity, relation)
        fixed_answers = symbolic.fixed_size_answer(answers, k)
        top_k_entities = fixed_answers.index.tolist()
        top_k_scores = fixed_answers['score'].tolist()
        result = {
                    'entity_id': entity,
                    'relation_id': relation,
                    'top_k_entities': top_k_entities,
                    'top_k_scores': top_k_scores
                }
        results.append(result)

    with open(result_path, 'w') as f:
        json.dump(results, f, indent=4)

cache_symbolic_predictions(all_1p_queries, result_path='data/FB15k-237/all_1p_queries_symbolic.json', k=5)
# This takes so much time, so we will not run it for now.