## Preparation

In [1]:
from utils import setup_dataset_and_graphs
data_dir = "data/FB15k-237"
dataset, graph_train, graph_valid, graph_test = setup_dataset_and_graphs(data_dir, logging=True, add_reverse=True)

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.
Loaded 544230 edges from data/FB15k-237/train.txt, skipped 0 edges due to missing nodes or relations.
Loaded 579300 edges from data/FB15k-237/valid.txt, skipped 0 edges due to missing nodes or relations.
Loaded 620232 edges from data/FB15k-237/test.txt, skipped 0 edges due to missing nodes or relations.


In [2]:
from utils import load_all_queries
query_dataset, query_dataset_hard = load_all_queries(dataset, data_dir, "test", version=1)

In [3]:
k = 10
t_norm, t_conorm = "prod", "prod"
model_path = "models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt"

In [4]:
from symbolic_torch import SymbolicReasoning
from xcqa_torch import XCQA

reasoner = SymbolicReasoning(graph_valid, logging=False)
xcqa = XCQA(symbolic=reasoner, dataset=dataset, logging=False, model_path=model_path, normalize=False)

ComplEx(
  (embeddings): ModuleList(
    (0): Embedding(14505, 2000, sparse=True)
    (1): Embedding(474, 2000, sparse=True)
  )
)


## Query Sampling

In [5]:
from utils import get_num_atoms

query_type = "2i"
num_atoms = get_num_atoms(query_type)
queries_hard = query_dataset_hard.get_queries(query_type)
queries_complete = query_dataset.get_queries(query_type)

In [6]:
query_idx = 337
query_hard = queries_hard[query_idx]
query_complete = queries_complete[query_idx]
query_hard

Query(type=2i, query=((212, (84,)), (1374, (40,))), answer=[7298, 8836, 6667, 8209, 10641, 917, 1951, 2995, 198, 1742, 4688, 3070, 4443, 6494, 2784, 3940, 2664, 1775, 6128, 10743, 123, 1150])

In [7]:
from utils import human_readable

hr = human_readable(query_hard, dataset, fol=False, full_relation=True)
fol = human_readable(query_hard, dataset, fol=True, full_relation=True)
print(hr)
print(fol)

Piano (212) --[/music/instrument/instrumentalists (84)]--> V0
Rock music (1374) --[/music/genre/artists (40)]--> V0
?V1: /music/instrument/instrumentalists(Piano, V0) âˆ§ /music/genre/artists(Rock music, V0)


In [8]:
hard_answers = query_hard.get_answer()
complete_answers = query_complete.get_answer()
easy_answers = list(set(complete_answers) - set(hard_answers))

In [9]:
for hard_answer in query_hard.get_answer():
    print(f"{hard_answer}: {dataset.get_title_by_id(hard_answer)}")

7298: Hikaru Utada
8836: Ryan Adams
6667: Ringo Starr
8209: Sheryl Crow
10641: Yoko Ono
917: Bob Dylan
1951: B.B. King
2995: Demi Lovato
198: Jon Bon Jovi
1742: Dave Grohl
4688: Michael Jackson
3070: George Harrison
4443: Lou Reed
6494: Paul Weller
2784: George Martin
3940: Roy Bittan
2664: Lenny Kravitz
1775: Pharrell Williams
6128: Serj Tankian
10743: Enrique Iglesias
123: Jay Bennett
1150: Robbie Williams


In [10]:
from shapley import Shapley
shapley = Shapley(xcqa)

## CQD-SHAP Explanations

### Target Answer: Paul Weller

In [11]:
from utils import compute_rank

target = 6494 # Paul Weller
result_cqd = xcqa.query_execution(query_hard, k=k, coalition=[1] * num_atoms, t_norm=t_norm, t_conorm=t_conorm)
result_sym = xcqa.query_execution(query_hard, k=k, coalition=[0] * num_atoms, t_norm=t_norm, t_conorm=t_conorm)
rank_cqd = compute_rank(result_cqd, complete_answers, target)
rank_sym = compute_rank(result_sym, complete_answers, target)
print(f"Rank of {target} ({dataset.get_title_by_id(target)}) in CQD: {rank_cqd}")
print(f"Rank of {target} ({dataset.get_title_by_id(target)}) in Symbolic: {rank_sym}")

Rank of 6494 (Paul Weller) in CQD: 61
Rank of 6494 (Paul Weller) in Symbolic: 56


In [12]:
filtered_nodes = list(set(complete_answers) - set([target]))
shapleys = shapley.shapley_values(query_hard, filtered_nodes, target)
shapleys

{0: 123.5, 1: -128.5}

In [13]:
from utils import compute_rank

target = 6494 # Paul Weller
result_hyb = xcqa.query_execution(query_hard, k=k, coalition=[1, 0], t_norm=t_norm, t_conorm=t_conorm)
rank_hyb = compute_rank(result_hyb, complete_answers, target)
print(f"Rank of {target} ({dataset.get_title_by_id(target)}) in Hybrid: {rank_hyb}")

Rank of 6494 (Paul Weller) in Hybrid: 33


In [14]:
print(f"Sum of Shapley values: {sum(shapleys.values())}")

Sum of Shapley values: -5.0


### Target Answer: Billy Joel

In [15]:
from utils import compute_rank

target = 6347 # Billy Joel
result_cqd = xcqa.query_execution(query_hard, k=k, coalition=[1] * num_atoms, t_norm=t_norm, t_conorm=t_conorm)
result_sym = xcqa.query_execution(query_hard, k=k, coalition=[0] * num_atoms, t_norm=t_norm, t_conorm=t_conorm)
rank_cqd = compute_rank(result_cqd, complete_answers, target)
rank_sym = compute_rank(result_sym, complete_answers, target)
print(f"Rank of {target} ({dataset.get_title_by_id(target)}) in CQD: {rank_cqd}")
print(f"Rank of {target} ({dataset.get_title_by_id(target)}) in Symbolic: {rank_sym}")

Rank of 6347 (Billy Joel) in CQD: 1
Rank of 6347 (Billy Joel) in Symbolic: 259


In [16]:
filtered_nodes = list(set(complete_answers) - set([target]))
shapleys = shapley.shapley_values(query_hard, filtered_nodes, target)
shapleys

{0: 219.0, 1: 39.0}

In [17]:
print(f"Sum of Shapley values: {sum(shapleys.values())}")

Sum of Shapley values: 258.0
