In [None]:
import pickle as pkl
import anytree
import numpy as np

In [None]:
equals = np.array([0,1,0], dtype="bool")
overlaps = np.array([1,1,1], dtype="bool")
included_in = np.array([0,1,1], dtype="bool")
disjoint = np.array([1,0,1], dtype="bool")
includes = np.array([1,1,0], dtype="bool")

In [None]:
sym_dict = {
    "=": equals,
    "o": overlaps,
    "<": included_in,
    ">": includes,
    "!": disjoint
}
all_rels = set(sym_dict.keys())

In [None]:
idx_rel_map = {0: ("in", "out"),
               1: ("in", "in"),
               2: ("out", "in")}

In [None]:
def rec_bitwise_and(fs):
    if len(fs) == 1:
        return fs[0]
    if len(fs) == 2:
        return np.bitwise_and(fs[0], fs[1])
    mid = int(np.ceil(len(fs)/2))
    return np.bitwise_and(rec_bitwise_and(fs[:mid]), rec_bitwise_and(fs[mid:]))

In [None]:
def rec_bitwise_and_not(fs):
    f_not = list(map(np.invert, fs))
    return rec_bitwise_and(f_not)

In [None]:
def intersection(fs):
    t1 = rec_bitwise_and(fs)
    t2 = rec_bitwise_and_not(fs)
    return np.bitwise_or(t1, t2)

In [None]:
def not_filter_helper(r_var, rel1, rel2, n1, n2, sign):
    return "#count {{{0} : vrs({0}), {1}({3}, {0}), {2}({4}, {0})}} {5} 0".format(r_var, rel1, rel2, n1, n2, sign)

In [None]:
def not_filter(n1, n2, rel):
    ts = []
    for i in range(len(rel)):
        t = not_filter_helper(chr(ord('A')+i), idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, ">" if rel[i] else "=")
        ts.append(t)
    return ":- {}.".format(",\n   ".join(ts))

In [None]:
def ir_helper(rel1: str, rel2: str, n1: str, n2: str, idx: int, rel_var="X"):
    return "ir({0}, r{1}) :- {2}({3}, {0}), {4}({5}, {0}).".format(rel_var, idx, rel1, n1, rel2, n2)

In [None]:
def vr_ir_helper(rel1: str, rel2: str, n1: str, n2: str, idx: int, rel_var="X"):
    return "vr({}, r{}) ; ".format(rel_var, idx) + ir_helper(rel1, rel2, n1, n2, idx, rel_var)

In [None]:
def gen_coverage_rule(parent: str, children: list):  # For every non-leaf node
    lhs = 'out({}, X)'.format(parent)
    rhs = ", ".join(list(map(lambda x: 'out({}, X)'.format(x), children)))
    return "{} :- {}.".format(lhs, rhs)

In [None]:
def gen_concept2_rule(node: str, tax_num: int):
    return "concept2({}, {}).".format(node, tax_num)
def gen_concept_rule(node: str, tax_num: int, concept_num: int):
    return "concept({}, {}, {}).".format(node, tax_num, concept_num)

In [None]:
def gen_sibling_disjointness(n1: str, n2: str, idx: int):  # For every pair of siblings
    r1 = ir_helper(n1=n1, n2=n2, rel1="in", rel2="in", idx=idx)
    r2 = ":- {}.".format(not_filter_helper(n1=n1, n2=n2, rel1="in", rel2="out", sign="=", r_var="X"))
    r3 = ":- {}.".format(not_filter_helper(n1=n2, n2=n1, rel1="in", rel2="out", sign="=", r_var="X"))
    return "\n".join([r1,r2,r3])

In [None]:
def gen_isa_rule(child: str, parent: str, idx: int):  # For every parent-child relation
    r1 = ir_helper(n1=child, n2=parent, rel1="in", rel2="out", idx=idx)
    r2 = ":- {}.".format(not_filter_helper(n1=child, n2=parent, rel1="in", rel2="in", sign="=", r_var="X"))
    return "\n".join([r1,r2])

In [None]:
rule_count = 0

In [None]:
#rule_count = 21

In [None]:
#print("\n".join(gen_rules("1_A", "2_A", ["<", "="])))

In [None]:
#print("\n".join(gen_rules("1_B", "2_B", ["="])))

In [None]:
#print("\n".join(gen_rules("1_C", "2_F", ["<", "o"])))

In [None]:
#print("\n".join(gen_rules("1_D", "2_D", ["="])))

In [None]:
def gen_tax_rules(root, tax_id=0, concept_count=0):
    children = root.children
    sd_r = []
    cov_r = []
    concept_r = []
    isa_r = []
    if len(children) > 0:
        sd_r = [gen_sibling_disjointness(children[n1].name, children[n2].name, 0) for n1 in range(len(children)) for n2 in range(n1+1, len(children))]
        cov_r = [gen_coverage_rule(root.name, list(map(lambda x: x.name, children)))]
        concept_r = [gen_concept2_rule(root.name, tax_id)]
        isa_r = [gen_isa_rule(child.name, root.name, 0) for child in children]
        for child in children:
            t_sd_r, t_cov_r, t_concept_r, t_isa_r, concept_count = gen_tax_rules(child, tax_id, concept_count)
            sd_r.extend(t_sd_r)
            cov_r.extend(t_cov_r)
            concept_r.extend(t_concept_r)
            isa_r.extend(t_isa_r)
    else:
        concept_r = [gen_concept_rule(root.name, tax_id, concept_count)]
        concept_count += 1
    return sd_r, cov_r, concept_r, isa_r, concept_count

In [None]:
def gen_rules(n1, n2, rels):
    """
    Generates the rules that must be encoded in clingo to represent the
    list of possible relations (rels) between given nodes n1 and n2
    """
    global rule_count
    rule_count += 1
    rules = []
    not_rels = list(all_rels - set(rels))
    rels = list(map(lambda x: sym_dict[x], rels))
    not_rels = list(map(lambda x: sym_dict[x], not_rels))
    for not_rel in not_rels:
        rules.append(not_filter(n1, n2, not_rel))
    if len(rels) > 1:
        intersect = intersection(rels)
        for i in range(len(intersect)):
            if intersect[i] == 1 and rels[0][i] == 0:
                rules.append(ir_helper(idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, rule_count))
            elif intersect[i] == 1 and rels[0][i] == 1:
                rules.append(":- {}.".format(not_filter_helper("X", idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, "="), "."))
            elif intersect[i] == 0:
                rules.append(vr_ir_helper(idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, rule_count))
    else:
        for i in range(len(rels[0])):
            if rels[0][i] == 0:
                rules.append(ir_helper(idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, rule_count))
            elif rels[0][i] == 1:
                rules.append(":- {}.".format(not_filter_helper("X", idx_rel_map[i][0], idx_rel_map[i][1], n1, n2, "="), "."))
    return rules

In [None]:
#print("\n".join(gen_rules("1_E", "2_G", ["o", "="])))

In [None]:
# MIR and Decoding rules (Standard)

In [None]:
def decoding_rules():
    
    rel_list = list(all_rels)
    mir_rules = [] ##
    for i in range(len(rel_list)):
        for j in range(i+1, len(rel_list)):
            mir_rules.append(':- rel(X, Y, "{}"), rel(X, Y, "{}"), concept2(X, N1), concept2(X, N2).'.format(rel_list[i], rel_list[j]))
    
    t = []
    for rel in rel_list:
        t.append('not rel(X, Y, "{}")'.format(rel))
    t.append('concept2(X, N1)')
    t.append('concept2(Y, N2)')
    t.append('N1 < N2')
    t.append('not ncf(X)')
    t.append('not ncf(Y)')
    at_least_one_rule = ':- {}.'.format(", ".join(t)) ##
    
    rel_def = [] ##
    for rel in rel_list:
        t_ = []
        for i in range(3):
            t_.append('{1}hint(X, Y, {0})'.format(i, "" if sym_dict[rel][i] else "not "))
        rel_def.append('rel(X, Y, "{}") :- {}.'.format(rel, ", ".join(t_)))
    
    ncf_rules = [] ##
    for i in range(3):
        ncf_rules.append('hint(X, Y, {}) :- concept2(X, N1), concept2(Y, N2), N1 < N2, vrs(R), {}(X, R), {}(Y, R), not ncf(X), not ncf(Y).'.format(i, idx_rel_map[i][0], idx_rel_map[i][1]))
    
    all_rules = mir_rules
    all_rules.append(at_least_one_rule)
    all_rules.extend(rel_def)
    all_rules.extend(ncf_rules)
    return all_rules

In [None]:
print("\n".join(decoding_rules()))

In [None]:
anytree_ = None
with open('Temp_Pickle_Data/cen_test/anytree.pkl', 'rb') as f:
    anytree_ = pkl.load(f)

In [None]:
anytree_.keys()

In [None]:
sibling_disjointness_rules = []
coverage_rules = []
concept_rules = []
isa_rules = []

for i, tax_name in enumerate(anytree_.keys()):
    root = anytree_[tax_name][tax_name].children[0]
    all_rules = gen_tax_rules(root, tax_id=i)
    sibling_disjointness_rules.extend(all_rules[0])
    coverage_rules.extend(all_rules[1])
    concept_rules.extend(all_rules[2])
    isa_rules.extend(all_rules[3])

In [None]:
print("\n".join(sibling_disjointness_rules))

In [None]:
print("\n".join(coverage_rules))

In [None]:
print("\n".join(concept_rules))

In [None]:
print("\n".join(isa_rules))

In [None]:
articulations = None
with open('Temp_Pickle_Data/cen_test/taxDesc.pkl', 'rb') as f:
    articulations = pkl.load(f)
print(articulations)

In [None]:
articulation_rules = []
for idx, row in articulations.iterrows():
    n1 = row['Node1']
    rel = row['Relation'].split(",")
    n2 = row['Node2']
    if rel[0] != 'parent':
        articulation_rules.extend(gen_rules(n1, n2, rel))
print("\n".join(articulation_rules))