In [79]:
import logging; logging.basicConfig(level=logging.INFO)
import torch
import numpy as np
import logictensornetworks as ltn
import networkx as nx
import itertools

# Data

Parent relationships (the knowledge is assumed complete).

In [80]:
entities = ["sue", "diana", "john", "edna", "paul", "francis", "john2",
                "john3", "john4", "joe", "jennifer", "juliet", "janice",
                "joey", "tom", "bonnie", "katie"]

parents = [
        ("sue", "diana"),
        ("john", "diana"),
        ("sue", "bonnie"),
        ("john", "bonnie"),
        ("sue", "tom"),
        ("john", "tom"),
        ("diana", "katie"),
        ("paul", "katie"),
        ("edna", "sue"),
        ("john2", "sue"),
        ("edna", "john3"),
        ("john2", "john3"),
        ("francis", "john"),
        ("john4", "john"),
        ("francis", "janice"),
        ("john4", "janice"),
        ("janice", "jennifer"),
        ("joe", "jennifer"),
        ("janice", "juliet"),
        ("joe", "juliet"),
        ("janice", "joey"),
        ("joe", "joey")]

all_relationships = list(itertools.product(entities, repeat=2))
not_parents = [item for item in all_relationships if item not in parents]

Visualized in a graph.

In [81]:
# # Ground Truth Parents
# parDG_truth = nx.DiGraph(parents)
# pos= nx.drawing.nx_agraph.graphviz_layout(parDG_truth, prog='dot')
# nx.draw(parDG_truth,pos,with_labels=True)

Ancestor relationships and visualization in a graph.

In [82]:
# Ground Truth Ancestors
def get_descendants(entity, DG):
    all_d = []
    direct_d = list(DG.successors(entity))
    all_d += direct_d
    for d in direct_d:
        all_d += get_descendants(d, DG)
    return all_d

ancestors = []
for e in entities:
    for d in get_descendants(e, parDG_truth):
        ancestors.append((e,d))

# ancDG_truth = nx.DiGraph(ancestors)
# pos= nx.drawing.nx_agraph.graphviz_layout(parDG_truth, prog='dot')
# nx.draw(ancDG_truth,pos,with_labels=True)

# LTN

Every individual is grounded as a trainable LTN constant in $\mathbb{R}^2$. The grounding of the predicates `Parent` and `Ancestor` (modelled by multi-layer perceptrons) are learned at the same times as the embeddings for the individuals.

We give the complete parent relationships in the knowledgebase. However, we don't give any ancestor relationships; they are to be inferred using a set of rules.

In [83]:
embedding_size = 4

Ancestor = ltn.Predicate.MLP([embedding_size,embedding_size],hidden_layer_sizes=[8,8])
Parent = ltn.Predicate.MLP([embedding_size,embedding_size],hidden_layer_sizes=[8,8])

g_e = {
    l: ltn.constant(np.random.uniform(low=0.,high=1.,size=embedding_size), trainable=True) 
    for l in entities
}


Knowledge Base

In [84]:
Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())
And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())
Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())
Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())
Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=5),semantics="forall")
Exists = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMean(p=5),semantics="exists")

In [85]:
formula_aggregator = ltn.fuzzy_ops.Aggreg_pMeanError(p=5)

# defining the theory
#@tf.function
def axioms():
    # Variables created in the training loop, so tf.GradientTape
    # keeps track of the connection with the trainable constants.
    a = ltn.variable("a",torch.stack(list(g_e.values())))
    b = ltn.variable("b",torch.stack(list(g_e.values())))
    c = ltn.variable("c",torch.stack(list(g_e.values())))

    ## Complete knowledge about parent relationships.
    ## The ancestor relationships are to be learned with these additional rules.
    axioms = [
        # forall pairs of individuals in the parent relationships: Parent(ancestor,child)
        Parent([g_e[a],g_e[c]])
        for a,c in parents
    ] + \
    [
        # forall pairs of individuals not in the parent relationships: Not(Parent([n_parent,n_child])))
        Not(Parent([g_e[a],g_e[c]]))
        for a,c in not_parents
    ] + \
    [
        # if a is parent of b, then a is ancestor of b
        Forall((a,b), Implies(Parent([a,b]),Ancestor([a,b]))),
        # parent is anti reflexive
        Forall(a, Not(Parent([a,a]))),
        # ancestor is anti reflexive
        Forall(a, Not(Ancestor([a,a]))),
        # parent is anti symmetric
        Forall((a,b), Implies(Parent([a,b]),Not(Parent([b,a])))),
        # if a is parent of an ancestor of c, a is an ancestor of c too
        Forall(
            (a,b,c),
            Implies(And(Parent([a,b]),Ancestor([b,c])), Ancestor([a,c])),
            p=6
        ),
        # if a is an ancestor of b, a is a parent of b OR a parent of an ancestor of b
        Forall(
            (a,b),
            Implies(Ancestor([a,b]),
                    Or(Parent([a,b]),
                       Exists(c, And(Ancestor([a,c]),Parent([c,b])),p=6)
                      )
                   )
        )
    ]
    # computing sat_level
    axioms = torch.stack([torch.squeeze(ax) for ax in axioms])
    sat_level = formula_aggregator(axioms, axis=0)
    return sat_level, axioms

print("Initial sat level %.5f"%axioms()[0])

Initial sat level 0.51109


Training

In [86]:
trainable_variables = list(Parent.parameters())\
                      +list(Ancestor.parameters())\
                      +list(g_e.values())
optimizer = torch.optim.Adam(trainable_variables, lr=0.001)

In [87]:
for epoch in range(2000):
    optimizer.zero_grad()
    loss_value = 1. - axioms()[0]
    loss_value.backward()
    optimizer.step()
    if epoch%200 == 0:
        print("Epoch %d: Sat Level %.3f"%(epoch, axioms()[0]))
print("Training finished at Epoch %d with Sat Level %.3f"%(epoch, axioms()[0]))

Epoch 0: Sat Level 0.514
Epoch 100: Sat Level 0.579
Epoch 200: Sat Level 0.585
Epoch 300: Sat Level 0.597
Epoch 400: Sat Level 0.611
Epoch 500: Sat Level 0.617
Epoch 600: Sat Level 0.625
Epoch 700: Sat Level 0.641
Epoch 800: Sat Level 0.655
Epoch 900: Sat Level 0.684
Epoch 1000: Sat Level 0.727
Epoch 1100: Sat Level 0.821
Epoch 1200: Sat Level 0.918
Epoch 1300: Sat Level 0.923
Epoch 1400: Sat Level 0.923
Epoch 1500: Sat Level 0.923
Epoch 1600: Sat Level 0.924
Epoch 1700: Sat Level 0.924
Epoch 1800: Sat Level 0.934
Epoch 1900: Sat Level 0.935
Training finished at Epoch 1999 with Sat Level 0.935


# Querying additional axioms

Additional axioms:
1. forall a,b,c: (Ancestor(a,b) & Parent(b,c)) -> Ancestor (a,c)
2. forall a,b: Ancestor(a,b) -> ~Ancestor(b,a)
3. forall a,b,c: (Parent(a,b) & Parent(b,c)) -> Ancestor(a,c)
4. forall a,b,c: (Ancestor(a,b) & Ancestor(b,c)) -> Ancestor(a,c)

In [88]:
a = ltn.variable("a",torch.stack(list(g_e.values())))
b = ltn.variable("b",torch.stack(list(g_e.values())))
c = ltn.variable("c",torch.stack(list(g_e.values())))

In [89]:
Forall((a,b,c), 
       Implies(And(Ancestor([a,b]),Parent([b,c])), Ancestor([a,c]))
)

tensor(0.7676, grad_fn=<RsubBackward1>)

In [90]:
Forall((a,b), 
       Implies(Ancestor([a,b]), Not(Ancestor([b,a])))
)

tensor(0.9729, grad_fn=<RsubBackward1>)

In [91]:
Forall((a,b,c),
       Implies(And(Parent([a,b]),Parent([b,c])), Ancestor([a,c]))
)

tensor(0.7775, grad_fn=<RsubBackward1>)

In [92]:
Forall((a,b,c),
       Implies(And(Parent([a,b]),Parent([b,c])), Ancestor([a,c]))
)

tensor(0.7775, grad_fn=<RsubBackward1>)

# Visualize Results

In [93]:
threshold = 0.5
parents_test = [
    (e1,e2) for e1 in entities for e2 in entities
    if (Parent([g_e[e1],g_e[e2]]).detach().numpy() > 0.5)
]

# parDG_test = nx.DiGraph(parents_test)
# pos= nx.drawing.nx_agraph.graphviz_layout(parDG_truth, prog='dot')
# nx.draw(parDG_test,pos,with_labels=True)

In [94]:
threshold = 0.5
ancestors_test = [
    (e1,e2) for e1 in entities for e2 in entities
    if (Ancestor([g_e[e1],g_e[e2]]).detach().numpy() > 0.5)
]

# ancDG_test = nx.DiGraph(ancestors_test)
# pos= nx.drawing.nx_agraph.graphviz_layout(parDG_test, prog='dot')
# nx.draw(ancDG_test,pos,with_labels=True)

In [95]:
not_ancestors = [item for item in all_relationships if item not in ancestors]

## 3 ##
is_ancestor = [Ancestor([g_e[a],g_e[c]]) for a,c in ancestors]
is_ancestor = torch.stack([torch.squeeze(ax) for ax in is_ancestor])
formula_aggregator(is_ancestor, axis=0)

tensor(0.2844, grad_fn=<RsubBackward1>)

In [96]:
## 4 ##
isnot_ancestor = [Not(Ancestor([g_e[a],g_e[c]])) for a,c in not_ancestors]
isnot_ancestor = torch.stack([torch.squeeze(ax) for ax in isnot_ancestor])
formula_aggregator(isnot_ancestor, axis=0)

tensor(0.9257, grad_fn=<RsubBackward1>)

In [97]:
is_ancestor>0.5

tensor([ True,  True,  True, False,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True, False, False, False, False,  True,  True,  True,  True,  True,
        False,  True,  True,  True,  True,  True, False, False, False, False,
         True,  True,  True,  True,  True,  True])

In [100]:
sum(is_ancestor.detach().numpy()>0.5)/46

0.7391304347826086

In [101]:
sum(isnot_ancestor.detach().numpy()>0.5)/243

1.0