In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from torch import nn
import data
from os import path
from collections import defaultdict
from enum import Enum
from codes.model import KGEModel
import argparse
import json

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

### To run this notebook, please train a kge model first.
### See the script 'linebone/kge/best_config.sh' to run a kge model.

In [2]:
name_dataset = "wn18rr"
# name_dataset = "FB15k-237"
# name_dataset = "YAGO3-10"

# name_model = "TransE"
name_model = "pRotatE"
# name_model = "RotatE"

data_path = f"data/{name_dataset}"
model_dir = f"models/{name_model}_{name_dataset}_0"
checkpoint_dir = f"models/{name_model}_{name_dataset}_0/checkpoint"
config_dir = f"models/{name_model}_{name_dataset}_0/config.json"

# load config
with open(config_dir, 'r') as fjson:
    argparse_dict = json.load(fjson)
model = argparse_dict['model']
double_entity_embedding = argparse_dict['double_entity_embedding']
double_relation_embedding = argparse_dict['double_relation_embedding']
hidden_dim = argparse_dict['hidden_dim']
gamma = argparse_dict['gamma']
# test_batch_size = argparse_dict['test_batch_size']

ent2id = {}
rel2id = {}
with open(os.path.join(data_path, 'entities.dict')) as fin:
    for line in fin:
        eid, entity = line.strip().split('\t')
        ent2id[entity] = int(eid)
with open(os.path.join(data_path, 'relations.dict')) as fin:
    for line in fin:
        rid, relation = line.strip().split('\t')
        rel2id[relation] = int(rid)

def read_triples(file_path):
    triples = []
    with open(file_path, 'r') as f:
        for line in f:
            h, r, t = line.strip().split('\t')
            triples.append((ent2id[h], rel2id[r], ent2id[t]))
    return triples

id2ent = {v:k for k,v in ent2id.items()}
id2rel = {v:k for k,v in rel2id.items()}

num_ent = len(ent2id)
num_rel = len(rel2id)

train_triples = read_triples(os.path.join(data_path, 'train.txt'))
valid_triples = read_triples(os.path.join(data_path, 'valid.txt'))
test_triples = read_triples(os.path.join(data_path, 'test.txt'))

#All true triples
all_true_triples = train_triples + valid_triples + test_triples

kge_model = KGEModel(
    model_name=name_model,
    nentity=num_ent,
    nrelation=num_rel,
    hidden_dim=hidden_dim,
    gamma=gamma,
    double_entity_embedding=double_entity_embedding,
    double_relation_embedding=double_relation_embedding
)

# load model
checkpoint = torch.load(checkpoint_dir)
kge_model.load_state_dict(checkpoint['model_state_dict'])

kge_model = kge_model.cuda()

# load embedding
entity_embedding = kge_model.entity_embedding
relation_embedding = kge_model.relation_embedding

# load other config
epsilon = kge_model.epsilon
embedding_range = kge_model.embedding_range
entity_dim = kge_model.entity_dim
relation_dim = kge_model.relation_dim
gamma = kge_model.gamma

if name_model == 'pRotatE':
    modulus = kge_model.modulus

# split dataset and output rel_type subdataset
def filtered_dataset(triples, rel_type):
    rst = set()
    for head, rel, tail in triples:
        if rel in rel_type:
            rst.add((head, rel, tail))
    return rst

# calculate proportion of each rel_type
proportion = torch.zeros(num_rel)
for triple in test_triples:
    proportion[triple[1]] += 1

In [3]:
# this is for wn18rr and yago3-10
# if you want to change dataset to fb15k-237, please comment this notebook block and use the next block.
def get_subset_rel(batch_type, A_set, B_set):
    rst = []
    threshold = 2
    _1N = {}
    for rel in A_set.keys():
        tmp = 0
        for head in A_set[rel].keys():
            if len(A_set[rel][head]) >= threshold:
                tmp += 1
        _1N[rel] = tmp / len(A_set[rel])
    id2relation = {v:k for k,v in rel2id.items()}
    _N1 = {}
    for rel in B_set.keys():
        tmp = 0
        for tail in B_set[rel].keys():
            if len(B_set[rel][tail]) >= threshold:
                tmp += 1
        _N1[rel] = tmp / len(B_set[rel])
    if batch_type == BatchType.HEAD_BATCH:
        for rel in A_set.keys():
            if _N1[rel] - _1N[rel] > 0.25 and _N1[rel] > 0.5:
                rst.append(rel)
        print("N-1 rel:", [id2relation[i] for i in rst])
    elif batch_type == BatchType.TAIL_BATCH:
        for rel in B_set.keys():
            if _1N[rel] - _N1[rel] > 0.2 and _1N[rel] > 0.2:
                rst.append(rel)
        print("1-N rel:", [id2relation[i] for i in rst])
    return rst

def get_11_rel(A_set, B_set):
    rst = []
    threshold = 2
    proportion_threshold = 0.25
    _1N = {}
    for rel in A_set.keys():
        tmp = 0
        for head in A_set[rel].keys():
            if len(A_set[rel][head]) >= threshold:
                tmp += 1
        _1N[rel] = tmp / len(A_set[rel])
    id2relation = {v:k for k,v in rel2id.items()}
    _N1 = {}
    for rel in B_set.keys():
        tmp = 0
        for tail in B_set[rel].keys():
            if len(B_set[rel][tail]) >= threshold:
                tmp += 1
        _N1[rel] = tmp / len(B_set[rel])
    for rel in A_set.keys():
        if _N1[rel] <= proportion_threshold and _1N[rel] <= proportion_threshold:
            rst.append(rel)
    print("1-1 rel:", [id2relation[i] for i in rst])
    return rst

In [4]:
### this is the block for fb15k-237
# def get_subset_rel(batch_type, A_set, B_set):
#     rst = []
#     threshold = 2
#     _1N = {}
#     for rel in A_set.keys():
#         tmp = 0
#         for head in A_set[rel].keys():
#             if len(A_set[rel][head]) >= threshold:
#                 tmp += 1
#         _1N[rel] = tmp / len(A_set[rel])
#     id2relation = {v:k for k,v in rel2id.items()}
#     _N1 = {}
#     for rel in B_set.keys():
#         tmp = 0
#         for tail in B_set[rel].keys():
#             if len(B_set[rel][tail]) >= threshold:
#                 tmp += 1
#         _N1[rel] = tmp / len(B_set[rel])
#     if batch_type == BatchType.HEAD_BATCH:
#         for rel in A_set.keys():
#             if _N1[rel] - _1N[rel] > 0.4 and _N1[rel] > 0.7:
#                 rst.append(rel)
#         # print("N-1 rel:", [id2relation[i] for i in rst])
#     elif batch_type == BatchType.TAIL_BATCH:
#         for rel in B_set.keys():
#             if _1N[rel] - _N1[rel] > 0.525 and _1N[rel] > 0.6:
#                 rst.append(rel)
#         # print("1-N rel:", [id2relation[i] for i in rst])
#     return rst

# def get_11_rel(A_set, B_set):
#     rst = []
#     threshold = 2
#     proportion_threshold = 0.2
#     _1N = {}
#     for rel in A_set.keys():
#         tmp = 0
#         for head in A_set[rel].keys():
#             if len(A_set[rel][head]) >= threshold:
#                 tmp += 1
#         _1N[rel] = tmp / len(A_set[rel])
#     id2relation = {v:k for k,v in rel2id.items()}
#     _N1 = {}
#     for rel in B_set.keys():
#         tmp = 0
#         for tail in B_set[rel].keys():
#             if len(B_set[rel][tail]) >= threshold:
#                 tmp += 1
#         _N1[rel] = tmp / len(B_set[rel])
#     for rel in A_set.keys():
#         if _N1[rel] <= proportion_threshold and _1N[rel] <= proportion_threshold:
#             rst.append(rel)
#     # print("1-1 rel:", [id2relation[i] for i in rst])
#     return rst

In [5]:
# get complex rel type
class BatchType(Enum):
    HEAD_BATCH = 0
    TAIL_BATCH = 1
    SINGLE = 2

def generate_get_tail_find_head():
    got_tail_find_head = {}
    got_head_find_tail = {}

    for head, rel, tail in train_triples:
        if rel not in got_tail_find_head.keys():
            got_tail_find_head[rel] = {}
        if rel not in got_head_find_tail.keys():
            got_head_find_tail[rel] = {}
        else:
            if tail not in got_tail_find_head[rel].keys():
                got_tail_find_head[rel][tail] = set()
            if head not in got_head_find_tail[rel].keys():
                got_head_find_tail[rel][head] = set()
            got_tail_find_head[rel][tail].add(head)
            got_head_find_tail[rel][head].add(tail)

    _N1_rel = get_subset_rel(BatchType.HEAD_BATCH, got_head_find_tail, got_tail_find_head)
    _1N_rel = get_subset_rel(BatchType.TAIL_BATCH, got_head_find_tail, got_tail_find_head)
    _sym_rel = get_11_rel(got_head_find_tail, got_tail_find_head)

    return _N1_rel, _1N_rel, _sym_rel

num_complex_rel_type = 4
_N1_rel, _1N_rel, _sym_rel = generate_get_tail_find_head()

print()

complex_rel_idx_dict = defaultdict(list)
for rel in range(num_rel):
    if rel in _sym_rel:
        complex_rel_idx_dict['one2one'].append(rel)
    elif rel in _N1_rel:
        complex_rel_idx_dict['many2one'].append(rel)
    elif rel in _1N_rel:
        complex_rel_idx_dict['one2many'].append(rel)
    else:
        complex_rel_idx_dict['many2many'].append(rel)

# implicit need: _N1_rel, _1N_rel, _sym_rel
def complex_rel_triples_gen(triples):
    one2one, many2one, one2many, many2many = set(), set(), set(), set()
    for head, rel, tail in triples:
        if rel in _sym_rel:
            one2one.add((head, rel, tail))
        elif rel in _N1_rel:
            many2one.add((head, rel, tail))
        elif rel in _1N_rel:
            one2many.add((head, rel, tail))
        else:
            many2many.add((head, rel, tail))
    return [one2one, many2one, one2many, many2many]


# complex_rel_dist
complex_rel_dist = torch.zeros(4)
complex_rel_dist[0] = proportion[complex_rel_idx_dict['one2one']].sum()
complex_rel_dist[1] = proportion[complex_rel_idx_dict['many2one']].sum()
complex_rel_dist[2] = proportion[complex_rel_idx_dict['one2many']].sum()
complex_rel_dist[3] = proportion[complex_rel_idx_dict['many2many']].sum()
complex_rel_dist /= complex_rel_dist.sum()
print("complex_rel_dist: ", complex_rel_dist)

N-1 rel: ['_hypernym', '_instance_hypernym', '_synset_domain_topic_of']
1-N rel: ['_member_meronym', '_has_part', '_member_of_domain_usage', '_member_of_domain_region']
1-1 rel: ['_verb_group', '_similar_to']

complex_rel_dist:  tensor([0.0134, 0.4745, 0.1516, 0.3606])


In [6]:
# known 2 finding 1
def get_table(dataset):
    table_head = defaultdict(list)
    table_tail = defaultdict(list)
    for head, rel, tail in dataset:
        table_head[(rel, tail)].append(head)
        table_tail[(rel, head)].append(tail)
    return table_head, table_tail

In [7]:
# N1 means head testing
# 1N means tail testing
def TransE(head, relation, tail, mode):
    if mode == 'N1':
        score = head + (relation - tail)
    else:
        score = (head + relation) - tail

    # score = gamma.item() - torch.norm(score, p=1, dim=2)
    score = gamma.item() - torch.abs(score)
    return score.squeeze()

def RotatE(head, relation, tail, mode):
    pi = 3.14159265358979323846
    
    re_head, im_head = torch.chunk(head, 2, dim=-1)
    re_tail, im_tail = torch.chunk(tail, 2, dim=-1)

    #Make phases of relations uniformly distributed in [-pi, pi]

    phase_relation = relation/(embedding_range.item()/pi)

    re_relation = torch.cos(phase_relation)
    im_relation = torch.sin(phase_relation)

    if mode == 'N1':
        re_score = re_relation * re_tail + im_relation * im_tail
        im_score = re_relation * im_tail - im_relation * re_tail
        re_score = re_score - re_head
        im_score = im_score - im_head
    else:
        re_score = re_head * re_relation - im_head * im_relation
        im_score = re_head * im_relation + im_head * re_relation
        re_score = re_score - re_tail
        im_score = im_score - im_tail

    score = torch.stack([re_score, im_score], dim = 0)
    score = score.norm(dim = 0)

    # score = gamma.item() - score.sum(dim = 2)
    score = gamma.item() - score
    return score.squeeze()

def pRotatE(head, relation, tail, mode):
    pi = 3.14159262358979323846
    
    #Make phases of entities and relations uniformly distributed in [-pi, pi]

    phase_head = head/(embedding_range.item()/pi)
    phase_relation = relation/(embedding_range.item()/pi)
    phase_tail = tail/(embedding_range.item()/pi)

    if mode == 'N1':
        score = phase_head + (phase_relation - phase_tail)
    else:
        # print(phase_head.shape)
        # print(phase_relation.shape)
        # print(phase_tail.shape)
        score = (phase_head + phase_relation) - phase_tail

    score = torch.sin(score)            
    score = torch.abs(score)

    # score = gamma.item() - score.sum(dim = 2) * modulus
    score = gamma.item() - score * modulus
    return score.squeeze()


model_func = {
            'TransE': TransE,
            # 'DistMult': DistMult,
            # 'ComplEx': ComplEx,
            'RotatE': RotatE,
            'pRotatE': pRotatE
        }

def get_contribution(test_type="N1"):
    target_rel = 0
    rst = torch.zeros(hidden_dim, num_rel)
    print(rst.shape)
    table_head, table_tail = get_table(all_true_triples)
    skip_rel_idx = []
    while target_rel < num_rel:
        if test_type == 'rel_based':
            if target_rel in _N1_rel:
                test_type = "N1"
            elif target_rel in _1N_rel:
                test_type = "1N"
            else:
                test_type = "N1"

        dataset = filtered_dataset(test_triples, [target_rel])
        if len(dataset) <= 0:
            print("rel " + str(target_rel) + " skip cause its size is zero")
            skip_rel_idx.append(target_rel)
            target_rel += 1
            continue
        all_idx = []
        all_std = []
        all_rate = []
        print("dataset size: ", len(dataset))
        for head, rel, tail in dataset:
            h = entity_embedding[head].unsqueeze(dim=0)
            r = relation_embedding[rel].unsqueeze(dim=0)
            t = entity_embedding[tail].unsqueeze(dim=0)

            rate = model_func[name_model](h, r, t, test_type)
            rate /= rate.sum()
            all_rate.append(rate.detach())

            if test_type == "N1":
                all_score = model_func[name_model](entity_embedding, r, t, test_type)
                std = torch.std(all_score, dim=0) # (D)
                all_score[torch.tensor(table_head[(rel, tail)]).long()] -= 1000.0
                all_score[head] += 1000.0
                if name_dataset != "FB15k-237":
                    all_score[tail] -= 1000.0
                values, indices = torch.sort(all_score, dim=0, descending=True)
                target_idx = (indices.T == head).nonzero()[:,1] + 1 # (D)
            elif test_type == "1N":
                all_score = model_func[name_model](h, r, entity_embedding, test_type)
                std = torch.std(all_score, dim=0) # (D)
                all_score[torch.tensor(table_tail[(rel, head)]).long()] -= 1000.0
                if name_dataset != "FB15k-237":
                    all_score[head] -= 1000.0
                all_score[tail] += 1000.0
                values, indices = torch.sort(all_score, dim=0, descending=True)
                target_idx = (indices.T == tail).nonzero()[:,1] + 1 # (D)
            all_std.append(std.detach())
            all_idx.append(target_idx.detach()) # (D)

        all_rate = torch.stack(all_rate, dim=0)
        all_rate = (all_rate.float()).mean(dim=0)

        all_std = torch.stack(all_std, dim=0)
        all_std = (all_std.float()).mean(dim=0)
        
        all_idx = torch.stack(all_idx, dim=0)
        all_idx = (all_idx.float()).mean(dim=0)
        rst[:, target_rel] = all_std / (all_idx * all_rate)
        # rst[:, target_rel] = all_std / (all_idx)
        # rst[:, target_rel] = all_std
        print("rel " + str(target_rel) + " done")
        target_rel += 1
    
    if len(skip_rel_idx) > 0:
        rst[:, skip_rel_idx] = rst.mean(dim=-1).unsqueeze(dim=-1)
    # log
    # rst = torch.log10(rst)
    # rst = rst - torch.min(rst, dim=0)
    rst = rst - torch.min(rst, dim=0)[0]
    # softmax
    rst = rst.softmax(dim=0)
    # to be weighted
    one2one_dims = rst[:,complex_rel_idx_dict['one2one']] * proportion[complex_rel_idx_dict['one2one']].unsqueeze(dim=0) / torch.sum(proportion[complex_rel_idx_dict['one2one']])
    one2one_dims = one2one_dims.sum(dim=-1).unsqueeze(dim=0)
    many2one_dims = rst[:,complex_rel_idx_dict['many2one']] * proportion[complex_rel_idx_dict['many2one']].unsqueeze(dim=0) / torch.sum(proportion[complex_rel_idx_dict['many2one']])
    many2one_dims = many2one_dims.sum(dim=-1).unsqueeze(dim=0)
    one2many_dims = rst[:,complex_rel_idx_dict['one2many']] * proportion[complex_rel_idx_dict['one2many']].unsqueeze(dim=0) / torch.sum(proportion[complex_rel_idx_dict['one2many']])
    one2many_dims = one2many_dims.sum(dim=-1).unsqueeze(dim=0)
    many2many_dims = rst[:,complex_rel_idx_dict['many2many']] * proportion[complex_rel_idx_dict['many2many']].unsqueeze(dim=0) / torch.sum(proportion[complex_rel_idx_dict['many2many']])
    many2many_dims = many2many_dims.sum(dim=-1).unsqueeze(dim=0)
    # 取大于平均值为True
    return torch.cat([one2one_dims, many2one_dims, one2many_dims, many2many_dims], dim=0)

rst_head = get_contribution("N1")
rst_tail = get_contribution("1N")

torch.Size([500, 11])
dataset size:  1251
rel 0 done
dataset size:  1074
rel 1 done
dataset size:  122
rel 2 done
dataset size:  56
rel 3 done
dataset size:  253
rel 4 done
dataset size:  114
rel 5 done
dataset size:  172
rel 6 done
dataset size:  24
rel 7 done
dataset size:  26
rel 8 done
dataset size:  39
rel 9 done
dataset size:  3
rel 10 done
torch.Size([500, 11])
dataset size:  1251
rel 0 done
dataset size:  1074
rel 1 done
dataset size:  122
rel 2 done
dataset size:  56
rel 3 done
dataset size:  253
rel 4 done
dataset size:  114
rel 5 done
dataset size:  172
rel 6 done
dataset size:  24
rel 7 done
dataset size:  26
rel 8 done
dataset size:  39
rel 9 done
dataset size:  3
rel 10 done


In [8]:
class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))

In [9]:
count_head = torch.zeros(num_complex_rel_type, hidden_dim) 
count_tail = torch.zeros(num_complex_rel_type, hidden_dim)
for i in range(num_complex_rel_type):
    count_head[i, :] = rst_head[i, :] > rst_head[i, :].mean()
    count_tail[i, :] = rst_tail[i, :] > rst_tail[i, :].mean()
names = ['one2one', 'many2one', 'one2many', 'many2many']
for i in range(num_complex_rel_type):
    print(names[i])
    print(count_head[i, :].nonzero().shape, count_tail[i, :].nonzero().shape)
    print((count_head[i, :] == count_tail[i, :]).nonzero().shape)
    print()

# complex_rel_dist
# complex_rel_dist = torch.zeros(4)
# complex_rel_dist[0] = proportion[complex_rel_idx_dict['one2one']].sum()
# complex_rel_dist[1] = proportion[complex_rel_idx_dict['many2one']].sum()
# complex_rel_dist[2] = proportion[complex_rel_idx_dict['one2many']].sum()
# complex_rel_dist[3] = proportion[complex_rel_idx_dict['many2many']].sum()
# complex_rel_dist /= complex_rel_dist.sum()
print("complex_rel_dist: ", complex_rel_dist)

# contribute_dims_dist
contribute_dims = torch.zeros(4)
contribute_dims[0] = (count_head[0, :].nonzero().shape[0] + count_tail[0, :].nonzero().shape[0]) / 2
contribute_dims[1] = (count_head[1, :].nonzero().shape[0] + count_tail[1, :].nonzero().shape[0]) / 2
contribute_dims[2] = (count_head[2, :].nonzero().shape[0] + count_tail[2, :].nonzero().shape[0]) / 2
contribute_dims[3] = (count_head[3, :].nonzero().shape[0] + count_tail[3, :].nonzero().shape[0]) / 2
contribute_dims /= contribute_dims.sum()
print("contribute_dims: ", contribute_dims)

# JS div
jsd = JSD()
output = jsd(complex_rel_dist, contribute_dims)

print(output)


one2one
torch.Size([196, 1]) torch.Size([199, 1])
torch.Size([481, 1])

many2one
torch.Size([279, 1]) torch.Size([282, 1])
torch.Size([493, 1])

one2many
torch.Size([245, 1]) torch.Size([248, 1])
torch.Size([493, 1])

many2many
torch.Size([239, 1]) torch.Size([235, 1])
torch.Size([496, 1])

complex_rel_dist:  tensor([0.0134, 0.4745, 0.1516, 0.3606])
contribute_dims:  tensor([0.2054, 0.2917, 0.2564, 0.2465])
tensor(0.0738)


In [10]:
### this is intersection matrix
# matrix_head = torch.eye(4)
# for i in range(4):
#     for j in range(i+1, 4):
#         matrix_head[i][j] = (count_head[i, :] * count_head[j, :]).nonzero().shape[0]

# matrix_head = matrix_head + matrix_head.T

# matrix_head[0,0] = (count_head[0, :].nonzero().shape[0] + count_tail[0, :].nonzero().shape[0]) / 2
# matrix_head[1,1] = (count_head[1, :].nonzero().shape[0] + count_tail[1, :].nonzero().shape[0]) / 2
# matrix_head[2,2] = (count_head[2, :].nonzero().shape[0] + count_tail[2, :].nonzero().shape[0]) / 2
# matrix_head[3,3] = (count_head[3, :].nonzero().shape[0] + count_tail[3, :].nonzero().shape[0]) / 2

# print(matrix_head.int())

tensor([[197, 125, 105, 125],
        [125, 280, 223, 228],
        [105, 223, 246, 188],
        [125, 228, 188, 237]], dtype=torch.int32)
