# Tutorial 2: Inference

In [1]:
import libspn as spn

### Building a Test Graph with Initialized Weights

In [5]:
iv_x = spn.IVs(num_vars=2, num_vals=2, name="iv_x")

# Note that inputs are given as tuples (node, indices)
sum_11 = spn.Sum((iv_x, [0,1]), name="sum_11")
sum_11.generate_weights([0.4, 0.6])
# Again same node, same indices
sum_12 = spn.Sum((iv_x, [0,1]), name="sum_12")
sum_12.generate_weights([0.1, 0.9])
# Same node, but second IV
sum_21 = spn.Sum((iv_x, [2,3]), name="sum_21")
sum_21.generate_weights([0.7, 0.3])
sum_22 = spn.Sum((iv_x, [2,3]), name="sum_22")
sum_22.generate_weights([0.8, 0.2])

# Product node, taking sum nodes of IV1;1 and IV2;1
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
# Product node, taking sum nodes of IV1;1 and IV2;2
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
# ...
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")
root = spn.Sum(prod_1, prod_2, prod_3, name="root")
root.generate_weights([0.5, 0.2, 0.3])
iv_y = root.generate_ivs(name="iv_y")

### Visualizing the SPN Graph

In [6]:
spn.display_spn_graph(root)

### Add Value Ops
Now generate marginals and MPE ops. Internally a `Value` class is instantiated that is given a specific `InferenceType`. Then, `Value(inference_type).get_value(self)` is called, where `self` is the root node. implementation of `get_value`:
```python
    class Value:
    """
    ...
    """

        def get_value(self, root):
            """Assemble a TF operation computing the values of nodes of the SPN
            rooted in ``root``.

            Returns the operation computing the value for the ``root``. Operations
            computing values for other nodes can be obtained using :obj:`values`.

            Args:
                root (Node): The root node of the SPN graph.

            Returns:
                Tensor: A tensor of shape ``[None, num_outputs]``, where the first
                dimension corresponds to the batch size.
            """
            def fun(node, *args):
                with tf.name_scope(node.name):
                    if (self._inference_type == InferenceType.MARGINAL
                        or (self._inference_type is None and
                            node.inference_type == InferenceType.MARGINAL)):
                        return node._compute_value(*args)
                    else:
                        return node._compute_mpe_value(*args)

            self._values = {}
            with tf.name_scope("Value"):
                return compute_graph_up(root, val_fun=fun,
                                        all_values=self._values)
    ```

And the implementation of `compute_graph_up`:
```python

def compute_graph_up(root, val_fun, const_fun=None, all_values=None):
    """Computes a certain value for the ``root`` node in the graph, assuming
    that for op nodes, the value depends on values produced by inputs of the op
    node. For this, it traverses the graph depth-first from the ``root`` node
    to the leaf nodes.

    Args:
        root (Node): The root of the SPN graph.
        val_fun (function): A function ``val_fun(node, *args)`` producing a
            certain value for the ``node``. For an op node, it will have
            additional arguments with values produced for the input nodes of
            ``node``.  The arguments will NOT be added if ``const_fun``
            returns ``True`` for the node. The arguments can be ``None`` if
            the input was empty.
        const_fun (function): A function ``const_fun(node)`` that should return
            ``True`` if the value generated by ``val_fun`` does not depend on
            the values generated for the input nodes, i.e. it is a constant
            function. If set to ``None``, it is assumed to always return
            ``False``, i.e. no ``val_fun`` is a constant function.
        all_values (dict): A dictionary indexed by ``node`` in which values
            computed for each node will be stored. Can be set to ``None``.

    Returns:
        The value for the ``root`` node.
    """
    if all_values is None:  # Dictionary of computed values indexed by node
        all_values = {}
    stack = deque()  # Stack of nodes to process
    stack.append(root)

    last_val = None
    while stack:
        next_node = stack[-1]
        # Was this node already processed?
        # This might happen if the node is referenced by several parents
        if next_node not in all_values:
            if next_node.is_op:
                # OpNode
                input_vals = []
                all_input_vals = True
                if const_fun is None or const_fun(next_node) is False:
                    # Gather input values for non-const val fun
                    for inpt in next_node.inputs:
                        if inpt:  # Input is not empty
                            try:
                                # Check if input_node in all_vals
                                input_vals.append(all_values[inpt.node])
                            except KeyError:
                                """ At least one input is missing """
                                all_input_vals = False
                                stack.append(inpt.node)
                        else:
                            # This input was empty, use None as value
                            """ When can an input be empty? """
                            input_vals.append(None)
                # Got all inputs?
                if all_input_vals:
                    """ This is where val_fun is called, which itselfs chooses MPE or marginal """
                    last_val = val_fun(next_node, *input_vals)
                    all_values[next_node] = last_val
                    stack.pop()
            else:
                # VarNode, ParamNode
                """ VarNode and ParamNode don't have their own inputs, so we don't have *input_vals """
                last_val = val_fun(next_node)
                all_values[next_node] = last_val
                stack.pop()
        else:
            stack.pop()

    return last_val
```

So what are the MPE and marginal functions in `Node`? Well, take `ProdNode`'s implementation:
```python
def _compute_mpe_path(self, counts, *value_values, add_random=False, use_unweighted=False):
    """ Note sure what value_values could be, should just be input_values (right?) """
    # Check inputs
    if not self._values:
        raise StructureError("%s is missing input values." % self)

    def process_input(v_input, v_value):
        """ Get size of v_value tensor, if v_input.indices is set just returns len(v_input) """
        input_size = v_input.get_size(v_value)
        # Tile the counts if input is larger than 1
        """ 
        Tiling is used for separating counting across minibatch I guess. It tiles 
        `counts` once along the first axis, and `input_size` times along second axis, 
        meaning for each of these inputs, it provides a slot initialized at the value 
        of `counts`
        """
        return (tf.tile(counts, [1, input_size])
                if input_size > 1 else counts)

    # For each input, pass counts to all elements selected by indices
    value_counts = [(process_input(v_input, v_value), v_value)
                    for v_input, v_value
                    in zip(self._values, value_values)]
    # TODO: Scatter to input tensors can be merged with tiling to reduce
    # the amount of operations.
        
    """ Returns an op that scatters the counts to input nodes """
    return self._scatter_to_input_tensors(*value_counts)
```

Getting there, now let's see what `_scatter_to_input_tensors(*value_counts)` does:
```python
def _scatter_to_input_tensors(self, *tuples):
    """For each input, scatter the given tensor to elements indicated by
    input indices. This reverses what ``gather_input_tensors`` is doing.
    If input indices are ``None``, it adds no operations and forwards the
    tensor as is. If the input is disconnected or ``None`` is given in
    ``*tuples``, ``None`` is returned for that input.

    Args:
        *tuples (tuple): For each input, a tuple ``(tensor, input_tensor)``,
            where ``tensor`` is the tensor to be scattered, and
            ``input_tensor`` is the tensor produced by the input node. The
            second tensor is used only to retrieve the appropriate dimensions.

    Returns:
        list of Tensor: A list of tensors containing scattered values.
    """
    with tf.name_scope("scatter_to_input_tensors", values=[t[0] for t in tuples]):
        return tuple(None if not i or t is None          # (None, ...)
                     else t[0] if i.indices is None      # (t[0], ...) first element of tuple
                     else utils.scatter_cols(            # utils.scatter_cols(t[0], i.indices)
                         t[0], i.indices,
                         int(t[1].get_shape()
                             [0 if t[1].get_shape().ndims == 1 else 1]))
                     for i, t in zip(self.inputs, tuples))

```
So this returns a (potentially large) tuple containing either `None`, `t[0]` or some scatter col Op. The scatter col python binding is:

```python
def scatter_cols(params, indices, num_out_cols, name=None):
    """Scatter columns of a 2D tensor or values of a 1D tensor into a tensor
    with the same number of dimensions and ``num_out_cols`` columns or values.

    Args:
        params (Tensor): A 1D or 2D tensor.
        indices (array_like): A 1D integer array indexing the columns in the
                              output array to which ``params`` is scattered.
        num_cols (int): The number of columns in the output tensor.
        name (str): A name for the operation (optional).

    Returns:
        Tensor: Has the same dtype and number of dimensions as ``params``.
    """
    with tf.name_scope(name, "scatter_cols", [params, indices]):
        # Check input
        params = tf.convert_to_tensor(params, name="params")
        indices = np.asarray(indices)
        # Check params
        param_shape = params.get_shape()
        param_dims = param_shape.ndims
        if param_dims == 1:
            param_size = param_shape[0].value
        elif param_dims == 2:
            param_size = param_shape[1].value
        else:
            raise ValueError("'params' must be 1D or 2D")
        # We need the size defined for optimizations
        if param_size is None:
            raise RuntimeError("The indexed dimension of 'params' is not specified")
        # Check num_out_cols
        if not isinstance(num_out_cols, int):
            raise ValueError("'num_out_cols' must be integer, not %s"
                             % type(num_out_cols))
        if num_out_cols < param_size:
            raise ValueError("'num_out_cols' must be larger than the size of "
                             "the indexed dimension of 'params'")
        # Check indices
        if indices.ndim != 1:
            raise ValueError("'indices' must be 1D")
        if indices.size != param_size:
            raise ValueError("Sizes of 'indices' and the indexed dimension of "
                             "'params' must be the same")
        if not np.issubdtype(indices.dtype, np.integer):
            raise ValueError("'indices' must be integer, not %s"
                             % indices.dtype)
        if np.any((indices < 0) | (indices >= num_out_cols)):
            raise ValueError("'indices' must be smaller than 'num_out_cols'")
        if len(set(indices)) != len(indices):
            raise ValueError("'indices' cannot contain duplicates")
        # Define op
        if num_out_cols == 1:
            # Scatter to a single column tensor, it must be from 1 column
            # tensor and the indices must include it. Just forward the tensor.
            return params
        elif num_out_cols == indices.size and np.all(np.ediff1d(indices) == 1):
            # Output equals input
            return params
        elif param_size == 1:
            # Scatter a single column tensor to a multi-column tensor
            if param_dims == 1:
                # Just pad with zeros, pad is fastest and offers smallest graph
                return tf.pad(params, [[indices[0], num_out_cols - indices[0] - 1]])
            else:
                # Currently pad is fastest (for GPU) and builds smaller graph
                # if conf.custom_scatter_cols:
                #     return ops.scatter_cols(
                #         params, indices,
                #         pad_elem=tf.constant(0, dtype=params.dtype),
                #         num_out_col=num_out_cols)
                # else:
                return tf.pad(params, [[0, 0],
                                       [indices[0], num_out_cols - indices[0] - 1]])
        else:
            # Scatter a multi-column tensor to a multi-column tensor
            if param_dims == 1:
                # Use custom built op
                if conf.custom_scatter_cols:
                    return ops.scatter_cols(
                        params, indices,
                        pad_elem=tf.constant(0, dtype=params.dtype),
                        num_out_col=num_out_cols)
                else:
                    # add zero on first axis                                               
                    with_zeros = tf.concat(values=([0], params), axis=0)
                    # Init zero array for gathering indices
                    gather_indices = np.zeros(num_out_cols, dtype=int)
                    # Set gathering indices as a list [1, ..., indices.size]
                    # So we gather with a zero prepended to make sure we take a larger
                    # tensor to a smaller one?
                    gather_indices[indices] = np.arange(indices.size) + 1
                    # Use gather cols
                    return gather_cols(with_zeros, gather_indices)
            else:
                if conf.custom_scatter_cols:
                    # Use custom built op
                    return ops.scatter_cols(
                        params, indices,
                        pad_elem=tf.constant(0, dtype=params.dtype),
                        num_out_col=num_out_cols)
                else:
                    # Set zero columns
                    zero_col = tf.zeros((tf.shape(params)[0], 1),
                                        dtype=params.dtype)
                    # Concat zeros with
                    with_zeros = tf.concat(values=(zero_col, params), axis=1)
                    gather_indices = np.zeros(num_out_cols, dtype=int)
                    gather_indices[indices] = np.arange(indices.size) + 1
                    return gather_cols(with_zeros, gather_indices)


```

In [7]:
# Generates the assign operation by gathering init ops from nodes
init_weights = spn.initialize_weights(root)

# Generates marginal val op
marginal_val = root.get_value(inference_type=spn.InferenceType.MARGINAL)
mpe_val = root.get_value(inference_type=spn.InferenceType.MPE)

### Calculate Values

In [5]:

iv_x_arr = [[0, 1], # iv1 = 0, iv2 = 1
           [1, 0],  # iv1 = 1, iv2 = 0
           [1,-1],  # iv1 = 1, iv2 = NOEVIDENCE
           [-1,-1]] # iv1 = iv2 = NOEVIDENCE

iv_y_arr = [[-1]] * 4

with spn.session() as (sess, _):
    init_weights.run()
    marginal_val_arr = sess.run(marginal_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})
    mpe_val_arr = sess.run(mpe_val, feed_dict={iv_x: iv_x_arr, iv_y: iv_y_arr})
    
print(marginal_val_arr)
print(mpe_val_arr)

[[ 0.082     ]
 [ 0.52200001]
 [ 0.69      ]
 [ 1.        ]]
[[ 0.06      ]
 [ 0.21600001]
 [ 0.21600001]
 [ 0.21600001]]


### Add MPE State Ops

In [8]:
mpe_state_gen = spn.MPEState(value_inference_type=spn.InferenceType.MPE)
iv_x_state, iv_y_state = mpe_state_gen.get_state(root, iv_x, iv_y)

In [9]:
with spn.session() as (sess, _):
    init_weights.run()
    iv_x_state_arr, iv_y_state_arr = sess.run([iv_x_state, iv_y_state], 
                                              feed_dict={iv_x: [[-1,-1]], 
                                                         iv_y: [[-1]]})
    
print(iv_x_state_arr)
print(iv_y_state_arr)

[[1 0]]
[[2]]
