In [1]:
import os
import random
import torch
import numpy as np
from torch.nn import functional as F

dataset_dir = "./family/"
all_trip_file = os.path.join(dataset_dir, "all.txt")
relations_file = os.path.join(dataset_dir, "relations.txt")
entities_file = os.path.join(dataset_dir, "entities.txt")

In [2]:
def read_xxx_to_id(file_path):
    xxx2id = {}
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            xxx2id[line] = len(xxx2id)
    return xxx2id

def parse_triplets(triplets_file: str,
                   rel2id: dict,
                   ent2id: dict):
    """Read triplets (relation, head, tail)."""
    triplets = []
    with open(triplets_file, 'r') as file:
        for line in file:
            line = line.strip().split('\t')
            assert(len(line) == 3)
            try:
                triplets.append(
                    (
                        rel2id[line[1]],
                        ent2id[line[0]],
                        ent2id[line[2]]
                    )
                )
            except KeyError:
                pass
    return triplets

## Read file and resample

In [3]:
resample = False
num_entities = False

In [4]:
rel2id = read_xxx_to_id(relations_file)
id2rel = {ident: rel for rel, ident in rel2id.items()}

ent2id = read_xxx_to_id(entities_file)
if resample:
    ent2id = {ent: ident for (ent, ident) in random.sample(list(ent2id.items()), num_entities)}
id2ent = {ident: ent for ent, ident in ent2id.items()}

all_facts = parse_triplets(all_trip_file, rel2id, ent2id)
# relation to (head, tail)
rel2ht = {rel: [] for rel in id2rel.keys()}
for (r, h, t) in all_facts:
    rel2ht[r].append((h, t))

num_rel, num_ent, num_trip = len(rel2id), len(ent2id), len(all_facts)

num_rel, num_ent, num_trip

(12, 3007, 28356)

## Compute macro, micro and comprehensive saturations

In [7]:
def get_adjacency_matrices(triplets,
                           num_relations: int,
                           num_entities: int
                           ):
    """Compute adjacency matrix from all triplets
       in preparation for creating sparse matrix in torch.
    """
    matrix = {
        r: ([[0, 0]], [0.], [num_entities, num_entities])
        for r in range(num_relations)
    }
    for triplet in triplets:
        rel = triplet[0]
        head = triplet[1]
        tail = triplet[2]
        value = 1.
        matrix[rel][0].append([head, tail])
        matrix[rel][1].append(value)

    for rel, mat in matrix.items():
        matrix[rel] = torch.sparse.FloatTensor(
            torch.LongTensor(mat[0]).t(),
            torch.FloatTensor(mat[1]),
            mat[2]
        )

    return matrix

# `adj_matrices`:  adjacency matrices, ORDER matters!!!
# `head_nodes`: head nodes list
# return: a list of `batch_size` nodes
def from_head_hops(adj_matrices: list,
                   head_nodes: list
                   ):
    # (batch_size, num_entities)
    v_x = F.one_hot(torch.LongTensor(head_nodes), adj_matrix[0].size(0)).float()
    # (num_entities, num_entities)
    result = torch.matmul(adj_matrices[0].t(), v_x.t())
    for mat in adj_matrices[1:]:
        result = torch.mm(mat.t(), result)
    # (batch_size, num_entites)
    result = result.t().numpy()
    indices = np.argwhere(result > 0)
    # {head: {tail: num_paths}}
    ret = {head: {} for head in head_nodes}
    for row, col in indices:
        # `row`: (row, column) ==> (head, tail)
        ret[head_nodes[row]][col] = result[row, col]
    return ret

In [8]:
adj_matrix = get_adjacency_matrices(all_facts, num_rel, num_ent)
adj_matrix[0], from_head_hops([adj_matrix[1], adj_matrix[2]], list(ent2id.values())[:2])

(tensor(indices=tensor([[   0, 2675,  267,  ..., 1973, 2096, 2097],
                        [   0, 2964,  369,  ..., 2023, 2866, 2866]]),
        values=tensor([0., 1., 1.,  ..., 1., 1., 1.]),
        size=(3007, 3007), nnz=2991, layout=torch.sparse_coo),
 {0: {}, 1: {0: 5.0, 1110: 5.0}})

In [9]:
from itertools import permutations

topk_macro = 10
topk_micro = 10
topk_comp = 10
max_rule_len = 2
relations = list(id2rel.keys())
paths_permut = [(rel1, rel2) for rel1 in relations for rel2 in relations]

len(paths_permut)

144

In [10]:
from time import time
from collections import defaultdict

start = time()

macro_saturations = {rel: {path: 0. for path in paths_permut} for rel in id2rel.keys()}
tmp_micro_saturations = {rel: {path: {} for path in paths_permut} for rel in id2rel.keys()}  # {path: {(head, tail): num_paths}}
micro_saturations = {rel: {path: 0. for path in paths_permut} for rel in id2rel.keys()}
total_paths_pairs = {rel: defaultdict(int) for rel in id2rel.keys()}  # {(head, tail): num_total_paths}
# get number of triplets under each relation
num_rel2trip = {rel: len(rel2ht[rel]) for rel in id2rel.keys()}
# get triplets under each relation
rel_head2tails = {rel: defaultdict(list) for rel in id2rel.keys()}
for (r, h, t) in all_facts:
    rel_head2tails[r][h].append(t)
    
for rel in rel_head2tails:
    if not rel_head2tails[rel]:
        continue
    for path in macro_saturations[rel].keys():
        matrices = [adj_matrix[r] for r in path]
        heads = list(rel_head2tails[rel].keys())
        num_paths_from_heads = from_head_hops(matrices, heads)
        for head, tails in rel_head2tails[rel].items():
            for tail in tails:
                if tail in num_paths_from_heads[head]:
                    macro_saturations[rel][path] += 1.
                    tmp_micro_saturations[rel][path][(head, tail)] = num_paths_from_heads[head][tail]
                    total_paths_pairs[rel][(head, tail)] += num_paths_from_heads[head][tail]
        macro_saturations[rel][path] /= num_rel2trip[rel]
    for path, pairs in tmp_micro_saturations[rel].items():
        for pair, num_path in pairs.items():
            # `pair`: (head, tail)
            micro_saturations[rel][path] += num_path / total_paths_pairs[rel][pair]
        if len(tmp_micro_saturations[rel][path]) != 0:
            micro_saturations[rel][path] /= num_rel2trip[rel]

print(f"{time() - start}s")

44.323203802108765s


### Macro saturation

In [11]:
for rel in macro_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(macro_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_macro:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}")        
    print("\n")

('sister', 'uncle'): 0.89
('sister', 'aunt'): 0.85
('aunt', 'brother'): 0.83
('aunt', 'sister'): 0.75
('sister', 'father'): 0.66
('sister', 'mother'): 0.34
('aunt', 'husband'): 0.02
('aunt', 'wife'): 0.02
('wife', 'uncle'): 0.00
('mother', 'wife'): 0.00


('son', 'father'): 1.00
('son', 'mother'): 0.98
('brother', 'brother'): 0.86
('brother', 'sister'): 0.81
('nephew', 'uncle'): 0.77
('nephew', 'aunt'): 0.68
('uncle', 'nephew'): 0.64
('uncle', 'niece'): 0.58
('father', 'nephew'): 0.33
('uncle', 'son'): 0.28


('sister', 'son'): 0.68
('sister', 'daughter'): 0.61
('daughter', 'husband'): 0.46
('daughter', 'wife'): 0.46
('niece', 'brother'): 0.38
('niece', 'sister'): 0.33
('wife', 'nephew'): 0.01
('niece', 'husband'): 0.00
('wife', 'brother'): 0.00
('daughter', 'niece'): 0.00


('husband', 'mother'): 0.85
('father', 'brother'): 0.62
('father', 'sister'): 0.54
('brother', 'uncle'): 0.45
('brother', 'aunt'): 0.40
('uncle', 'wife'): 0.01
('uncle', 'husband'): 0.00
('uncle', 'mother'): 0.00
(

### Micro saturation

In [12]:
for rel in micro_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(micro_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_micro:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}")
    print("\n")

('sister', 'uncle'): 0.26
('sister', 'aunt'): 0.22
('aunt', 'brother'): 0.21
('aunt', 'sister'): 0.18
('sister', 'father'): 0.09
('sister', 'mother'): 0.05
('aunt', 'wife'): 0.00
('aunt', 'husband'): 0.00
('wife', 'uncle'): 0.00
('mother', 'wife'): 0.00


('brother', 'brother'): 0.14
('nephew', 'uncle'): 0.13
('brother', 'sister'): 0.13
('uncle', 'nephew'): 0.12
('nephew', 'aunt'): 0.11
('uncle', 'niece'): 0.09
('son', 'father'): 0.08
('son', 'mother'): 0.07
('father', 'nephew'): 0.04
('uncle', 'son'): 0.04


('sister', 'son'): 0.25
('sister', 'daughter'): 0.20
('daughter', 'husband'): 0.15
('daughter', 'wife'): 0.14
('niece', 'brother'): 0.10
('niece', 'sister'): 0.09
('wife', 'nephew'): 0.00
('wife', 'son'): 0.00
('daughter', 'niece'): 0.00
('niece', 'nephew'): 0.00


('husband', 'mother'): 0.26
('father', 'brother'): 0.22
('father', 'sister'): 0.17
('brother', 'uncle'): 0.13
('brother', 'aunt'): 0.11
('uncle', 'wife'): 0.00
('uncle', 'husband'): 0.00
('uncle', 'mother'): 0.00
('uncl

### Comprehensive saturation

In [13]:
comp_saturations = {
    rel: {} for rel in micro_saturations
}

for rel in micro_saturations:
    for path, value in micro_saturations[rel].items():
        comp_saturations[rel][path] = value * macro_saturations[rel][path]

for rel in comp_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(comp_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_comp:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}%")        
    print("\n")

('sister', 'uncle'): 0.23%
('sister', 'aunt'): 0.19%
('aunt', 'brother'): 0.17%
('aunt', 'sister'): 0.13%
('sister', 'father'): 0.06%
('sister', 'mother'): 0.02%
('aunt', 'wife'): 0.00%
('aunt', 'husband'): 0.00%
('wife', 'uncle'): 0.00%
('mother', 'wife'): 0.00%


('brother', 'brother'): 0.12%
('nephew', 'uncle'): 0.10%
('brother', 'sister'): 0.10%
('son', 'father'): 0.08%
('nephew', 'aunt'): 0.08%
('uncle', 'nephew'): 0.08%
('son', 'mother'): 0.07%
('uncle', 'niece'): 0.05%
('father', 'nephew'): 0.01%
('uncle', 'son'): 0.01%


('sister', 'son'): 0.17%
('sister', 'daughter'): 0.12%
('daughter', 'husband'): 0.07%
('daughter', 'wife'): 0.06%
('niece', 'brother'): 0.04%
('niece', 'sister'): 0.03%
('wife', 'nephew'): 0.00%
('wife', 'brother'): 0.00%
('niece', 'husband'): 0.00%
('daughter', 'niece'): 0.00%


('husband', 'mother'): 0.22%
('father', 'brother'): 0.14%
('father', 'sister'): 0.09%
('brother', 'uncle'): 0.06%
('brother', 'aunt'): 0.04%
('uncle', 'wife'): 0.00%
('uncle', 'husband