In [2]:
import pickle

class Dataset:
    def __init__(self):
        self.id2node = {}
        self.node2id = {}
        self.id2rel = {}
        self.rel2id = {}
        self.node2title = {}
        self.title2node = {}

    def load_key_value_files(self, filename):
        '''
        Load key-value pairs from a file
        Args:
            filename (str): The name of the file to load.
            file_format (str): The format of the file ('pkl' or 'txt').
        Returns:
            tuple: (id2value, value2id) where:
                id2value (dict): Dictionary mapping IDs to values.
                value2id (dict): Dictionary mapping values to IDs.
        '''
        id2value = {}
        value2id = {}
        file_format = filename.split('.')[-1]
        if file_format == 'pkl':
            with open(filename, 'rb') as f:
                id2value = pickle.load(f)
        elif file_format == 'txt':
            with open(filename, 'r') as f:
                id2value = {}
                for line in f:
                    id, value = line.strip().split('\t')
                    id2value[id] = value
        else:
            raise ValueError("Unsupported file format. Use 'pkl' or 'txt'.")
        value2id = {v: k for k, v in id2value.items()}
        return id2value, value2id
    
    def set_id2node(self, filename):
        id2node, node2id = self.load_key_value_files(filename)
        self.id2node = id2node
        self.node2id = node2id
        print(f"Loaded {len(self.id2node)} nodes from {filename}.")

    def get_node_by_id(self, node_id):
        return self.id2node.get(node_id, None)
    
    def get_id_by_node(self, node):
        return self.node2id.get(node, None)

    def set_id2rel(self, filename):
        id2rel, rel2id = self.load_key_value_files(filename)
        self.id2rel = id2rel
        self.rel2id = rel2id
        print(f"Loaded {len(self.id2rel)} relations from {filename}.")

    def get_relation_by_id(self, rel_id):
        return self.id2rel.get(rel_id, None)
    
    def get_id_by_relation(self, relation):
        return self.rel2id.get(relation, None)

    def set_node2title(self, filename):
        node2title, title2node = self.load_key_value_files(filename)
        self.node_to_title = node2title
        self.title2node = title2node
        print(f"Loaded {len(self.node_to_title)} node titles from {filename}.")


    def get_title_by_node(self, node):
        return self.node_to_title.get(node, None)
    
    def get_node_by_title(self, title):
        return self.title2node.get(title, None)
    
    def get_num_nodes(self):
        return len(self.id2node)

In [3]:
import numpy as np

class Node:
    def __init__(self, name: str, id: int, title: str):
        self.id = id
        self.name = name
        self.title = title

    def get_id(self):
        return self.id

    def get_name(self):
        return self.name
    
    def get_title(self):
        return self.title

class Edge:
    def __init__(self, name:str, id: int, head: Node, tail: Node):
        self.id = id
        self.name = name
        self.head = head
        self.tail = tail

    def get_id(self):
        return self.id
    
    def get_name(self):
        return self.name

    def get_head(self):
        return self.head

    def get_tail(self):
        return self.tail

class Graph:
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        self.edges = []
    
    def add_edge(self, head: str, relation: str, tail: str, skip_missing: bool = True, add_reverse: bool = True):
        head_id = self.dataset.get_id_by_node(head)
        tail_id = self.dataset.get_id_by_node(tail)
        relation_id = self.dataset.get_id_by_relation(relation)
        if head_id is None and skip_missing:
            print(f'Node {head} not found in dataset, skipping edge')
        elif tail_id is None and skip_missing:
            print(f'Node {tail} not found in dataset, skipping edge')
        elif relation_id is None and skip_missing:
            print(f'Relation {relation} not found in dataset, skipping edge')
        else:
            head_node = Node(head, head_id, self.dataset.get_title_by_node(head))
            tail_node = Node(tail, tail_id, self.dataset.get_title_by_node(tail))
            edge = Edge(relation, relation_id, head_node, tail_node)
            self.edges.append(edge)
            if add_reverse:
                reverse_relation = f'{relation}_reverse'
                reverse_relation_id = self.dataset.get_id_by_relation(reverse_relation)
                reverse_edge = Edge(reverse_relation, reverse_relation_id, tail_node, head_node)
                self.edges.append(reverse_edge)

    def load_triples(self, filename: str, skip_missing: bool = True, add_reverse: bool = True):
        try:
            with open(filename, 'r') as f:
                for line in f:
                    head, relation, tail = line.strip().split('\t')
                    self.add_edge(head, relation, tail, skip_missing, add_reverse)
        except FileNotFoundError:
            raise ValueError(f'File {filename} not found')
        except Exception as e:
            raise ValueError(f'Error loading triples from {filename}: {e}')
        
    def get_num_edges(self):
        return len(self.edges)
    
    def get_edges(self):
        return self.edges

    def get_num_nodes(self):
        return self.dataset.get_num_nodes()

In [4]:
dataset = Dataset()

data_dir = 'data/FB15k-237'
dataset.set_id2node(f'{data_dir}/ind2ent.pkl')
dataset.set_id2rel(f'{data_dir}/ind2rel.pkl')
dataset.set_node2title(f'{data_dir}/extra/entity2text.txt')

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.


In [5]:
graph_train = Graph(dataset)
graph_train.load_triples(f'{data_dir}/train.txt', skip_missing=False, add_reverse=True)
graph_train.get_num_nodes(), graph_train.get_num_edges()

(14505, 544230)

In [6]:
graph_valid = Graph(dataset)
# add training edges to validation graph
for edge in graph_train.get_edges():
    # we set add_reverse=False because it already exists in the training graph
    graph_valid.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
graph_valid.load_triples(f'{data_dir}/valid.txt', skip_missing=False, add_reverse=True)
graph_valid.get_num_nodes(), graph_valid.get_num_edges()

(14505, 579300)

In [7]:
graph_test = Graph(dataset)
# add training and validation edges to test graph (validation graph contains all training edges)
for edge in graph_valid.get_edges():
    # we set add_reverse=False because it already exists in the validation graph
    graph_test.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
graph_test.load_triples(f'{data_dir}/test.txt', skip_missing=False, add_reverse=True)
graph_test.get_num_nodes(), graph_test.get_num_edges()

(14505, 620232)

In [71]:
count = 0
with open(f'{data_dir}/all_combinations.txt', 'w') as f:
    for node in dataset.id2node.values():
        for relation in dataset.id2rel.values():
            node_id = dataset.get_id_by_node(node)
            relation_id = dataset.get_id_by_relation(relation)
            if node_id is not None and relation_id is not None:
                f.write(f'{node_id}\t{relation_id}\n')
            count += 1
print(f'Number of all combinations: {count}')

Number of all combinations: 6875370


In [72]:
# read all combinations in a dataframe
import pandas as pd

link_df = pd.read_csv(f'{data_dir}/all_combinations.txt', sep='\t', header=None, names=['node', 'relation'])
link_df

Unnamed: 0,node,relation
0,0,0
1,0,1
2,0,2
3,0,3
4,0,4
...,...,...
6875365,14504,469
6875366,14504,470
6875367,14504,471
6875368,14504,472


In [8]:
class Query:
    def __init__(self, query_type: str, query_answer: tuple):
        self.query_type = query_type
        if len(query_answer) != 2:
            raise ValueError("Query answer must be a tuple of (query, answer)")
        elif type(query_answer[1]) is not list:
            raise ValueError("Query answer must be a tuple of (query, answer) where answer is a list")
        self.query = query_answer[0]
        self.answer = query_answer[1]

    def get_query(self):
        return self.query
    
    def get_answer(self):
        return self.answer
    
    def __repr__(self):
        return f"Query(type={self.query_type}, query={self.query}, answer={self.answer})"
    
class QueryDataset:
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        self.queries = {}

    def add_query(self, query_type: str, query_answer: tuple):
        if query_type not in self.queries:
            self.queries[query_type] = []
        query = Query(query_type, query_answer)
        self.queries[query_type].append(query)

    def get_queries(self, query_type: str):
        if query_type not in self.queries:
            raise ValueError(f"No queries of type {query_type} found")
        return self.queries[query_type]
    
    def get_all_queries(self):
        all_queries = []
        for query_type, queries in self.queries.items():
            all_queries.extend(queries)
        return all_queries
    
    def get_num_queries(self):
        return sum(len(queries) for queries in self.queries.values())
    
    def get_num_queries_by_type(self, query_type: str):
        if query_type not in self.queries:
            return 0
        return len(self.queries[query_type])
    
    def load_queries_from_pkl(self, filename: str, query_type: str = ''):
        try:
            with open(filename, 'rb') as f:
                queries = pickle.load(f)
                for query, answer in queries.items():
                    answer = list(answer)
                    self.add_query(query_type, (query, answer))
        except FileNotFoundError:
            raise ValueError(f'File {filename} not found')
        except Exception as e:
            raise ValueError(f'Error loading queries from {filename}: {e}')
        
def human_readable(query: Query, dataset: Dataset):
    if query.query_type == '2p':
        anchor = query.query[0][0]
        relations = query.query[0][1]
        rel1 = relations[0]
        rel2 = relations[1]
        anchor_name = dataset.get_node_by_id(anchor)
        rel1_name = dataset.get_relation_by_id(rel1)
        rel2_name = dataset.get_relation_by_id(rel2)
        anchor_title = dataset.get_title_by_node(anchor_name)
        answers_titles = [dataset.get_title_by_node(dataset.get_node_by_id(a)) for a in query.answer]
        print(f"Query:\n{anchor_title}\t--{rel1_name}-->\tV")
        print(f"V\t--{rel2_name}-->\t?")
        print(f"\nAnswer Set (?): \n{answers_titles}")

In [9]:
import pandas as pd

class SymbolicReasoning:
    def __init__(self, graph: Graph, logging: bool = True):
        self.graph = graph
        self.logging = logging

    def query_1p(self, head: int, relation: int):
        if self.logging:
            print(f"Querying for head: {self.graph.dataset.get_title_by_node(self.graph.dataset.get_node_by_id(head))} ({head} | {self.graph.dataset.get_node_by_id(head)}) and relation: {self.graph.dataset.get_relation_by_id(relation)} ({relation})")
        answers = []
        edges = self.graph.get_edges()
        for edge in edges:
            if edge.get_head().get_id() == head and edge.get_id() == relation:
                if self.logging:
                    print(f"Found edge: {edge.get_head().get_title()} --{edge.get_name()}--> {edge.get_tail().get_title()} ({edge.get_tail().get_id()})")
                answers.append(edge.get_tail().get_id())
        if self.logging:
            print("-" * 50)
        return list(set(answers))
    
    def query_2p(self, head: int, relations: tuple):
        first_level_answers = self.query_1p(head, relations[0])
        second_level_answers = {}
        for answer in first_level_answers:
            second_level_answers[answer] = self.query_1p(answer, relations[1])
        answers_set = set()
        for answer, second_level in second_level_answers.items():
            for item in second_level:
                answers_set.add(item)
        return second_level_answers, list(answers_set)
    
    def fixed_size_answer(self, answers: list, size: int):
        # make a dataframe which the index are answers and there is a column called score which the value is 1 for all answers
        array = np.full((len(answers), 1), 1)
        answers = np.array(answers)
        df = pd.DataFrame(array, index=answers, columns=['score'])

        if len(df) < size:
            # add random nodes to fill the size
            all_nodes = list(self.graph.dataset.id2node.keys())
            all_nodes_remaining = [node for node in all_nodes if node not in df.index]
            additional_nodes = np.random.choice(all_nodes_remaining, size - len(df), replace=False)
            additional_nodes = [int(node) for node in additional_nodes]
            # add them with score 0
            additional_df = pd.DataFrame(np.zeros((len(additional_nodes), 1)), index=additional_nodes, columns=['score'])
            if len(df) == 0:
                df = additional_df
            else:
                df = pd.concat([df, additional_df])
        elif len(df) > size:
            # truncate the dataframe to the size
            df = df.sample(size, replace=False)
        return df

In [10]:
dir_query = 'data/FB15k-237/test_ans_2c.pkl'
sample_query_type = '2p'
query_dataset = QueryDataset(dataset)
query_dataset.load_queries_from_pkl(dir_query, query_type=sample_query_type)

In [11]:
sample_idx = 4000
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]


In [12]:
query_dataset.get_num_queries()

5000

In [13]:
dir_query = 'data/FB15k-237/test_ans_2c_hard.pkl'
sample_query_type = '2p'
query_dataset_hard = QueryDataset(dataset)
query_dataset_hard.load_queries_from_pkl(dir_query, query_type=sample_query_type)

In [14]:
sample_idx = 4000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
human_readable(query_hard, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']


In [15]:
query_dataset_hard.get_num_queries()

5000

In [16]:
def accuracy(query: Query, answers: list):
    correct_answers = set(query.get_answer())
    predicted_answers = set(answers)
    if len(correct_answers) == 0:
        return 0.0
    return len(correct_answers.intersection(predicted_answers)) / len(correct_answers)

In [17]:
reasoner_train = SymbolicReasoning(graph_train)

sample_idx = 4000
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

middle_steps, answers_test = reasoner_train.query_2p(query.get_query()[0][0], query.get_query()[0][1])
print(f"Answers from test graph: {middle_steps}")
print(f"Final Answers: {answers_test}")
print(f"Expected Answers: {query.get_answer()}")
print(f"Accuracy: {accuracy(query, answers_test)}")

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Bachelor of Science', 'PhD', 'Doctorate', 'Bachelor of Arts', "Bachelor's degree"]
Querying for head: Lamar Odom (12324 | /m/02_nkp) and relation: /education/educational_institution/students_graduates./education/education/student_reverse (45)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Rhode Island (4074)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Nevada, Las Vegas (9463)
--------------------------------------------------
Querying for head: University of Rhode Island (4074 | /m/02fjzt) and relation: /education/educational_degree/people_with_this_degree./education/education/i

In [18]:
reasoner_test = SymbolicReasoning(graph_test)

middle_steps, answers_test = reasoner_test.query_2p(query.get_query()[0][0], query.get_query()[0][1])
print(f"Answers from test graph: {middle_steps}")
print(f"Final Answers: {answers_test}")
print(f"Expected Answers: {query.get_answer()}")
print(f"Accuracy: {accuracy(query, answers_test)}")

Querying for head: Lamar Odom (12324 | /m/02_nkp) and relation: /education/educational_institution/students_graduates./education/education/student_reverse (45)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Rhode Island (4074)
Found edge: Lamar Odom --/education/educational_institution/students_graduates./education/education/student_reverse--> University of Nevada, Las Vegas (9463)
--------------------------------------------------
Querying for head: University of Rhode Island (4074 | /m/02fjzt) and relation: /education/educational_degree/people_with_this_degree./education/education/institution_reverse (179)
Found edge: University of Rhode Island --/education/educational_degree/people_with_this_degree./education/education/institution_reverse--> PhD (587)
Found edge: University of Rhode Island --/education/educational_degree/people_with_this_degree./education/education/institution_reverse--> Bachelor o

In [19]:
reasoner_test.fixed_size_answer(answers_test, 10)

Unnamed: 0,score
706,1.0
587,1.0
3181,1.0
1177,1.0
1566,1.0
14197,0.0
4968,0.0
10371,0.0
10184,0.0
6673,0.0


In [20]:
reasoner_train = SymbolicReasoning(graph_train, logging=False)
reasoner_valid = SymbolicReasoning(graph_valid, logging=False)
reasoner_test = SymbolicReasoning(graph_test, logging=False)

answers_train = reasoner_train.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]
answers_valid = reasoner_valid.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]
answers_test = reasoner_test.query_2p(query.get_query()[0][0], query.get_query()[0][1])[1]

print(f"Train Accuracy: {accuracy(query, answers_train)}")
print(f"Valid Accuracy: {accuracy(query, answers_valid)}")
print(f"Test Accuracy: {accuracy(query, answers_test)}")

Train Accuracy: 0.8
Valid Accuracy: 0.8
Test Accuracy: 1.0


In [21]:
import pickle

def create_cqd_file(query, original_file='data/FB15k-237/FB15k-237_test_hard.pkl', output_file='data/FB15k-237/FB15k-237_test_hard_sample1.pkl'):
    with open(original_file, 'rb') as f:
        data_hard = pickle.load(f)
    
    # remove all queries except the first one
    data_hard.type1_1chain = [data_hard.type1_1chain[0]]
    
    # replace the query with the provided one (note that the target is not important here, so we set it to 0)
    data_hard.type1_1chain[0].data['raw_chain'] = [query[0][0], query[0][1][0], [0]]
    data_hard.type1_1chain[0].data['anchors'] = [query[0][0]]
    data_hard.type1_1chain[0].data['optimisable'] = [-1, 0]
    data_hard.type1_1chain[0].data['targets'] = [0]
    
    with open(output_file, 'wb') as f:
        pickle.dump(data_hard, f)

In [22]:
query.get_query()

((12324, (45, 179)),)

In [23]:
create_cqd_file(query.get_query(), output_file='data/FB15k-237/FB15k-237_test_hard_sample.pkl')

In [24]:
import argparse
import torch
from kbc.cqd_co_xcqa import main
import pandas as pd

def cqd_query(query: Query, sample_path: str, result_path: str, k: int = 10):
    
    if sample_path is None:
        sample_path = 'data/FB15k-237/FB15k-237_test_hard_sample.pkl'
    if result_path is None:
        result_path = 'scores.pt'

    # Create a CQD file with the query
    create_cqd_file(query.get_query(), output_file=sample_path)

    # Set up the arguments for the CQD model (cqd_co_xcqa)
    args = argparse.Namespace(
        path = 'FB15k-237',
        sample_path = 'data/FB15k-237/FB15k-237_test_hard_sample.pkl',
        model_path = 'models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt',
        dataset = 'FB15k-237',
        mode = 'test',
        chain_type = '1_1', # '1_1', '1_2', '2_2', '2_2_disj', '1_3', '2_3', '3_3', '4_3', '4_3_disj', '1_3_joint'
        t_norm = 'prod', # 'min', 'prod'
        reg = None,
        lr = 0.1,
        optimizer='adam', # 'adam', 'adagrad', 'sgd'
        max_steps = 1000,
        sample = True,
        result_path = result_path
    )

    # Run the CQD model
    main(args)

    # Load the scores
    scores = torch.load(result_path)
    scores_np = scores.cpu().numpy()

    # Create a DataFrame with the scores
    df = pd.DataFrame({'score': scores_np[0]})
    df = df.sort_values(by='score', ascending=False)
    
    # Get the top k answers
    top_k_answers = df.head(k)
    
    return top_k_answers

In [25]:
cqd_query(query, sample_path='data/FB15k-237/FB15k-237_test_hard_sample.pkl', result_path='scores.pt', k=10)

ComplEx(
  (embeddings): ModuleList(
    (0): Embedding(14505, 2000, sparse=True)
    (1): Embedding(474, 2000, sparse=True)
  )
)


Unnamed: 0,score
9463,10.657951
4074,10.455688
7265,6.088031
4683,5.748784
1236,5.67069
3169,5.650109
2173,5.561179
6483,5.536769
5153,5.439585
3895,5.423999


## Example Usage

In [26]:
dataset = Dataset()

data_dir = 'data/FB15k-237'
dataset.set_id2node(f'{data_dir}/ind2ent.pkl')
dataset.set_id2rel(f'{data_dir}/ind2rel.pkl')
dataset.set_node2title(f'{data_dir}/extra/entity2text.txt')

graph_train = Graph(dataset)
graph_train.load_triples(f'{data_dir}/train.txt', skip_missing=False, add_reverse=True)
graph_train.get_num_nodes(), graph_train.get_num_edges()

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.


(14505, 544230)

In [27]:
dir_query = 'data/FB15k-237/test_ans_2c_hard.pkl'
sample_query_type = '2p'
query_dataset = QueryDataset(dataset)
query_dataset.load_queries_from_pkl(dir_query, query_type=sample_query_type)

In [28]:
sample_idx = 4000
query = query_dataset.get_queries(sample_query_type)[sample_idx]
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']


In [29]:
symolic = SymbolicReasoning(graph_train, logging=False)

In [30]:
import time

def query_execution(query: Query, k: int = 10, coalition: list = None, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True):

    if query.query_type == '2p':
        anchor = query.get_query()[0][0]
        relations = query.get_query()[0][1]

        first_level_answers = None
        final_answers = None

        time_start = time.time()
        if coalition[0] == 1:
            current_query = Query('1p', (((anchor, (relations[0],)),), []))
            create_cqd_file(current_query.get_query(), output_file='data/FB15k-237/FB15k-237_test_hard_sample.pkl')
            first_level_answers = cqd_query(current_query, sample_path='data/FB15k-237/FB15k-237_test_hard_sample.pkl', result_path='scores.pt', k=k)
        elif coalition[0] == 0:
            first_level_answers = symolic.query_1p(anchor, relations[0])
            first_level_answers = symolic.fixed_size_answer(first_level_answers, 10)
        time_end = time.time()

        if logging:
            print(f"Time taken for first level query: {time_end - time_start:.2f} seconds")

        start_time = time.time()
        for answer_idx, row in first_level_answers.iterrows():
            if coalition[1] == 1:
                current_query = Query('1p', (((answer_idx, (relations[1],)),), []))
                create_cqd_file(current_query.get_query(), output_file='data/FB15k-237/FB15k-237_test_hard_sample.pkl')
                second_level_answers = cqd_query(current_query, sample_path='data/FB15k-237/FB15k-237_test_hard_sample.pkl', result_path='scores.pt', k=k)
            elif coalition[1] == 0:
                second_level_answers = symolic.query_1p(answer_idx, relations[1])
                second_level_answers = symolic.fixed_size_answer(second_level_answers, 10)
            if t_norm == 'prod':
                    second_level_answers['score'] = second_level_answers['score'] * row['score']
            elif t_norm == 'min':
                    second_level_answers['score'] = second_level_answers['score'].min(row['score'])
            second_level_answers['path'] = str((anchor, relations[0], answer_idx, relations[1]))
            if final_answers is None:
                final_answers = second_level_answers
            else:
                final_answers = pd.concat([final_answers, second_level_answers], axis=0)
        time_end = time.time()
        if logging:
            print(f"Time taken for second level query: {time_end - start_time:.2f} seconds")
                 
    else:
        raise ValueError(f"Unsupported query type: {query.query_type}. Only '2p' queries are supported.")
    final_answers = final_answers.sort_values(by='score', ascending=False)

    # if we have duplicate answers, we need to keep only the one with the highest score
    # as the final_answers dataframe is already sorted by score, we can just keep the first occurrence of each index which means that we keep the highest score
    final_answers = final_answers[~final_answers.index.duplicated(keep='first')]

    # the output should be a dataframe of scores for each possible node in the graph
    df = pd.DataFrame(index=dataset.id2node.keys(), columns=['score', 'path'])
    df['score'] = 0.0
    for answer in final_answers.index:
        df.loc[answer, 'score'] = final_answers.loc[answer, 'score']
        df.loc[answer, 'path'] = final_answers.loc[answer, 'path']
    # shuffle the data to have a random order of answers and a fair measurement of performance
    # df = df.sample(frac=1)
    # sort the dataframe by score in descending order
    df = df.sort_values(by='score', ascending=False)
        
    return df

In [31]:
sample_idx = 4000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
query = query_dataset.get_queries(sample_query_type)[sample_idx]
easy_answers = query.get_answer()
easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
human_readable(query_hard, dataset)
print("-"*70)
human_readable(query, dataset)

Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']
----------------------------------------------------------------------
Query:
Lamar Odom	--/education/educational_institution/students_graduates./education/education/student_reverse-->	V
V	--/education/educational_degree/people_with_this_degree./education/education/institution_reverse-->	?

Answer Set (?): 
['Doctorate']


In [32]:
query_execution(query, k=5, coalition=[0, 0], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 0.08 seconds
Time taken for second level query: 0.81 seconds


Unnamed: 0,score,path
706,1.0,"(12324, 45, 4074, 179)"
1566,1.0,"(12324, 45, 9463, 179)"
587,1.0,"(12324, 45, 4074, 179)"
1177,1.0,"(12324, 45, 4074, 179)"
9664,0.0,
...,...,...
4843,0.0,
4844,0.0,
4845,0.0,
4846,0.0,


In [33]:
query_execution(query, k=5, coalition=[0, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 0.08 seconds
Time taken for second level query: 31.34 seconds


Unnamed: 0,score,path
1177,8.299611,"(12324, 45, 4074, 179)"
706,8.166074,"(12324, 45, 4074, 179)"
1566,7.979342,"(12324, 45, 9463, 179)"
587,7.124403,"(12324, 45, 4074, 179)"
1019,6.903434,"(12324, 45, 4074, 179)"
...,...,...
4845,0.000000,
4846,0.000000,
4847,0.000000,
4848,0.000000,


In [34]:
query_execution(query, k=5, coalition=[1, 0], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 3.15 seconds
Time taken for second level query: 0.40 seconds


Unnamed: 0,score,path
1177,10.657951,"(12324, 45, 9463, 179)"
1566,10.657951,"(12324, 45, 9463, 179)"
587,10.455688,"(12324, 45, 4074, 179)"
706,10.455688,"(12324, 45, 4074, 179)"
4709,6.088031,"(12324, 45, 7265, 179)"
...,...,...
4848,0.000000,
4849,0.000000,
4850,0.000000,
4851,0.000000,


In [35]:
query_execution(query, k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 3.19 seconds
Time taken for second level query: 15.86 seconds


Unnamed: 0,score,path
1177,86.778137,"(12324, 45, 4074, 179)"
706,85.381912,"(12324, 45, 4074, 179)"
1566,85.043434,"(12324, 45, 9463, 179)"
587,74.490540,"(12324, 45, 4074, 179)"
1019,72.180145,"(12324, 45, 4074, 179)"
...,...,...
4846,0.000000,
4847,0.000000,
4848,0.000000,
4849,0.000000,


In [36]:
query_execution(query, k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 3.15 seconds
Time taken for second level query: 15.96 seconds


Unnamed: 0,score,path
1177,86.778137,"(12324, 45, 4074, 179)"
706,85.381912,"(12324, 45, 4074, 179)"
1566,85.043434,"(12324, 45, 9463, 179)"
587,74.490540,"(12324, 45, 4074, 179)"
1019,72.180145,"(12324, 45, 4074, 179)"
...,...,...
4846,0.000000,
4847,0.000000,
4848,0.000000,
4849,0.000000,


In [37]:
query.answer

[3181]

In [38]:
easy_answers

[]

In [39]:
query_hard.answer

[3181]

In [40]:
def value_function(query: Query, easy_answers: list, target_entity: int, qoi: str = 'rank', k: int = 10, coalition: list = None, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True):
    if sum(coalition) == 0:
        # this is the requirement of shapley values definition
        return 0
    else:
        result = query_execution(query, k=k, coalition=coalition, t_norm=t_norm, t_conorm=t_conorm, logging=logging)
        # remove easy answers from the result
        result = result[~result.index.isin(easy_answers)]
        if qoi == 'rank':
            if target_entity in result.index:
                value = result.index.get_loc(target_entity)
            else:
                raise ValueError(f"Target entity {target_entity} not found in the result")
        elif qoi == 'hit1':
            value = 1 if target_entity in result.index[:1] else 0
        elif qoi == 'hit3':
            value = 1 if target_entity in result.index[:3] else 0
        elif qoi == 'hit10':
            value = 1 if target_entity in result.index[:10] else 0
        else:
            raise ValueError(f"Unsupported QoI: {qoi}. Supported values are 'rank', 'hit1', 'hit3', 'hit10'.")
        return value

In [41]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[0, 0], t_norm='prod', t_conorm='min', logging=True)

0

In [42]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[0, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 0.08 seconds
Time taken for second level query: 34.78 seconds


8306

In [43]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 0], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 1.83 seconds
Time taken for second level query: 0.40 seconds


5

In [44]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 4.73 seconds
Time taken for second level query: 16.75 seconds


8306

In [45]:
import math

def shapley_value(query: Query, atom_idx: int, easy_answers: list, target_entity: int, qoi: str = 'rank', k: int = 10, t_norm: str = 'prod', t_conorm: str = 'min', logging: bool = True):
    num_atoms = 0
    if query.query_type == '2p':
        num_atoms = 2
    else:
        raise ValueError(f"Unsupported query type: {query.query_type}. Only '2p' queries are supported.")

    shapley_value = 0.0
    
    num_remaining_atoms = num_atoms - 1
    for i in range(2**num_remaining_atoms):
        coalition = [int(x) for x in bin(i)[2:].zfill(num_remaining_atoms)]
        # create the coalition in the format of a list of 0s and 1s
        coalition_mask = [0] * num_atoms
        counter = 0
        for idx, _ in enumerate(coalition_mask):
            if idx== atom_idx:
                coalition_mask[idx] = 0
            else:
                coalition_mask[idx] = coalition[counter]
                counter += 1
        if logging:
            print(f"Coalition: {coalition_mask}, Atom Index: {atom_idx}")

        # calculate the weight term (|S|! (p-|S|-1)! \ p!
        weight = (math.factorial(sum(coalition)) * math.factorial(num_atoms - sum(coalition) - 1)) / math.factorial(num_atoms)
        
        # calculate the value function for the current coalition
        value = value_function(query, easy_answers, target_entity=target_entity, qoi=qoi, k=k, coalition=coalition_mask, t_norm=t_norm, t_conorm=t_conorm, logging=logging)
        
        # calculate the contribution of the current coalition when the atom is added
        added_coalition_mask = coalition_mask.copy()
        added_coalition_mask[atom_idx] = 1
        added_value = value_function(query, easy_answers, target_entity=target_entity, qoi=qoi, k=k, coalition=added_coalition_mask, t_norm=t_norm, t_conorm=t_conorm, logging=logging)
        
        # compute the difference
        contribution = added_value - value
        if logging:
            print(f"Coalition: {coalition_mask}, Contribution: {contribution} (before adding atom: {value}, after adding atom: {added_value}), weight: {weight})")
            
        # add the contribution to the shapley value
        shapley_value += contribution * weight
        
    if logging:
        print(f"Shapley value for atom {atom_idx}: {shapley_value}")
    return shapley_value

## Efficiency Axiom

Shapley value of the first player (atom 0):

In [46]:
shapley_value(query, atom_idx=0, easy_answers=easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 0
Time taken for first level query: 3.41 seconds
Time taken for second level query: 0.41 seconds
Coalition: [0, 0], Contribution: 5 (before adding atom: 0, after adding atom: 5), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Time taken for first level query: 0.08 seconds
Time taken for second level query: 33.03 seconds
Time taken for first level query: 3.45 seconds
Time taken for second level query: 15.04 seconds
Coalition: [0, 1], Contribution: 0 (before adding atom: 8306, after adding atom: 8306), weight: 0.5)
Shapley value for atom 0: 2.5


2.5

Shapley value of the second player (atom 1):

In [47]:
shapley_value(query, atom_idx=1, easy_answers=easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 1
Time taken for first level query: 0.08 seconds
Time taken for second level query: 34.80 seconds
Coalition: [0, 0], Contribution: 8306 (before adding atom: 0, after adding atom: 8306), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Time taken for first level query: 3.45 seconds
Time taken for second level query: 0.40 seconds
Time taken for first level query: 1.76 seconds
Time taken for second level query: 17.92 seconds
Coalition: [1, 0], Contribution: 8301 (before adding atom: 5, after adding atom: 8306), weight: 0.5)
Shapley value for atom 1: 8303.5


8303.5

Value function when every atom is present in the coalition:

In [48]:
value_function(query, easy_answers, target_entity=query_hard.answer[0], qoi='rank', k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True)

Time taken for first level query: 3.39 seconds
Time taken for second level query: 16.83 seconds


8306

Final check:

In [49]:
# efficiency check
8306 == 8303.5 + 2.5

True

## Shapley Value Examples

In [63]:
sample_idx = 3000
query_hard = query_dataset_hard.get_queries(sample_query_type)[sample_idx]
query = query_dataset.get_queries(sample_query_type)[sample_idx]
easy_answers = query.get_answer()
easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
human_readable(query_hard, dataset)
print("-"*70)
human_readable(query, dataset)

Query:
Ja Rule	--/people/person/spouse_s./people/marriage/type_of_union-->	V
V	--/people/marriage_union_type/unions_of_this_type./people/marriage/location_of_ceremony-->	?

Answer Set (?): 
['Parma', 'Pittsburgh', 'Andhra Pradesh', 'Darmstadt', 'Marburg', 'Lincoln', 'County Kildare', 'Oahu', 'Connecticut', 'Barnet', 'Massachusetts', 'Anguilla', 'Caracas', 'Fiji', 'San Jose', 'Windsor', 'Newport', 'City of Winchester', 'Reading', 'Singapore', 'Fort Lauderdale', 'Fall River', 'Sonoma County', 'Nantucket', 'Yuma', 'Hawaii', 'Nevada', 'Nuremberg', 'Halle', 'Osaka', 'Helsinki', 'Greenwich', 'Hartford', 'Brussels', 'Islamabad', 'Istanbul', 'Erie', 'Long Island', 'Salzburg', 'Richmond', 'Lebanon', 'Las Vegas', 'Nashville', 'Cumberland County', 'Duval County', 'Buenos Aires', 'Newburyport', 'Prague', 'Weimar', 'Luxembourg', 'Guildford', 'Fort Smith', 'Clark County']
----------------------------------------------------------------------
Query:
Ja Rule	--/people/person/spouse_s./people/marriage/

In [64]:
symolic.query_1p(query.get_query()[0][0], query.get_query()[0][1][0])

[434]

In [None]:
cqd_result = query_execution(query, k=5, coalition=[1, 1], t_norm='prod', t_conorm='min', logging=True)
human_df = cqd_result.copy()
human_df['title'] = human_df.index.map(dataset.id2node)
human_df['title'] = human_df['title'].map(dataset.get_title_by_node)
human_df['is_easy_answer'] = human_df.index.isin(easy_answers)
human_df['is_hard_answer'] = human_df.index.isin(query_hard.get_answer())
human_df

Time taken for first level query: 4.38 seconds
Time taken for second level query: 17.34 seconds


Unnamed: 0,score,path,title,is_easy_answer,is_hard_answer
6851,55.762383,"(5594, 77, 2276, 164)",Judaism-GB,False,True
2456,51.476990,"(5594, 77, 2276, 164)",Atheism,False,False
1771,49.634693,"(5594, 77, 5593, 164)",Catholicism,False,False
7859,45.179855,"(5594, 77, 2276, 164)",Agnosticism,False,False
7142,41.937160,"(5594, 77, 5593, 164)",Buddhism,False,False
...,...,...,...,...,...
4842,0.000000,,Soviet Union,False,False
4843,0.000000,,Freiburg im Breisgau,False,False
4844,0.000000,,The Stand,False,False
4845,0.000000,,Yale Divinity School,False,False


Let's explain the true hard answer which the model could predict it correctly.

In [185]:
target = "Judaism-GB"
target_entity = dataset.get_node_by_title(target)
target_id = dataset.get_id_by_node(target_entity)
print(f"Target entity: {target} ({target_entity}) with ID {target_id}")

Target entity: Judaism-GB (/m/03_gx) with ID 6851


In [186]:
atom_idx = 0 # Goldfinger	--/award/award_nominee/award_nominations./award/award_nomination/nominated_for_reverse-->	V
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 0
Time taken for first level query: 1.83 seconds
Time taken for second level query: 0.41 seconds
Coalition: [0, 0], Contribution: 6827 (before adding atom: 0, after adding atom: 6827), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Time taken for first level query: 0.08 seconds
Time taken for second level query: 33.88 seconds
Time taken for first level query: 4.38 seconds
Time taken for second level query: 17.05 seconds
Coalition: [0, 1], Contribution: 0 (before adding atom: 0, after adding atom: 0), weight: 0.5)
Shapley value for atom 0: 3413.5


3413.5

In [187]:
atom_idx = 1 # V	--/people/person/religion-->	?
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 1
Time taken for first level query: 0.08 seconds
Time taken for second level query: 33.71 seconds
Coalition: [0, 0], Contribution: 0 (before adding atom: 0, after adding atom: 0), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Time taken for first level query: 4.72 seconds
Time taken for second level query: 0.42 seconds
Time taken for first level query: 1.84 seconds
Time taken for second level query: 17.09 seconds
Coalition: [1, 0], Contribution: -6827 (before adding atom: 6827, after adding atom: 0), weight: 0.5)
Shapley value for atom 1: -3413.5


-3413.5

Now, let's look at the explanation for another target entity, which is not in the answer set of the query.

In [188]:
target = "Catholicism"
target_entity = dataset.get_node_by_title(target)
target_id = dataset.get_id_by_node(target_entity)
print(f"Target entity: {target} ({target_entity}) with ID {target_id}")

Target entity: Catholicism (/m/0c8wxp) with ID 1771


In [190]:
atom_idx = 0 # Goldfinger	--/award/award_nominee/award_nominations./award/award_nomination/nominated_for_reverse-->	V
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 0
Time taken for first level query: 4.42 seconds
Time taken for second level query: 0.41 seconds
Coalition: [0, 0], Contribution: 1795 (before adding atom: 0, after adding atom: 1795), weight: 0.5)
Coalition: [0, 1], Atom Index: 0
Time taken for first level query: 0.08 seconds
Time taken for second level query: 33.44 seconds
Time taken for first level query: 1.83 seconds
Time taken for second level query: 16.64 seconds
Coalition: [0, 1], Contribution: -1 (before adding atom: 3, after adding atom: 2), weight: 0.5)
Shapley value for atom 0: 897.0


897.0

In [191]:
atom_idx = 1 # V	--/people/person/religion-->	?
shapley_value(query, atom_idx=atom_idx, easy_answers=easy_answers, target_entity=target_id, qoi='rank', k=5, t_norm='prod', t_conorm='min', logging=True)

Coalition: [0, 0], Atom Index: 1
Time taken for first level query: 0.08 seconds
Time taken for second level query: 34.00 seconds
Coalition: [0, 0], Contribution: 3 (before adding atom: 0, after adding atom: 3), weight: 0.5)
Coalition: [1, 0], Atom Index: 1
Time taken for first level query: 4.33 seconds
Time taken for second level query: 0.41 seconds
Time taken for first level query: 1.81 seconds
Time taken for second level query: 16.53 seconds
Coalition: [1, 0], Contribution: -1793 (before adding atom: 1795, after adding atom: 2), weight: 0.5)
Shapley value for atom 1: -895.0


-895.0