In [12]:
from factorgraph import VariableNode, FactorNode, FactorGraph

var_a = VariableNode("a")
var_b = VariableNode("b")
fac = FactorNode("f", "0.95 if V[a] and not V[b] else 0.05")
fg = FactorGraph([var_a, var_b], [fac])
fg.marginal_inference_exhaustive()
print(fg)

VariableNode(name=a, val=None, fixed=False, prob=0.9090909090909091)
VariableNode(name=b, val=None, fixed=False, prob=0.09090909090909091)
FactorNode(name=f, formula=0.95 if V[a] and not V[b] else 0.05, prob=None)



In [36]:
var_a.val = 0
var_a.fix()
fg.marginal_inference_exhaustive()
print(fg)

VariableNode(name=a, val=0, fixed=True, prob=None)
VariableNode(name=b, val=None, fixed=False, prob=0.5)
FactorNode(name=f, formula=0.95 if V[a] and not V[b] else 0.05, prob=None)



In [37]:
import copy
import itertools
import pprint

pp = pprint.PrettyPrinter(indent=4)


def getvar(v):
    return var_a if v == "a" else var_b


v2f_message = {"a": {"f": (0.5, 0.5)}, "b": {"f": (0.5, 0.5)}}
f2v_message = {"f": {"a": (0.5, 0.5), "b": (0.5, 0.5)}}
message_dict = {"v2f": v2f_message, "f2v": f2v_message}


def update_message_dict(message_dict):
    old_message_dict = copy.deepcopy(message_dict)
    new_message_dict = copy.deepcopy(message_dict)

    old_v2f_message = old_message_dict["v2f"]
    for v in old_v2f_message:
        vdict = old_v2f_message[v]
        for f in vdict:
            if fstar := set(vdict.keys()) - {f}:
                raise NotImplementedError()
            else:
                new_message_dict["v2f"][v][f] = (
                    0.5,
                    0.5,
                )  # uniform distribution
    old_f2v_message = old_message_dict["f2v"]
    for f in old_f2v_message:
        fdict = old_f2v_message[f]
        for v in fdict:
            if not (vstars := set(fdict.keys()) - {v}):
                raise NotImplementedError()
            message = []
            for vval in [0, 1]:
                var = getvar(v)
                # print(var)
                var.val = vval
                prob = 0
                for val_vec in itertools.product([0, 1], repeat=len(vstars)):
                    for vstar, val in zip(vstars, val_vec):
                        var = getvar(vstar)
                        var.val = val
                        # print(var)
                    prob_vstar = fac.calc_prob([var_a, var_b], None)
                    for vstar, val in zip(vstars, val_vec):
                        prob_vstar *= old_v2f_message[vstar][f][val]
                    prob += prob_vstar
                message.append(prob)
            sum_prob = sum(message)
            new_message_dict["f2v"][f][v] = tuple(p / sum_prob for p in message)
    return new_message_dict


def get_marginal(message_dict):
    for v in message_dict["v2f"]:
        prob_0, prob_1 = 1, 1
        for f in message_dict["v2f"][v]:
            prob_0 *= message_dict["f2v"][f][v][0]
            prob_1 *= message_dict["f2v"][f][v][1]
        all_prob = prob_0 + prob_1
        prob_0 /= all_prob
        prob_1 /= all_prob
        print(f"{v}: {prob_0}, {prob_1}")


pp.pprint(message_dict)
get_marginal(message_dict)


{   'f2v': {'f': {'a': (0.5, 0.5), 'b': (0.5, 0.5)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.5, 0.5
b: 0.5, 0.5


In [38]:
v2f_message = {"a": {"f": (0.5, 0.5)}, "b": {"f": (0.5, 0.5)}}
f2v_message = {"f": {"a": (0.5, 0.5), "b": (0.5, 0.5)}}
message_dict = {"v2f": v2f_message, "f2v": f2v_message}
print("iteration 0")
pp.pprint(message_dict)
get_marginal(message_dict)
print()

for i in range(10):
    print(f"iteration {i + 1}")
    message_dict = update_message_dict(message_dict)
    pp.pprint(message_dict)
    get_marginal(message_dict)
    print()

iteration 0
{   'f2v': {'f': {'a': (0.5, 0.5), 'b': (0.5, 0.5)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.5, 0.5
b: 0.5, 0.5

iteration 1
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
                        'b': (0.9090909090909091, 0.09090909090909091)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.09090909090909091, 0.9090909090909091
b: 0.9090909090909091, 0.09090909090909091

iteration 2
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
                        'b': (0.9090909090909091, 0.09090909090909091)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.09090909090909091, 0.9090909090909091
b: 0.9090909090909091, 0.09090909090909091

iteration 3
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
                        'b': (0.9090909090909091, 0.09090909090909091)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.090909090909

In [39]:
import copy
import itertools
import pprint

pp = pprint.PrettyPrinter(indent=4)


def getvar(v):
    return var_a if v == "a" else var_b


v2f_message = {"a": {"f": (0.5, 0.5)}, "b": {"f": (0.5, 0.5)}}
f2v_message = {"f": {"a": (0.5, 0.5), "b": (0.5, 0.5)}}
message_dict = {"v2f": v2f_message, "f2v": f2v_message}


def update_message_dict(message_dict):
    old_message_dict = copy.deepcopy(message_dict)
    new_message_dict = copy.deepcopy(message_dict)

    old_v2f_message = old_message_dict["v2f"]
    for v in old_v2f_message:
        vdict = old_v2f_message[v]
        for f in vdict:
            if fstar := set(vdict.keys()) - {f}:
                raise NotImplementedError()
            else:
                new_message_dict["v2f"][v][f] = (
                    0.5,
                    0.5,
                )  # uniform distribution
    old_f2v_message = old_message_dict["f2v"]
    for f in old_f2v_message:
        fdict = old_f2v_message[f]
        for v in fdict:
            if not (vstars := set(fdict.keys()) - {v}):
                raise NotImplementedError()
            message = []
            for vval in [0, 1]:
                var = getvar(v)
                # print(var)
                var.val = vval
                prob = 0
                for val_vec in itertools.product([0, 1], repeat=len(vstars)):
                    for vstar, val in zip(vstars, val_vec):
                        var = getvar(vstar)
                        var.val = val
                        # print(var)
                    prob_vstar = fac.calc_prob([var_a, var_b], None)
                    for vstar, val in zip(vstars, val_vec):
                        prob_vstar *= old_v2f_message[vstar][f][val]
                    prob += prob_vstar
                message.append(prob)
            sum_prob = sum(message)
            new_message_dict["f2v"][f][v] = tuple(p / sum_prob for p in message)
    return new_message_dict


def get_marginal(message_dict):
    for v in message_dict["v2f"]:
        prob_0, prob_1 = 1, 1
        for f in message_dict["v2f"][v]:
            prob_0 *= message_dict["f2v"][f][v][0]
            prob_1 *= message_dict["f2v"][f][v][1]
        all_prob = prob_0 + prob_1
        prob_0 /= all_prob
        prob_1 /= all_prob
        print(f"{v}: {prob_0}, {prob_1}")


pp.pprint(message_dict)
get_marginal(message_dict)

v2f_message = {"a": {"f": (0.5, 0.5)}, "b": {"f": (0.5, 0.5)}}
f2v_message = {"f": {"a": (0.5, 0.5), "b": (0.5, 0.5)}}
message_dict = {"v2f": v2f_message, "f2v": f2v_message}
print("iteration 0")
pp.pprint(message_dict)
get_marginal(message_dict)
print()

for i in range(10):
    print(f"iteration {i + 1}")
    message_dict = update_message_dict(message_dict)
    pp.pprint(message_dict)
    get_marginal(message_dict)
    print()

{   'f2v': {'f': {'a': (0.5, 0.5), 'b': (0.5, 0.5)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.5, 0.5
b: 0.5, 0.5
iteration 0
{   'f2v': {'f': {'a': (0.5, 0.5), 'b': (0.5, 0.5)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.5, 0.5
b: 0.5, 0.5

iteration 1
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
                        'b': (0.9090909090909091, 0.09090909090909091)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.09090909090909091, 0.9090909090909091
b: 0.9090909090909091, 0.09090909090909091

iteration 2
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
                        'b': (0.9090909090909091, 0.09090909090909091)}},
    'v2f': {'a': {'f': (0.5, 0.5)}, 'b': {'f': (0.5, 0.5)}}}
a: 0.09090909090909091, 0.9090909090909091
b: 0.9090909090909091, 0.09090909090909091

iteration 3
{   'f2v': {   'f': {   'a': (0.09090909090909091, 0.9090909090909091),
             