In [1]:
import os
import random
import torch
import numpy as np
from torch.nn import functional as F

dataset_dir = "./umls"
all_trip_file = os.path.join(dataset_dir, "all.txt")
relations_file = os.path.join(dataset_dir, "relations.txt")
entities_file = os.path.join(dataset_dir, "entities.txt")

In [2]:
def read_xxx_to_id(file_path, num_read=None):
    xxx2id = {}
    with open(file_path, 'r') as file:
        lines = file.readlines()
        if num_read:
            lines = random.sample(lines, num_read)
        for line in lines:
            line = line.strip()
            xxx2id[line] = len(xxx2id)
    return xxx2id

def parse_triplets(triplets_file: str,
                   rel2id: dict,
                   ent2id: dict):
    """Read triplets (relation, head, tail)."""
    triplets = []
    with open(triplets_file, 'r') as file:
        for line in file:
            line = line.strip().split('\t')
            assert(len(line) == 3)
            try:
                triplets.append(
                    (
                        rel2id[line[1]],
                        ent2id[line[0]],
                        ent2id[line[2]]
                    )
                )
            except KeyError:
                pass
    return triplets

## Read file and resample

In [3]:
num_entities = None
num_relations = None

In [4]:
rel2id = read_xxx_to_id(relations_file, num_relations)
id2rel = {ident: rel for rel, ident in rel2id.items()}

ent2id = read_xxx_to_id(entities_file, num_entities)
id2ent = {ident: ent for ent, ident in ent2id.items()}

all_facts = parse_triplets(all_trip_file, rel2id, ent2id)
# relation to (head, tail)
rel2ht = {rel: [] for rel in id2rel.keys()}
for (r, h, t) in all_facts:
    rel2ht[r].append((h, t))

num_rel, num_ent, num_trip = len(rel2id), len(ent2id), len(all_facts)

num_rel, num_ent, num_trip

(46, 135, 6529)

## Compute bifurcation

In [5]:
from collections import defaultdict

def get_head_bifur(head_tail_pairs: list, k: int):
    bifur_cnt = 0
    head2tails = defaultdict(list)
    for h, t in head_tail_pairs:
        head2tails[h].append(t)
        if len(head2tails[h]) == k:
            bifur_cnt += 1
    return head2tails, bifur_cnt / len(head2tails.keys())

def get_tail_bifur(head_tail_pairs: list, k: int):
    bifur_cnt = 0
    tail2heads = defaultdict(list)
    for h, t in head_tail_pairs:
        tail2heads[t].append(h)
        if len(tail2heads[t]) == k:
            bifur_cnt += 1
    return tail2heads, bifur_cnt / len(tail2heads.keys())

In [6]:
max_lambda = 7
backward = False

for rel, pairs in rel2ht.items():
    print(f"[{id2rel[rel]}]")
    for k in range(2, max_lambda + 1):
        head2tails, bifur_head = get_head_bifur(pairs, k)
        print(f"bifur_head({k}): {bifur_head:.2f}")
        if backward:
            tail2heads, bifur_tail = get_tail_bifur(pairs, k)
            print(f"bifur_tail({k}): {bifur_tail:.2f}")
    print("\n")

[Manifestation_of]
bifur_head(2): 1.00
bifur_head(3): 0.82
bifur_head(4): 0.82
bifur_head(5): 0.82
bifur_head(6): 0.82
bifur_head(7): 0.82


[Evaluation_of]
bifur_head(2): 1.00
bifur_head(3): 1.00
bifur_head(4): 1.00
bifur_head(5): 1.00
bifur_head(6): 1.00
bifur_head(7): 1.00


[Performs]
bifur_head(2): 1.00
bifur_head(3): 1.00
bifur_head(4): 1.00
bifur_head(5): 1.00
bifur_head(6): 1.00
bifur_head(7): 1.00


[Ingredient_of]
bifur_head(2): 0.00
bifur_head(3): 0.00
bifur_head(4): 0.00
bifur_head(5): 0.00
bifur_head(6): 0.00
bifur_head(7): 0.00


[Location_of]
bifur_head(2): 1.00
bifur_head(3): 1.00
bifur_head(4): 1.00
bifur_head(5): 0.96
bifur_head(6): 0.96
bifur_head(7): 0.96


[Affects]
bifur_head(2): 0.96
bifur_head(3): 0.96
bifur_head(4): 0.91
bifur_head(5): 0.91
bifur_head(6): 0.91
bifur_head(7): 0.91


[Contains]
bifur_head(2): 0.11
bifur_head(3): 0.11
bifur_head(4): 0.00
bifur_head(5): 0.00
bifur_head(6): 0.00
bifur_head(7): 0.00


[Exhibits]
bifur_head(2): 1.00
bifur_head(3): 1.0

## Compute macro, micro and comprehensive saturation

In [7]:
def get_adjacency_matrices(triplets,
                           num_relations: int,
                           num_entities: int
                           ):
    """Compute adjacency matrix from all triplets
       in preparation for creating sparse matrix in torch.
    """
    matrix = {
        r: ([[0, 0]], [0.], [num_entities, num_entities])
        for r in range(num_relations)
    }
    for triplet in triplets:
        rel = triplet[0]
        head = triplet[1]
        tail = triplet[2]
        value = 1.
        matrix[rel][0].append([head, tail])
        matrix[rel][1].append(value)

    for rel, mat in matrix.items():
        matrix[rel] = torch.sparse.FloatTensor(
            torch.LongTensor(mat[0]).t(),
            torch.FloatTensor(mat[1]),
            mat[2]
        )

    return matrix

# `adj_matrix`:  adjacency matrix
# `head_nodes`: head nodes list
# return: a dict of `batch_size` head nodes {head: {tail: num_paths}}
def one_hop_from_head(adj_matrix: list,
                      head_nodes: list
                      ):
    # (batch_size, num_entities)
    v_x = F.one_hot(torch.LongTensor(head_nodes), adj_matrix.size(0)).float()
    # (num_entities, num_entities)
    result = torch.matmul(adj_matrix.t(), v_x.t())
    result = result.t().numpy()
    indices = np.argwhere(result > 0)
    # {head: {tail: num_paths}}
    ret = {head: [] for head in head_nodes}
    for row, col in indices:
        # `indices`: (batch, column) ==> (batch, tail)
        ret[head_nodes[row]].append(col)
    return ret

In [8]:
adj_matrices = get_adjacency_matrices(all_facts, num_rel, num_ent)
adj_matrices[0], one_hop_from_head(adj_matrices[0], list(ent2id.values())[:2])

(tensor(indices=tensor([[  0, 105, 118,  65,  84,  38,  81,  19,  19,  38,  81,
                          55,  13, 118, 118,  38,  37,  51,  38,  81,  25,  19,
                         111,  57,  37,  51,  37,  51,  57,  51,  51,  55,  25,
                          77,  65, 118,  59,  25,  25,  84,  37,  55,  13,  57,
                          37, 118,  13,  55,  25,  37,  57, 118,  37,  84,  37,
                          59,  55,  65,  57,  59,  81,  13,  59,  57,  81,  84,
                          38,  51,  65,  25,  55,  38,  51,  37,  57,  55,  55,
                          56,  13,  13,  81,  19,  13,  84,  77,  65,  81,  25,
                          37,  37,  55,  84,  65,  77,  13,  84,  38,  84,  55,
                          38,  84,  38,  57,  84,  81,  77,  81,  65,  55, 118,
                          38,  13, 118,  55,  65, 118,  37,  51,  19,  77,  38,
                         118,  37,  57,  13,  37,  55,  13, 111,  57,  38,  56,
                          59,  51,  57, 

In [15]:
from itertools import permutations

topk_macro = 5
topk_micro = 10
topk_comp = 10

max_rule_len = 3  # T >= 2

In [10]:
from time import time
from collections import defaultdict

start = time()
# relations
relations = list(id2rel.keys())
# initial paths {hop: list of path tuples}
paths = {1: [tuple([rel]) for rel in relations]}

# get number of triplets under each relation
num_rel2trip = {rel: len(rel2ht[rel]) for rel in id2rel.keys()}
# get triplets under each relation
rel_head2tails = {rel: {} for rel in id2rel.keys()}
for (r, h, t) in all_facts:
    if h not in rel_head2tails[r]:
        rel_head2tails[r][h] = set()
    rel_head2tails[r][h].add(t)

# `macro_saturations`: {rel: {path: saturation}}
macro_saturations = {rel: defaultdict(int) for rel in id2rel.keys()}
# `tmp_micro_saturations`: {rel: {path: {(head, tail): num_paths}}}
tmp_micro_saturations = {
    rel: defaultdict(dict) for rel in id2rel.keys()
}
# `micro_saturations`: {rel: {path: saturation}}
micro_saturations = {rel: defaultdict(int) for rel in id2rel.keys()}
# `total_paths_pairs`: {rel: {(head, tail): num_total_paths}}
total_paths_pairs = {
    rel: {
        pair: 0 for pair in rel2ht[rel]
    } for rel in id2rel.keys()
}
# `last_hops`: {rel: {head: {path: last_hop_nodes}}}
last_hops = {
    rel: {
        head: {
        } for head in rel_head2tails[rel].keys()
    } for rel in relations
}
for rel, head2path in last_hops.items():
    for head in head2path:
        for path in paths[1]:
            if head in rel_head2tails[path[0]]:
                head2path[head][path] = rel_head2tails[path[0]][head]
            else:
                head2path[head][path] = set()

# traverse length of single rule path
for t in range(2, max_rule_len+1):
    print(t)
    paths[t] = []
    # traverse paths of length `t-1`
    for last_path in paths[t-1]:
        # generate new paths
        for r in relations:
            new_path = last_path + tuple([r])
            paths[t].append(new_path)
            matrix = adj_matrices[r]
            # compute saturation for each relation
            for rel in relations:
                # compute next hops from last hops
                pair_occur = {pair: False for pair in rel2ht[rel]}
                # TODO matrix optimization
                for head in last_hops[rel].keys():
                    last_hop = last_hops[rel][head][last_path]  # get last hops
                    if not last_hop:
                        last_hops[rel][head][new_path] = set()
                        continue
                    last2next_hops = one_hop_from_head(matrix, list(last_hop))  # one more hop
                    next_hops = set()
                    for tail, hops in last2next_hops.items():
                        next_hops |= set(hops)
                        for hop in hops:
                            # successful reasoning
                            if hop in rel_head2tails[rel][head]:
                                pair_occur[(head, hop)] = True
                                if (head, hop) not in tmp_micro_saturations[rel][new_path]:
                                    tmp_micro_saturations[rel][new_path][(head, hop)] = 0.
                                tmp_micro_saturations[rel][new_path][(head, hop)] += 1.
                                total_paths_pairs[rel][(head, hop)] += 1.
                    last_hops[rel][head][new_path] = next_hops
                macro_saturations[rel][new_path] = np.mean(list(pair_occur.values()))
for rel in relations:
    for path, pairs in tmp_micro_saturations[rel].items():
        for pair, num_path in pairs.items():
            # `pair`: (head, tail)
            micro_saturations[rel][path] += num_path / total_paths_pairs[rel][pair]
        if len(tmp_micro_saturations[rel][path]) != 0:
            micro_saturations[rel][path] /= num_rel2trip[rel]

print(f"{time() - start}s")

2
3
970.2510817050934s


### Macro saturation

In [12]:
for rel in macro_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(macro_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_macro:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}%")        
    print("\n")

('Manifestation_of', 'Result_of'): 1.00%
('Manifestation_of', 'Affects', 'Result_of'): 1.00%
('Manifestation_of', 'Co-occurs_with', 'Result_of'): 1.00%
('Manifestation_of', 'Result_of', 'Result_of'): 1.00%
('Manifestation_of', 'Occurs_in', 'Result_of'): 1.00%


('Evaluation_of', 'Affects'): 0.81%
('Evaluation_of', 'Manifestation_of', 'Affects'): 0.81%
('Evaluation_of', 'Affects', 'Affects'): 0.81%
('Evaluation_of', 'Result_of', 'Affects'): 0.81%
('Evaluation_of', 'Isa', 'Affects'): 0.79%


('Performs', 'Affects', 'Performs'): 1.00%
('Performs', 'Associated_with', 'Performs'): 1.00%
('Exhibits', 'Associated_with', 'Performs'): 1.00%
('Uses', 'Affects', 'Performs'): 1.00%
('Produces', 'Affects', 'Performs'): 1.00%


('Causes', 'Result_of', 'Uses'): 1.00%
('Causes', 'Occurs_in', 'Uses'): 1.00%
('Causes', 'Occurs_in', 'Produces'): 1.00%
('Causes', 'Produces', 'Ingredient_of'): 1.00%
('Isa', 'Ingredient_of'): 0.96%


('Location_of', 'Produces', 'Location_of'): 0.73%
('Location_of', 'Affects

('Connected_to', 'Surrounds'): 1.00%
('Location_of', 'Location_of', 'Surrounds'): 1.00%
('Location_of', 'Location_of', 'Interconnects'): 1.00%
('Location_of', 'Part_of', 'Adjacent_to'): 1.00%
('Location_of', 'Disrupts', 'Surrounds'): 1.00%


('Affects', 'Affects'): 1.00%
('Result_of', 'Affects'): 1.00%
('Affects', 'Manifestation_of', 'Affects'): 1.00%
('Affects', 'Affects', 'Affects'): 1.00%
('Affects', 'Affects', 'Process_of'): 1.00%


('Complicates', 'Manifestation_of', 'Disrupts'): 1.00%
('Complicates', 'Co-occurs_with', 'Disrupts'): 1.00%
('Complicates', 'Result_of', 'Disrupts'): 1.00%
('Complicates', 'Complicates', 'Disrupts'): 1.00%
('Complicates', 'Occurs_in', 'Disrupts'): 1.00%


('Isa', 'Produces'): 0.86%
('Produces', 'Affects', 'Produces'): 0.83%
('Produces', 'Disrupts', 'Produces'): 0.77%
('Produces', 'Part_of', 'Produces'): 0.75%
('Produces', 'Derivative_of', 'Produces'): 0.75%


('Indicates', 'Produces', 'Location_of'): 1.00%
('Indicates', 'Result_of'): 0.89%
('Indicates',

### Micro saturation

In [13]:
for rel in micro_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(micro_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_micro:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}%")        
    print("\n")

('Manifestation_of', 'Result_of', 'Result_of'): 0.01%
('Associated_with', 'Result_of', 'Result_of'): 0.01%
('Result_of', 'Measures', 'Affects'): 0.00%
('Result_of', 'Assesses_effect_of', 'Affects'): 0.00%
('Manifestation_of', 'Result_of', 'Affects'): 0.00%
('Manifestation_of', 'Affects', 'Result_of'): 0.00%
('Manifestation_of', 'Process_of', 'Result_of'): 0.00%
('Associated_with', 'Result_of', 'Affects'): 0.00%
('Result_of', 'Result_of', 'Result_of'): 0.00%
('Manifestation_of', 'Co-occurs_with', 'Result_of'): 0.00%


('Evaluation_of', 'Associated_with', 'Performs'): 0.07%
('Result_of', 'Produces', 'Performs'): 0.02%
('Evaluation_of', 'Isa'): 0.01%
('Evaluation_of', 'Affects', 'Performs'): 0.01%
('Evaluation_of', 'Associated_with', 'Result_of'): 0.01%
('Evaluation_of', 'Affects', 'Result_of'): 0.01%
('Evaluation_of', 'Diagnoses', 'Result_of'): 0.01%
('Evaluation_of', 'Treats', 'Result_of'): 0.01%
('Result_of', 'Result_of', 'Result_of'): 0.01%
('Evaluation_of', 'Complicates', 'Result_of'

### Comprehensive saturation

In [16]:
comp_saturations = {
    rel: {} for rel in micro_saturations
}

for rel in micro_saturations:
    for path, value in micro_saturations[rel].items():
        comp_saturations[rel][path] = value * macro_saturations[rel][path]

for rel in comp_saturations:
    print(f"{id2rel[rel]:=^50}")
    sorted_items = sorted(comp_saturations[rel].items(), key=lambda x: x[1], reverse=True)
    for i, (path, saturation) in enumerate(sorted_items):
        if i == topk_comp:
            break
        print(f"{tuple(id2rel[r] for r in path)}: {saturation:.2f}%")        
    print("\n")

('Manifestation_of', 'Result_of', 'Result_of'): 0.01%
('Associated_with', 'Result_of', 'Result_of'): 0.01%
('Manifestation_of', 'Affects', 'Result_of'): 0.00%
('Manifestation_of', 'Result_of', 'Affects'): 0.00%
('Manifestation_of', 'Process_of', 'Result_of'): 0.00%
('Manifestation_of', 'Co-occurs_with', 'Result_of'): 0.00%
('Manifestation_of', 'Result_of'): 0.00%
('Associated_with', 'Result_of', 'Affects'): 0.00%
('Manifestation_of', 'Precedes', 'Result_of'): 0.00%
('Manifestation_of', 'Affects', 'Affects'): 0.00%


('Evaluation_of', 'Associated_with', 'Performs'): 0.02%
('Evaluation_of', 'Associated_with', 'Result_of'): 0.01%
('Evaluation_of', 'Affects', 'Result_of'): 0.01%
('Evaluation_of', 'Complicates', 'Result_of'): 0.01%
('Result_of', 'Produces', 'Performs'): 0.01%
('Evaluation_of', 'Isa'): 0.01%
('Evaluation_of', 'Result_of', 'Affects'): 0.00%
('Evaluation_of', 'Affects', 'Affects'): 0.00%
('Manifestation_of', 'Result_of', 'Affects'): 0.00%
('Associated_with', 'Result_of', 'Affe