In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
N = 6
K = 5

G = np.array([[1, 0, 0, 0, 0],
              [1, 0, 1, 0, 1],
              [0, 1, 1, 1, 0],
              [0, 1, 0, 0, 0],
              [0, 0, 0, 0, 1],
              [0, 0, 0, 1, 0],
              ])
assert G.shape == (N, K)

In [5]:
F = np.ones(shape=(K, 2, 2))
F[0, 1, 1] = 5
F[1, 0, 1] = 0.5
F[2, 0, 0] = 0
F[3, 0, 0] = 2
# F(j, p, q):
# value of factor j when variable with the smaller numerical index is p and variable with larger numerical index is q

# a

![Graph for Homework 2, a](hw2_a_graph.png)

The graph is singly connected, since there's only one path from any given variable $u$ to another variable $f$, for any $u$ and $f$.

# b

In [6]:
def computeMarginals(G, F):
    """
    Compute marginal probabilities of all variables in the graph. Return: N x 2 matrix of
    marginal probabilities.
    """
    
    # create variables for use later
    VF = np.ones(shape=(N, K, 2))
    FV = np.ones(shape=(K, N, 2))
    Vsent = np.zeros(shape=(N, K))
    Fsent = np.zeros(shape=(K, N))
    forward_pass_order = []
    
    def send_message(node1, node2, was_variable, root):
        """
        Send message between node1 and node2. Must know whether node1 is a variable or not,
        as well as what the root node is.
        """
        if was_variable:
            # then node1 was a variable and node2 was a factor
            
            # make sure the message was not sent yet
            assert Vsent[node1, node2] == 0
            
            # find the factors that this variable is connected to
            connections = np.where(G[node1, :] == 1)[0]
            if len(connections) > 1:
                result = np.ones(shape=(2,))
                # multiply the connections' messages, except for the factor we are currently messaging
                for connection in connections:
                    if connection == node2:
                        # skip current factor
                        continue
                    result *= FV[connection, node1, :]
                # save the message
                VF[node1, node2, :] = result
            else:
                # if we only have one connection, it's a leaf variable so it returns 1
                VF[node1, node2, :] = 1
            # keep track of which messages were sent from variable to factor
            Vsent[node1, node2] = 1
        else:
            # then node1 was a factor and node2 was a variable
            
            # make sure the message was not sent yet
            assert Fsent[node1, node2] == 0
            
            # find the (two) variables that this factor is connected to
            connections = np.where(G[:, node1] == 1)[0]
            if node2 == connections[0]:
                # if we're messaging the lower indexed variable, we need to get the message from the higher indexed variable
                message_sent_to_factor = VF[connections[1], node1, :]
                
                # sum-product of factors with the message that was sent, equivalent to matrix multiplication
                FV[node1, node2, :] = (F[node1, :, :] @ message_sent_to_factor)
            else:
                # if we are messaging the higher indexed variable
                message_sent_to_factor = VF[connections[0], node1, :]
                # have to take the transpose since we are messaging the higher indexed variable
                FV[node1, node2, :] = (F[node1, :, :].T @ message_sent_to_factor)
            
            # keep track of which messages were sent from factor to variable
            Fsent[node1, node2] = 1
    
    def traverse_graph(index, vector, is_variable):
        """
        Traverse factor graph and find out what dependencies exist between factors and variables.
        Saves dependencies in `forward_pass_order`. Start from the root and recursively find
        dependencies.
        """
        
        if not vector.any():
            # base case
            if is_variable:
                # if we hit a variable leaf, then we have no dependencies
                return
            else:
                # we don't have factor leaves, but we could have them. we can use this for observed variables
                raise NotImplementedError
        
        if is_variable:
            # if we are starting out with a variable
            result = 1
            for factor_index, factor in enumerate(vector):
                if factor:
                    # for each connected factor (as written in `G`), introduce a dependency in `forward_pass_order`
                    forward_pass_order.append((index, factor_index, is_variable))
                    
                    # remove the connection we already operated over
                    without_connected_variable = G[:, factor_index].copy()
                    without_connected_variable[index] = 0
                    
                    # find other dependencies of that factor recursively
                    traverse_graph(factor_index,
                                   without_connected_variable,
                                   is_variable=False,
                                   )
            return
        else:
            # if we are starting out with a factor
            for variable_index, variable in enumerate(vector):
                if variable:
                    # for each connected variable (as written in `G`), introduce a dependency
                    forward_pass_order.append((index, variable_index, is_variable))
                    
                    # remove the connection we already looked at
                    without_connected_factor = G[variable_index, :].copy()
                    without_connected_factor[index] = 0
                    
                    # find other dependencies of that variable recursively
                    traverse_graph(variable_index,
                                   without_connected_factor,
                                   is_variable=True,
                                   )
            return
    
    # graph traversal starting from root node 6 (index `5`)
    root = 5
    vector = G[root, :]
    is_variable = True
    
    # traverse the graph
    traverse_graph(root, vector, is_variable)
    
    # in opposite order - i.e. from leaves to the root -, send messages
    for node1, node2, was_variable in forward_pass_order[::-1]:
        send_message(node2, node1, not was_variable, root)
    
    # send messages from root to leaves
    for node1, node2, was_variable in forward_pass_order:
        send_message(node1, node2, was_variable, root)

    # end state: make sure that all messages were sent that had to be sent
    assert np.equal(Vsent, G).all()
    assert np.equal(Fsent, G.T).all()
    
    # initialize variable that keeps our marginals
    B = np.zeros(shape=(N, 2))
    
    for variable in range(N):
        # for each variable, marginalize by taking the product of all factors that are connected to it
        B[variable, :] = FV[:, variable, :].prod(axis=0)
        
    # normalize
    # have to transpose it so that the broadcasting works properly...
    B = (B.T / B.sum(axis=1)).T
    
    # return marginals
    return B

B = computeMarginals(G, F)
B

array([[0.21186441, 0.78813559],
       [0.13559322, 0.86440678],
       [0.45762712, 0.54237288],
       [0.57627119, 0.42372881],
       [0.5       , 0.5       ],
       [0.57627119, 0.42372881]])

# c

If I only care about the marginal probability of one of the variables, I would modify the algorithm to have that variable's index as its root (where my implementation says `root = 5`). I would also not require the backward pass through the graph, so only one graph traversal is required. Finally, I would skip the calculation of marginal probabilities for every variable at the end (where my implementation says `for variable in range(N):`) and instead just calculate the marginal probability for the specific index of the variable I want.

# d

In [7]:
def bruteForce(G, F):
    """
    Enumerate all 2^N configurations of variables. Return: N x 2 matrix of marginal probabilities.
    """
    
    possibilities = np.zeros(shape=(2,) * N)
    
    binary_possibilities = [0, 1]
    
    def get_possibilities(index, *args):
        """
        Recursively find all possibilities and their individual probabilities.
        """
        if index == N - 1:
            # if we're the last index (i.e. base case)
            for var in binary_possibilities:
                new_args = tuple([var]) + args
                result = 1
                
                for factor_index, factor in enumerate(F):
                    # find the connections for each factor
                    connections = G[:, factor_index]
                    connection_indices = np.where(connections == 1)[0]
                    first_index = connection_indices[0]
                    second_index = connection_indices[1]
                    
                    # multiply the result by the factors' entries in `F`
                    result *= factor[new_args[-(first_index + 1)], new_args[-(second_index + 1)]]
                # save the probability
                possibilities[new_args[::-1]] = result
            return
        
        for var in binary_possibilities:
            # recursively count up to the highest index N - 1
            get_possibilities(index + 1, var, *args)
    
    # get the possibilities and their probabilities, starting from the variable indexed 0
    get_possibilities(0)

    # normalize
    possibilities = possibilities / possibilities.sum()
    
    B = np.zeros(shape=(N, 2))
    
    for variable in range(N):
        # marginalize by summing probabilities across all axes except the one of the variable (of course)
        list_of_axes = list(range(N))
        list_of_axes.remove(variable)
        B[variable, :] = possibilities.sum(axis=tuple(list_of_axes))
    return B

B_bruteforce = bruteForce(G, F)
B_bruteforce

array([[0.21186441, 0.78813559],
       [0.13559322, 0.86440678],
       [0.45762712, 0.54237288],
       [0.57627119, 0.42372881],
       [0.5       , 0.5       ],
       [0.57627119, 0.42372881]])

In [8]:
B

array([[0.21186441, 0.78813559],
       [0.13559322, 0.86440678],
       [0.45762712, 0.54237288],
       [0.57627119, 0.42372881],
       [0.5       , 0.5       ],
       [0.57627119, 0.42372881]])

In [12]:
%timeit computeMarginals(G, F)

355 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
%timeit bruteForce(G, F)

1.94 ms ± 8.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### Comparison

`B` and `B_bruteforce` yield the identical results. However, `B` only requires two graph traversals to calculate the marginals, whereas `B_bruteforce` requires listing and calculating the marginals for every combination. This makes it so that the brute force approach takes nearly ten times as long as the sum-product algorithm in our example.

# e

You could condition on observed variables by adding a factor on those variables that takes only a single value, multiplying the variable by e.g. (0, 1) so that only one value is possible. In the code above, this would have to be implemented in the `raise NotImplementedError` line. This was not implemented since we did not have any factor leaves in our tree.

Obviously, we would also have to update the matrices $G$ and $F$.