In [1]:
import libspn as spn
import tensorflow as tf

import numpy as np
sess = tf.InteractiveSession()

In [2]:
# An example network with different multi-node types (ParallelSums, 
# PermProducts, Sums, Products, Sum, Product) per layer, with a mixture 
# of node-types per layer:
# Layer 1: ParallelSums & Sums
# Layer 2: PermProducts
# Layer 3: Sum & Sums
# Layer 4: Products & Product
# Layer 5: Sum

# Inputs
ivs = spn.IVs(num_vars=2, num_vals=2, name="IVs")

# Layer 1 - Input mixtures
# Sums-1 - Parallel Sums
s1 = spn.ParallelSums((ivs, [0, 1]), num_sums=2, name="Sums1")
s1.generate_weights([0.4, 0.6, 0.1, 0.9])

# Sums-2 - Sums
s2 = spn.Sums((ivs, [2, 3]), (ivs, [2, 3]), num_sums=2, name="Sums2")
s2.generate_weights([0.7, 0.3, 0.8, 0.2])

# Layer 2 - Components - Permuted Products
# Products-1_4
p1_4 = spn.PermProducts(s1, s2, name="PermComp_1-4")

# Layer 3 - Mixing components - Single Sum
# Sum-3.1
s31 = spn.Sum((p1_4, [0, 2, 3]), name="Sums31")
s31.generate_weights([0.2, 0.3, 0.5])

# Layer 3 - Mixing components - Generic Sums
# Sums-3.1 and 3.2
s32_33 = spn.Sums((p1_4, [1, 2, 3]),
                  (p1_4, [0, 1, 3]),
                  num_sums=2, name="Sums32_33")
s32_33.generate_weights([0.1, 0.3, 0.6,
                         0.2, 0.7, 0.1])


# Layer 4 - Components - Generic Products
# Products-5 and 6
p5_6 = spn.Products(s31, (s32_33, [0]),
                    s31, (s32_33, [1]), 
                    num_prods=2, name="Products_5-6")

# Layer 4 - Components - Single Product
# Product-7
p7 = spn.Products((s32_33, [0, 1]), name="Products_7")

# Layer 5 - Mixing components - Single Sum
# Root
root = spn.Sum((p5_6, [0]), 
               (p5_6, [1]), 
               p7, name="Mixture")
root.generate_weights([0.5, 0.2, 0.3])

# Init weights
spn.initialize_weights(root).run()

In [3]:
#spn.display_spn_graph(root)

In [4]:
# Feed
values = np.arange(-1, 2)
points = np.array(np.meshgrid(*[values for i in range(2)])).T
feed = points.reshape(-1, points.shape[-1])

# True value
true_values = np.array([[1.0],
                        [0.58593],
                        [0.05493],
                        [0.03667],
                        [0.0213741],
                        [0.0020421],
                        [0.63567],
                        [0.3727701],
                        [0.0348381]], dtype=spn.conf.dtype.as_numpy_dtype)

# True MPE value
true_mpe_values = np.array([[0.07776],
                            [0.07776],
                            [0.00486],
                            [0.0032256],
                            [0.0032256],
                            [0.0002688],
                            [0.07776],
                            [0.07776],
                            [0.00486]], dtype=spn.conf.dtype.as_numpy_dtype)

# True MPE value
true_ivs_counts_marginal = np.array([[0, 2, 2, 0],
                                     [0, 2, 2, 0],
                                     [0, 2, 0, 2],
                                     [2, 0, 2, 0],
                                     [2, 0, 2, 0],
                                     [2, 0, 0, 2],
                                     [0, 2, 2, 0],
                                     [0, 2, 2, 0],
                                     [0, 2, 0, 2]], dtype=spn.conf.dtype.as_numpy_dtype)

# True MPE value
true_ivs_counts_mpe = np.array([[0, 2, 2, 0],
                                [0, 2, 2, 0],
                                [0, 2, 0, 2],
                                [2, 0, 2, 0],
                                [2, 0, 2, 0],
                                [2, 0, 0, 2],
                                [0, 2, 2, 0],
                                [0, 2, 2, 0],
                                [0, 2, 0, 2]], dtype=spn.conf.dtype.as_numpy_dtype)

In [5]:
# Value
value=root.get_value()
log_value=root.get_log_value()

In [6]:
value_array=value.eval(feed_dict={ivs: feed})
value_array_log=np.exp(log_value.eval(feed_dict={ivs: feed}))

print(value_array)
print(true_values)
np.testing.assert_almost_equal(value_array, true_values)
np.testing.assert_almost_equal(value_array_log, true_values)

[[ 1.        ]
 [ 0.58593005]
 [ 0.05493001]
 [ 0.03667   ]
 [ 0.0213741 ]
 [ 0.0020421 ]
 [ 0.63567001]
 [ 0.37277013]
 [ 0.0348381 ]]
[[ 1.        ]
 [ 0.58592999]
 [ 0.05493   ]
 [ 0.03667   ]
 [ 0.0213741 ]
 [ 0.0020421 ]
 [ 0.63567001]
 [ 0.3727701 ]
 [ 0.0348381 ]]


In [7]:
# MPE Value
mpe_value = root.get_value(spn.InferenceType.MPE)
log_mpe_value = root.get_log_value(spn.InferenceType.MPE)

In [8]:
mpe_value_array=mpe_value.eval(feed_dict={ivs: feed})
mpe_value_array_log=np.exp(log_mpe_value.eval(feed_dict={ivs: feed}))

print(mpe_value_array)
print(true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array, true_mpe_values)
np.testing.assert_almost_equal(mpe_value_array_log, true_mpe_values)

[[ 0.07776  ]
 [ 0.07776  ]
 [ 0.00486  ]
 [ 0.0032256]
 [ 0.0032256]
 [ 0.0002688]
 [ 0.07776  ]
 [ 0.07776  ]
 [ 0.00486  ]]
[[ 0.07776  ]
 [ 0.07776  ]
 [ 0.00486  ]
 [ 0.0032256]
 [ 0.0032256]
 [ 0.0002688]
 [ 0.07776  ]
 [ 0.07776  ]
 [ 0.00486  ]]


In [9]:
# MPE MPE Path
mpe_mpe_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False)
mpe_mpe_path_gen.get_mpe_path(root)
ivs_counts_mpe = mpe_mpe_path_gen.counts[ivs].eval(feed_dict={ivs: feed})

print("ivs_counts_mep: \n", ivs_counts_mpe)
np.testing.assert_almost_equal(ivs_counts_mpe, true_ivs_counts_mpe)

ivs_counts_mep: 
 [[ 0.  2.  2.  0.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  0.  2.]
 [ 2.  0.  2.  0.]
 [ 2.  0.  2.  0.]
 [ 2.  0.  0.  2.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  0.  2.]]


In [10]:
# MPE Marginal Path
mpe_marginal_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MARGINAL, log=False)
mpe_marginal_path_gen.get_mpe_path(root)
ivs_counts_marginal = mpe_marginal_path_gen.counts[ivs].eval(feed_dict={ivs: feed})

print("ivs_counts_marginal: \n", ivs_counts_marginal)
np.testing.assert_almost_equal(ivs_counts_marginal, true_ivs_counts_marginal)

ivs_counts_marginal: 
 [[ 0.  2.  2.  0.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  0.  2.]
 [ 2.  0.  2.  0.]
 [ 2.  0.  2.  0.]
 [ 2.  0.  0.  2.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  2.  0.]
 [ 0.  2.  0.  2.]]
