In [None]:
import libspn as spn
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()

In [None]:
# Inputs
ivs = spn.IVs(num_vars=2, num_vals=2, name="IVs")

# Input mixtures
s11 = spn.Sum((ivs, [0, 1]), name="Sum1.1")
s11.generate_weights([0.4, 0.6])
s12 = spn.Sum((ivs, [0, 1]), name="Sum1.2")
s12.generate_weights([0.1, 0.9])
s21 = spn.Sum((ivs, [2, 3]), name="Sum2.1")
s21.generate_weights([0.7, 0.3])
s22 = spn.Sum((ivs, [2, 3]), name="Sum2.2")
s22.generate_weights([0.8, 0.2])

# Components
p = spn.Products(s11, s21,
                 s11, s22,
                 s12, s22,
                 num_prods=3,
                 name="Comp1-2-3")

# Mixing components
root = spn.Sum(p, name="Mixture")
root.generate_weights([0.5, 0.2, 0.3])

# Init weights
spn.initialize_weights(root).run()

In [None]:
#spn.display_spn_graph(root)

In [None]:
# Feed
values = np.arange(-1, 2)
points = np.array(np.meshgrid(*[values for i in range(2)])).T
feed = points.reshape(-1, points.shape[-1])

# True value
true_values = np.array([[1.0],
                        [0.75],
                        [0.25],
                        [0.31],
                        [0.228],
                        [0.082],
                        [0.69],
                        [0.522],
                        [0.168]], dtype=spn.conf.dtype.as_numpy_dtype)

# True MPE value
true_mpe_values = np.array([[0.216],
                            [0.216],
                            [0.09],
                            [0.14],
                            [0.14],
                            [0.06],
                            [0.216],
                            [0.216],
                            [0.09]], dtype=spn.conf.dtype.as_numpy_dtype)

# True IVS counts Marginal inference
true_ivs_counts_marginal = np.array([[0, 1, 1, 0],
                                     [0, 1, 1, 0],
                                     [0, 1, 0, 1],
                                     [1, 0, 1, 0],
                                     [1, 0, 1, 0],
                                     [1, 0, 0, 1],
                                     [0, 1, 1, 0],
                                     [0, 1, 1, 0],
                                     [0, 1, 0, 1]], dtype=spn.conf.dtype.as_numpy_dtype)

# True IVS counts MPE inference
true_ivs_counts_mpe = np.array([[0, 1, 1, 0],
                                [0, 1, 1, 0],
                                [0, 1, 0, 1],
                                [1, 0, 1, 0],
                                [1, 0, 1, 0],
                                [1, 0, 0, 1],
                                [0, 1, 1, 0],
                                [0, 1, 1, 0],
                                [0, 1, 0, 1]], dtype=spn.conf.dtype.as_numpy_dtype)

In [None]:
# Value
value=root.get_value()
log_value=root.get_log_value()

In [None]:
value_array=value.eval(feed_dict={ivs: feed})
value_array_log=np.exp(log_value.eval(feed_dict={ivs: feed}))

print(value_array)
print(true_values)
np.testing.assert_almost_equal(value_array, true_values)
np.testing.assert_almost_equal(value_array_log, true_values)

In [None]:
# MPE Value
mpe_value = root.get_value(spn.InferenceType.MPE)
log_mpe_value = root.get_log_value(spn.InferenceType.MPE)

In [None]:
mpe_value_array=mpe_value.eval(feed_dict={ivs: feed})
mpe_value_array_log=np.exp(log_mpe_value.eval(feed_dict={ivs: feed}))

print(mpe_value_array)
print(true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array, true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array_log, true_mpe_values)

In [None]:
# MPE MPE Path
mpe_mpe_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False)
mpe_mpe_path_gen.get_mpe_path(root)
ivs_counts_mpe = mpe_mpe_path_gen.counts[ivs].eval(feed_dict={ivs: feed})

print("ivs_counts_mep: \n", ivs_counts_mpe)
np.testing.assert_almost_equal(ivs_counts_mpe, true_ivs_counts_mpe)

In [None]:
# MPE Marginal Path
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MARGINAL, log=False)
mpe_marginal_path_gen.get_mpe_path(root)
ivs_counts_marginal = mpe_marginal_path_gen.counts[ivs].eval(feed_dict={ivs: feed})

print("ivs_counts_marginal: \n", ivs_counts_marginal)
np.testing.assert_almost_equal(ivs_counts_marginal, true_ivs_counts_marginal)