In [1]:
edges = {
    (1,2):0.1,
    (2,3):0.1,
    (2,4):0.2,
    (1,5):0.1,
    (5,6):0.1,
    (5,7):0.4,
    (1,8):0.1,
    (8,9):0.5,
    (8,10):0.3,
}

In [2]:
def psi(u, v, u_val, v_val):
    if u > v:
        u, v = v, u
    p = edges[(u, v)]
    if u_val == v_val:
        return 1-p
    else:
        return p


In [3]:
from collections import defaultdict
graph = defaultdict(list)
for (u, v), p in edges.items():
    graph[u].append(v)
    graph[v].append(u)


In [4]:

observed = {3:1,4:0,6:0,7:0,9:1,10:1}
local_tables = {}
marginal_tables = {}
collect_msgs = {}

def collect(v, u):
    children = [child for child in graph[v] if child != u]
    msgs = [collect(child, v) for child in children]
    
    # compute local table
    local_table = [1, 1]
    if v in observed:
        local_table[1-observed[v]]=0
    for x_v in [0, 1]:
        for msg in msgs:
            local_table[x_v] *= msg[x_v]
    local_tables[v] = local_table
    
    # compute upward message 
    if u is None:
        # v is root, no need to prepare a message
        marginal_tables[v] = local_table
        return None
    else:
        msg = [0, 0]
        for x_u in [0, 1]:
            for x_v in [0, 1]:
                msg[x_u] += local_table[x_v]*psi(u, v, x_u, x_v)
        collect_msgs[(v,u)] = msg
        return msg

def distribute(u, v):
    children = [child for child in graph[v] if child != u]
    if u is not None:
        msg_uv = [0, 0]
        marginal_tables[v] = [0, 0]
        for x_v in [0, 1]:
            # fix message v->u to message v->u 
            for x_u in [0, 1]:
                msg_uv[x_v] += marginal_tables[u][x_u]/collect_msgs[(v,u)][x_u]*psi(u,v,x_u,x_v)
            # final marginal table
            marginal_tables[v][x_v] = local_tables[v][x_v]*msg_uv[x_v]
    for child in children:
        distribute(v, child)

def normalize(table):
    return [t/sum(table) for t in table]

def collect_distribute(root):
    local_tables.clear()
    marginal_tables.clear()
    collect_msgs.clear()
    
    collect(root, None)
    distribute(None, root)
    likelihood = sum(marginal_tables[root])
    conditionals = {v: normalize(table) for v, table in 
                    marginal_tables.items() if v not in observed}
    print(f"root: {root}")
    print(f"likelihood: {likelihood:.3f}")
    for v, table in sorted(conditionals.items()):
        print(f"p({v}) = [{table[0]:.2f}, {table[1]:.2f}]")


In [5]:
observed = {3:0,4:1,6:1,7:0,9:0,10:1}
for root in [1,2,6]:
    collect_distribute(root)

root: 1
likelihood: 0.012
p(1) = [0.21, 0.79]
p(2) = [0.36, 0.64]
p(5) = [0.14, 0.86]
p(8) = [0.20, 0.80]
root: 2
likelihood: 0.012
p(1) = [0.21, 0.79]
p(2) = [0.36, 0.64]
p(5) = [0.14, 0.86]
p(8) = [0.20, 0.80]
root: 6
likelihood: 0.012
p(1) = [0.21, 0.79]
p(2) = [0.36, 0.64]
p(5) = [0.14, 0.86]
p(8) = [0.20, 0.80]


In [6]:
observed = {3:0,4:0,6:1,7:0,9:0,10:1}
for root in [1,2,6]:
    collect_distribute(root)

root: 1
likelihood: 0.020
p(1) = [0.50, 0.50]
p(2) = [0.90, 0.10]
p(5) = [0.31, 0.69]
p(8) = [0.42, 0.58]
root: 2
likelihood: 0.020
p(1) = [0.50, 0.50]
p(2) = [0.90, 0.10]
p(5) = [0.31, 0.69]
p(8) = [0.42, 0.58]
root: 6
likelihood: 0.020
p(1) = [0.50, 0.50]
p(2) = [0.90, 0.10]
p(5) = [0.31, 0.69]
p(8) = [0.42, 0.58]


In [7]:
observed = {3:1,4:1,6:1,7:1,9:1,10:1}
for root in [1,2,6]:
    collect_distribute(root)

root: 1
likelihood: 0.106
p(1) = [0.01, 0.99]
p(2) = [0.01, 0.99]
p(5) = [0.01, 0.99]
p(8) = [0.06, 0.94]
root: 2
likelihood: 0.106
p(1) = [0.01, 0.99]
p(2) = [0.01, 0.99]
p(5) = [0.01, 0.99]
p(8) = [0.06, 0.94]
root: 6
likelihood: 0.106
p(1) = [0.01, 0.99]
p(2) = [0.01, 0.99]
p(5) = [0.01, 0.99]
p(8) = [0.06, 0.94]


In [8]:
retrieval_tables = {}
max_prob_assignment = {}

def collect2(v, u):
    children = [child for child in graph[v] if child != u]
    msgs = [collect2(child, v) for child in children]
    
    # compute local table
    local_table = [1, 1]
    if v in observed:
        local_table[1-observed[v]]=0
    for x_v in [0, 1]:
        for msg in msgs:
            local_table[x_v] *= msg[x_v]
    
    # compute upward message 
    if u is None:
        # v is root, return the max probability 
        for x_v in [0, 1]:
            if local_table[x_v] == max(local_table):
                max_prob_assignment[v] = x_v
                break
        return max(local_table)
    else:
        msg = [0, 0]
        retrieval_tables[v] = [None, None]
        for x_u in [0, 1]:
            for x_v in [0, 1]:
                candidate = local_table[x_v]*psi(u, v, x_u, x_v)
                if candidate > msg[x_u]:
                    msg[x_u] = candidate
                    retrieval_tables[v][x_u] = x_v
        return msg

def distribute2(u, v):
    children = [child for child in graph[v] if child != u]
    if u is not None:
        max_prob_assignment[v] = retrieval_tables[v][max_prob_assignment[u]]
    for child in children:
        distribute2(v, child)

def collect_distribute2(root):
    local_tables.clear()
    retrieval_tables.clear()
    max_prob_assignment.clear()
    
    max_prob = collect2(root, None)
    distribute2(None, root)
    print(f"root: {root}")
    print(f"max_prob: {max_prob:.4f}")
    for v, v_val in sorted(max_prob_assignment.items()):
        if v not in observed:
            print(f"x_{v}* = {v_val}")


In [9]:
observed = {3:0,4:1,6:1,7:0,9:0,10:1}
for root in [1,2,6]:
    collect_distribute2(root)

root: 1
max_prob: 0.0073
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1
root: 2
max_prob: 0.0073
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1
root: 6
max_prob: 0.0073
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1


In [10]:
observed = {3:0,4:0,6:1,7:0,9:0,10:1}
for root in [1,2,6]:
    collect_distribute2(root)

root: 1
max_prob: 0.0073
x_1* = 1
x_2* = 0
x_5* = 1
x_8* = 1
root: 2
max_prob: 0.0073
x_1* = 1
x_2* = 0
x_5* = 1
x_8* = 1
root: 6
max_prob: 0.0073
x_1* = 1
x_2* = 0
x_5* = 1
x_8* = 1


In [11]:
observed = {3:1,4:1,6:1,7:1,9:1,10:1}
for root in [1,2,6]:
    collect_distribute2(root)

root: 1
max_prob: 0.0992
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1
root: 2
max_prob: 0.0992
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1
root: 6
max_prob: 0.0992
x_1* = 1
x_2* = 1
x_5* = 1
x_8* = 1
