In [None]:
import libspn as spn
import tensorflow as tf

import numpy as np
sess = tf.InteractiveSession()

In [None]:
# Inputs
ivs = spn.IVs(num_vars=2, num_vals=2, name="IVs")

# Layer 1 - Input mixtures - Parallel Sums
# Sums-1
s1 = spn.ParallelSums((ivs, [0, 1]), num_sums=2, name="Sums1")
s1.generate_weights([0.4, 0.6, 0.1, 0.9])

# Sums-2
s2 = spn.ParallelSums((ivs, [2, 3]), num_sums=2, name="Sums2")
s2.generate_weights([0.7, 0.3, 0.8, 0.2])

# Layer 2 - Components - Permuted Products
p1_4 = spn.PermProducts(s1, s2, name="PermComp_1-4")

# Layer 3 - Mixing components - Generic Sums
s3 = spn.Sums((p1_4, [0, 2, 3]), 
              (p1_4, [1, 2, 3]),
              (p1_4, [0, 1, 3]),
              num_sums=3, name="Sums3")
s3.generate_weights([0.2, 0.3, 0.5, 
                     0.1, 0.3, 0.6,
                     0.2, 0.7, 0.1])

# Layer 4 - Components - Generic Products
p5_7 = spn.Products((s3, [0, 1]), 
                    (s3, [0, 2]), 
                    (s3, [1, 2]), 
                    num_prods=3, name="Products_5-7")

# Layer 5 - Mixing components - Single Sum
root = spn.Sum((p5_7, [0]), 
               (p5_7, [1]), 
               (p5_7, [2]), name="Mixture")
root.generate_weights([0.5, 0.2, 0.3])

# Init weights
spn.initialize_weights(root).run()

In [None]:
spn.display_spn_graph(root)

In [None]:
# Feed
values = np.arange(-1, 2)
points = np.array(np.meshgrid(*[values for i in range(2)])).T
feed = points.reshape(-1, points.shape[-1])

# True value
true_values = np.array([[1.0],
                        [0.58593],
                        [0.05493],
                        [0.03667],
                        [0.0213741],
                        [0.0020421],
                        [0.63567],
                        [0.3727701],
                        [0.0348381]], dtype=spn.conf.dtype.as_numpy_dtype)

# True MPE value
true_mpe_values = np.array([[0.07776],
                            [0.07776],
                            [0.00486],
                            [0.0032256],
                            [0.0032256],
                            [0.0002688],
                            [0.07776],
                            [0.07776],
                            [0.00486]], dtype=spn.conf.dtype.as_numpy_dtype)

In [None]:
# Value
value=root.get_value()
log_value=root.get_log_value()

In [None]:
value_array=value.eval(feed_dict={ivs: feed})
value_array_log=np.exp(log_value.eval(feed_dict={ivs: feed}))

print(value_array)
print(true_values)
np.testing.assert_almost_equal(value_array, true_values)
np.testing.assert_almost_equal(value_array_log, true_values)

In [None]:
# MPE Value
mpe_value = root.get_value(spn.InferenceType.MPE)
log_mpe_value = root.get_log_value(spn.InferenceType.MPE)

In [None]:
mpe_value_array=mpe_value.eval(feed_dict={ivs: feed})
mpe_value_array_log=np.exp(log_mpe_value.eval(feed_dict={ivs: feed}))

print(mpe_value_array)
print(true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array, true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array_log, true_mpe_values)

In [None]:
# MPE MPE Path
#mpe_mpe_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False)
#mpe_mpe_path_gen.get_mpe_path(root)

#print(mpe_mpe_path_gen.counts[ivs].eval(feed_dict={ivs: feed}))

In [None]:
# MPE Marginal Path
#mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MARGINAL, log=False)
#mpe_marginal_path_gen.get_mpe_path(root)

#print(mpe_marginal_path_gen.counts[ivs].eval(feed_dict={ivs: feed}))