In [1]:
def get_num_atoms(query_type):
    """Get the number of atoms for a given query type."""
    atom_mapping = {
        '2p': 2, '3p': 3, '2i': 2, '2u': 2, 
        '3i': 3, 'pi': 3, 'up': 3, 'ip': 3
    }
    if query_type not in atom_mapping:
        raise ValueError(f"Unsupported query type: {query_type}.")
    return atom_mapping[query_type]

def get_query_file_paths(data_dir, query_type, hard=False, split='test'):
    """Get file paths for query data based on query type."""
    prefix = split + '_ans_'
    file_mapping = {
        '2p': prefix + '2c', '3p': prefix + '3c', '2i': prefix + '2i', 
        '2u': prefix + '2u', '3i': prefix + '3i', 'pi': prefix + 'ci',
        'ip': prefix + 'ic', 'up': prefix + 'uc'
    }
    
    if query_type not in file_mapping:
        raise ValueError(f"Unsupported query type: {query_type}.")
    
    suffix = '_hard' if hard else ''
    filename = f"{file_mapping[query_type]}{suffix}.pkl"
    return f"{data_dir}/{filename}"

In [2]:
from graph import Dataset, Graph
from query import  QueryDataset

def setup_dataset_and_graphs(data_dir):
    """Setup dataset and graphs (train, valid, test)."""
    dataset = Dataset()
    dataset.set_id2node(f'{data_dir}/ind2ent.pkl')
    dataset.set_id2rel(f'{data_dir}/ind2rel.pkl')
    dataset.set_node2title(f'{data_dir}/extra/entity2text.txt')

    # Setup training graph
    graph_train = Graph(dataset)
    graph_train.load_triples(f'{data_dir}/train.txt', skip_missing=False, add_reverse=True)
    
    # Setup validation graph
    graph_valid = Graph(dataset)
    for edge in graph_train.get_edges():
        graph_valid.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
    graph_valid.load_triples(f'{data_dir}/valid.txt', skip_missing=False, add_reverse=True)
    
    return dataset, graph_train, graph_valid

def load_query_datasets(dataset, data_dir, query_type, split='test'):
    """Load query datasets for the specific query type."""
    # Load complete query dataset
    query_dataset = QueryDataset(dataset)
    query_path = get_query_file_paths(data_dir, query_type, hard=False, split=split)
    query_dataset.load_queries_from_pkl(query_path, query_type=query_type)
    
    # Load hard query dataset
    query_dataset_hard = QueryDataset(dataset)
    query_path_hard = get_query_file_paths(data_dir, query_type, hard=True, split=split)
    query_dataset_hard.load_queries_from_pkl(query_path_hard, query_type=query_type)
    
    return query_dataset, query_dataset_hard

In [3]:
dataset, graph_train, graph_valid = setup_dataset_and_graphs('data/FB15k-237')

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.


In [4]:
from symbolic_torch import SymbolicReasoning

reasoner = SymbolicReasoning(graph_valid, logging=False)

Edge: <graph.Edge object at 0x72337e525e40>, Head ID: 13432, Relation ID: 31
Edge: <graph.Edge object at 0x72337e36f700>, Head ID: 62, Relation ID: 51
Edge: <graph.Edge object at 0x72337e193880>, Head ID: 62, Relation ID: 51
Edge: <graph.Edge object at 0x72337df659c0>, Head ID: 62, Relation ID: 51
Edge: <graph.Edge object at 0x72337de7c640>, Head ID: 8009, Relation ID: 363
Edge: <graph.Edge object at 0x72337db1e440>, Head ID: 774, Relation ID: 237
Edge: <graph.Edge object at 0x72337db874c0>, Head ID: 14397, Relation ID: 31
Edge: <graph.Edge object at 0x72337da68220>, Head ID: 3387, Relation ID: 430
Edge: <graph.Edge object at 0x72337d6156c0>, Head ID: 774, Relation ID: 237


In [5]:
query_dataset, query_dataset_hard = load_query_datasets(dataset, 'data/FB15k-237', "2p", "test")

In [6]:
from xcqa_torch import XCQA

xcqa = XCQA(symbolic=reasoner, dataset=dataset, logging=False, model_path="models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt")

ComplEx(
  (embeddings): ModuleList(
    (0): Embedding(14505, 2000, sparse=True)
    (1): Embedding(474, 2000, sparse=True)
  )
)
Successfully loaded model and set to eval mode (device: cuda:0)


In [7]:
query_type = "2p"
num_atoms = get_num_atoms(query_type)

hard = query_dataset_hard.get_queries(query_type)
complete = query_dataset.get_queries(query_type)

I have a question regarding the necessary and sufficient evaluation. In the CQDA paper, they mention that they only consider predictions where ‚Äútheir original ùêª@1 and ùëÄùëÖùëÖ are both 1.0‚Äù. You told me to follow the same approach and for each query, to consider the hard answer predicted in the top position.

However, in the CQD implementation, when computing metrics such as Hit@1, they remove all easy answers as well as the other hard answers (except for the one currently under evaluation). As a result, it is possible for multiple hard answers in a single query to receive a rank of 1.

For example, suppose the model prediction is [4, 78, 3, 2], where 4 is an easy answer and 78 and 3 are hard answers. In this case, both 78 and 3 would have Hit@1 = 1. Would this be acceptable, and do you see any issue with averaging over all such hard answers?

In [8]:
def compute_metrics(result, answer_complete, target_answer):
    """
    result: pd.Series or pd.DataFrame with index = candidate answers, sorted by predicted score (descending).
    answer_complete: set/list/array of all correct answers (to exclude during filtered ranking).
    target_answer: the specific answer to evaluate.
    """

    mrr = 0.0
    hit_1 = 0
    hit_3 = 0
    hit_10 = 0

    # Convert to sets for fast exclusion
    answer_complete_set = set(answer_complete)

    # Get a filtered version of result for each a_hard
    result_index = list(result.index)

    # Remove all other correct answers from ranking for this answer
    filtered_exclude = answer_complete_set - {target_answer}
    # Build mask once for speed
    filtered_index = [x for x in result_index if x not in filtered_exclude]
    if target_answer in filtered_index:
        rank = filtered_index.index(target_answer) + 1
        mrr += 1.0 / rank
        if rank == 1:
            hit_1 += 1
        if rank <= 3:
            hit_3 += 1
        if rank <= 10:
            hit_10 += 1

    return mrr, hit_1, hit_3, hit_10

In [9]:
k = 10
cqd_coalition = [1] * num_atoms
symbolic_coalition = [0] * num_atoms
t_norm = "prod"
t_conorm = "prod"

In [20]:
import logging

def setup_logger(filename):
    # Create a logger
    logger = logging.getLogger(filename + "evaluation")
    logger.setLevel(logging.DEBUG)  # Set the lowest level of log messages to capture
    
    # Formatter to control the log output format
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    
    # Console handler (for terminal output)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(formatter)
    
    # File handler (for saving logs to a file)
    file_handler = logging.FileHandler(filename + ".log", mode="w")
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(formatter)
    
    # Avoid adding multiple handlers if function called multiple times
    if not logger.handlers:
        logger.addHandler(console_handler)
        logger.addHandler(file_handler)
    
    return logger

In [10]:
def get_metric_dict():
    return {
        "mrr": [],
        "hit_1": [],
        "hit_3": [],
        "hit_10": []
    }

## Necessary

In [21]:
logger = setup_logger("necessary")

In [22]:
from tqdm import tqdm
from shapley import shapley_value

metrics_cqd = get_metric_dict()
metrics_new = get_metric_dict()

for i in tqdm(range(len(hard))):
    query_hard = hard[i]
    query_complete = complete[i]
    easy_answers = query_complete.get_answer() 
    easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
    cqd_result = xcqa.query_execution(query_hard, k=k, coalition=cqd_coalition, t_norm=t_norm, t_conorm=t_conorm)

    current_metrics_cqd = get_metric_dict()
    current_metrics_new = get_metric_dict()

    for hard_answer in query_hard.answer:
        mrr, hit_1, hit_3, hit_10 = compute_metrics(cqd_result, query_complete.answer, hard_answer)
        if hit_1 == 1.0 and mrr == 1.0:
            #logger.info(f"Match! Query #{i} and Hard answer {hard_answer}")
            #logger.info(f"Original metrics: MRR: {mrr}, Hit@1: {hit_1}, Hit@3: {hit_3}, Hit@10: {hit_10}")
            current_metrics_cqd["mrr"].append(mrr)
            current_metrics_cqd["hit_1"].append(hit_1)
            current_metrics_cqd["hit_3"].append(hit_3)
            current_metrics_cqd["hit_10"].append(hit_10)

            shapley = {}
            for atom in range(num_atoms):
                shapley[atom] = shapley_value(xcqa, query_hard, atom, easy_answers, hard_answer, "rank", k, t_norm, t_conorm)
            
            lowest_atom = min(shapley, key=shapley.get)
            #logger.info(f"Lowest contributing atom: {lowest_atom} with Shapley value: {shapley[lowest_atom]} (Full shapley value: {shapley})")
            new_coalition = cqd_coalition.copy()
            new_coalition[lowest_atom] = 0
            new_result = xcqa.query_execution(query_hard, k=k, coalition=new_coalition, t_norm=t_norm, t_conorm=t_conorm)
            mrr_new, hit_1_new, hit_3_new, hit_10_new = compute_metrics(new_result, query_complete.answer, hard_answer)
            #logger.info(f"New metrics after removing atom {lowest_atom}: MRR: {mrr_new}, Hit@1: {hit_1_new}, Hit@3: {hit_3_new}, Hit@10: {hit_10_new}")
            current_metrics_new["mrr"].append(mrr_new)
            current_metrics_new["hit_1"].append(hit_1_new)
            current_metrics_new["hit_3"].append(hit_3_new)
            current_metrics_new["hit_10"].append(hit_10_new)

    # if it is not empty
    if current_metrics_cqd["mrr"] != []:

        for key in current_metrics_new:
            current_metrics_new[key] = sum(current_metrics_new[key]) / len(current_metrics_new[key])
        for key in current_metrics_cqd:
            current_metrics_cqd[key] = sum(current_metrics_cqd[key]) / len(current_metrics_cqd[key])
        
        metrics_cqd["mrr"].append(current_metrics_cqd["mrr"])
        metrics_cqd["hit_1"].append(current_metrics_cqd["hit_1"])
        metrics_cqd["hit_3"].append(current_metrics_cqd["hit_3"])
        metrics_cqd["hit_10"].append(current_metrics_cqd["hit_10"])
        metrics_new["mrr"].append(current_metrics_new["mrr"])
        metrics_new["hit_1"].append(current_metrics_new["hit_1"])
        metrics_new["hit_3"].append(current_metrics_new["hit_3"])
        metrics_new["hit_10"].append(current_metrics_new["hit_10"])


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5000/5000 [06:40<00:00, 12.48it/s]  


In [23]:
len(metrics_cqd["mrr"])

1549

In [24]:
# report the reduced metrics
for key in metrics_new:
    avg_before = sum(metrics_cqd[key]) / len(metrics_cqd[key])
    avg_after = sum(metrics_new[key]) / len(metrics_new[key])
    avg_diff = avg_after - avg_before
    print(f"Average {key} before removing the most important atom: {avg_before}")
    print(f"Average {key} after removing the most important atom: {avg_after}")
    print(f"Average {key} difference: {avg_diff}")


Average mrr before removing the most important atom: 1.0
Average mrr after removing the most important atom: 0.3287072230187944
Average mrr difference: -0.6712927769812056
Average hit_1 before removing the most important atom: 1.0
Average hit_1 after removing the most important atom: 0.2666838963160033
Average hit_1 difference: -0.7333161036839967
Average hit_3 before removing the most important atom: 1.0
Average hit_3 after removing the most important atom: 0.3660907200374906
Average hit_3 difference: -0.6339092799625095
Average hit_10 before removing the most important atom: 1.0
Average hit_10 after removing the most important atom: 0.44883617109698676
Average hit_10 difference: -0.5511638289030132


## Sufficient

In [None]:
logger = setup_logger("sufficient")

In [12]:
from tqdm import tqdm
from shapley import shapley_value

metrics_symbolic = get_metric_dict()
metrics_new = get_metric_dict()

for i in tqdm(range(len(hard))):
    query_hard = hard[i]
    query_complete = complete[i]
    easy_answers = query_complete.get_answer() 
    easy_answers = [a for a in easy_answers if a not in query_hard.get_answer()]
    symbolic_result = xcqa.query_execution(query_hard, k=k, coalition=symbolic_coalition, t_norm=t_norm, t_conorm=t_conorm)

    current_metrics_symbolic = get_metric_dict()
    current_metrics_new = get_metric_dict()

    for hard_answer in query_hard.answer:
        #logger.info(f"Evaluating query # {i} with hard answer: {hard_answer}")
        mrr, hit_1, hit_3, hit_10 = compute_metrics(symbolic_result, query_complete.answer, hard_answer)

        if mrr != 1.0 and hit_1 != 1.0:
            #logger.info(f"Original metrics: MRR: {mrr}, Hit@1: {hit_1}, Hit@3: {hit_3}, Hit@10: {hit_10}")
            current_metrics_symbolic["mrr"].append(mrr)
            current_metrics_symbolic["hit_1"].append(hit_1)
            current_metrics_symbolic["hit_3"].append(hit_3)
            current_metrics_symbolic["hit_10"].append(hit_10)

            shapley = {}
            for atom in range(num_atoms):
                shapley[atom] = shapley_value(xcqa, query_hard, atom, easy_answers, hard_answer, "rank", k, t_norm, t_conorm)
                
            lowest_atom = min(shapley, key=shapley.get)
            #logger.info(f"Lowest contributing atom: {lowest_atom} with Shapley value: {shapley[lowest_atom]} (Full shapley value: {shapley})")

            new_coalition = symbolic_coalition.copy()
            new_coalition[lowest_atom] = 1

            new_result = xcqa.query_execution(query_hard, k=k, coalition=new_coalition, t_norm=t_norm, t_conorm=t_conorm)
            mrr_new, hit_1_new, hit_3_new, hit_10_new = compute_metrics(new_result, query_complete.answer, hard_answer)
            #logger.info(f"New metrics after removing atom {lowest_atom}: MRR: {mrr_new}, Hit@1: {hit_1_new}, Hit@3: {hit_3_new}, Hit@10: {hit_10_new}")
            current_metrics_new["mrr"].append(mrr_new)
            current_metrics_new["hit_1"].append(hit_1_new)
            current_metrics_new["hit_3"].append(hit_3_new)
            current_metrics_new["hit_10"].append(hit_10_new)

    # if it is not empty
    if current_metrics_symbolic["mrr"] != []:

        for key in current_metrics_new:
            current_metrics_new[key] = sum(current_metrics_new[key]) / len(current_metrics_new[key])
        for key in current_metrics_symbolic:
            current_metrics_symbolic[key] = sum(current_metrics_symbolic[key]) / len(current_metrics_symbolic[key])
        
        metrics_symbolic["mrr"].append(current_metrics_symbolic["mrr"])
        metrics_symbolic["hit_1"].append(current_metrics_symbolic["hit_1"])
        metrics_symbolic["hit_3"].append(current_metrics_symbolic["hit_3"])
        metrics_symbolic["hit_10"].append(current_metrics_symbolic["hit_10"])
        metrics_new["mrr"].append(current_metrics_new["mrr"])
        metrics_new["hit_1"].append(current_metrics_new["hit_1"])
        metrics_new["hit_3"].append(current_metrics_new["hit_3"])
        metrics_new["hit_10"].append(current_metrics_new["hit_10"])


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5000/5000 [40:39<00:00,  2.05it/s]   


In [15]:
len(metrics_new["mrr"])

4937

In [16]:
len(metrics_symbolic["mrr"])

4937

In [18]:
# report the reduced metrics
for key in metrics_new:
    avg_before = sum(metrics_symbolic[key]) / len(metrics_symbolic[key])
    avg_after = sum(metrics_new[key]) / len(metrics_new[key])
    avg_diff = avg_after - avg_before
    print(f"Average {key} before removing the most important atom: {avg_before}")
    print(f"Average {key} after removing the most important atom: {avg_after}")
    print(f"Average {key} difference: {avg_diff}")

Average mrr before removing the most important atom: 0.0011034287627172425
Average mrr after removing the most important atom: 0.3200484838822459
Average mrr difference: 0.31894505511952864
Average hit_1 before removing the most important atom: 0.0
Average hit_1 after removing the most important atom: 0.24257409198923327
Average hit_1 difference: 0.24257409198923327
Average hit_3 before removing the most important atom: 0.0008579196266527068
Average hit_3 after removing the most important atom: 0.34149325865913005
Average hit_3 difference: 0.34063533903247734
Average hit_10 before removing the most important atom: 0.0017161851087366179
Average hit_10 after removing the most important atom: 0.4742260084649854
Average hit_10 difference: 0.4725098233562488
