<a href="https://colab.research.google.com/github/r-doz/PML2025/blob/main/./04_exact_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exact inference with Belief Propagation

This notebook is inspired from [Jessica Stringham's work](https://jessicastringham.net)

We are going to perform inference through the sum-product message passing, or belief propagation, on tree-like factor graphs (without any loop). We work only with discrete distributions and without using ad-hoc libraries, to better understand the algorithm.

In [9]:
import numpy as np

### Probability distributions

First of all, we need to represent a discrete probability distribution and check that it is normalized.
For example, we can represent a discrete conditional distribution $p(v_1 | h_1)$ with a 2D array, as:

|   | $h_1=a$ | $h_1=b$ | $h_1=c$|
|---|-----|-----|----|
| $v_1=0$ | 0.4  | 0.8  | 0.9|
| $v_1=1$ | 0.6 | 0.2  | 0.1|

We can build a class for the distributions containing the arrays and the labels of the axes


In [None]:
class Distribution():

    """"
    Discrete probability distributions, expressed using labeled arrays
    probs: array of probability values
    axes_labels: list of axes names
    """
    
    def __init__(self, probs, axes_labels):
        self.probs = probs
        self.axes_labels = axes_labels

    def get_axes(self):
        # returns a dictionary with axes names and the corresponding coordinates
        return {name: axis for axis, name in enumerate(self.axes_labels)}
    
    def get_other_axes_from(self, axis_label):
        # returns a tuple containing all the axes except from axis_label
        return tuple(axis for axis, name in enumerate(self.axes_labels) if name != axis_label)
    
    def is_valid_conditional(self, variable_name):
        # variable_name is the name of the variable for which we are computing the distribution, e.g. in p(y|x) it is 'y'
        return np.all(np.isclose(np.sum(self.probs, axis=self.get_axes()[variable_name]), 1.0))
    
    def is_valid_joint(self):
        return np.all(np.isclose(np.sum(self.probs), 1.0))

In [None]:
# Let's see the previous distribution:

p_v1_given_h1 = Distribution(np.array([[0.4, 0.8, 0.9], [0.6, 0.2, 0.1]]), ['v1', 'h1'])

print('Is p(v1|h1) a valid conditional distribution? ', p_v1_given_h1.is_valid_conditional('v1'))
print('Is p(v1|h1) a valid joint distribution? ', p_v1_given_h1.is_valid_joint())

# Consider also a joint distribution and a conditional distribution with more than one 'given' variables

p_h1 = Distribution(np.array([0.6, 0.3, 0.1]), ['h1'])

print('Is p(h1) a valid conditional distribution? ', p_h1.is_valid_conditional('h1'))
print('Is p(h1) a valid joint distribution? ', p_h1.is_valid_joint())

p_v1_given_h0_h1 = Distribution(np.array([[[0.9, 0.2, 0.7], [0.3, 0.2, 0.5]],[[0.1, 0.8, 0.3], [0.7, 0.8, 0.5]]]), ['v1', 'h0', 'h1'])
print('Is p(v1|h1, h2) a valid conditional distribution? ', p_v1_given_h0_h1.is_valid_conditional('v1'))
print('Is p(v1|h1, h2) a valid joint distribution? ', p_v1_given_h0_h1.is_valid_joint())

Is p(v1|h1) a valid conditional distribution?  True
Is p(v1|h1) a valid joint distribution?  False
Is p(h1) a valid conditional distribution?  True
Is p(h1) a valid joint distribution?  True
Is p(v1|h1, h2) a valid conditional distribution?  True
Is p(v1|h1, h2) a valid joint distribution?  False


We need to allow multiplications between distributions like $p(v_1|h_1,...,h_n) p(h_i)$, where $p(h_i)$ is a 1D array.
To do it, we can exploit broadcasting. But first, we need to reshape $p(h_i)$ accordingly to the dimension $h_i$ of the distribution $p(v_1|h_1,...,h_n)$

In [None]:
def multiply(p_v_given_h, p_hi):
    ''' 
    Compute the product of the distributions p(v|h1,..,hn)p(hi) where p(hi) is a 1D array
    '''

    # Get the axis corresponding to hi in the conditional distribution
    axis = p_v_given_h.get_axes()[next(iter(p_hi.get_axes()))]

    # Reshape p(hi) in order to exploit broadcasting. Consider also the case in which p(hi) is a scalar.
    if p_hi.probs.shape != ():  # Check if p_hi is not a scalar
        reshaped_p_hi = p_hi.probs.reshape([-1 if i == axis else 1 for i in range(p_v_given_h.probs.ndim)])
    else:
        reshaped_p_hi = p_hi.probs  # Scalar, no reshaping needed

    return Distribution(p_v_given_h.probs * reshaped_p_hi, p_v_given_h.axes_labels)


In [None]:
p_v1_h1 = multiply(p_v1_given_h1, p_h1)
print(p_v1_h1.probs)
print(p_v1_h1.is_valid_joint())

p_v1_h1_given_h0 = multiply(p_v1_given_h0_h1, p_h1)
print(p_v1_h1_given_h0.probs)

[[0.24 0.24 0.09]
 [0.36 0.06 0.01]]
True
[[[0.54 0.06 0.07]
  [0.18 0.06 0.05]]

 [[0.06 0.24 0.03]
  [0.42 0.24 0.05]]]


### Factor graphs

Factor graphs are bipartite graphs, with variable nodes and factor nodes. Edges can only connect nodes of different type. Consider for example:

![factor_ex](imgs/factor_example.png)



In [None]:
class Node(object):
    
    def __init__(self, name):
        """
        Initialize a node with a given name and an empty list of neighbors.
        
        :param name: The name of the node (string).
        """
        self.name = name
        self.neighbors = []

    def is_valid_neighbor(self, neighbor):
        """
        This method should be implemented in subclasses to define valid neighbor relationships.
        """
        raise NotImplemented()

    def add_neighbor(self, neighbor):
        """
        Adds a neighbor to the node after checking if it is a valid connection.
        
        :param neighbor: The node to be added as a neighbor.
        :raises AssertionError: If the neighbor is not valid according to is_valid_neighbor.
        """
        assert self.is_valid_neighbor(neighbor)
        self.neighbors.append(neighbor)


class Variable(Node):
    def is_valid_neighbor(self, factor):
        """
        Checks if the given neighbor is a valid Factor.
        
        :param factor: The node to check.
        :return: True if the neighbor is an instance of Factor, otherwise False.
        """
        return isinstance(factor, Factor)  # Variables can only be connected to Factors


class Factor(Node):
    def is_valid_neighbor(self, variable):
        """
        Checks if the given neighbor is a valid Variable.
        
        :param variable: The node to check.
        :return: True if the neighbor is an instance of Variable, otherwise False.
        """
        return isinstance(variable, Variable)  # Factors can only be connected to Variables

    def __init__(self, name):
        """
        Initialize a Factor node, which extends Node and contains additional data storage.
        
        :param name: The name of the factor node.
        """
        super(Factor, self).__init__(name)
        self.data = None  # Placeholder for storing factor-specific data


We can build some parsing methods in order to create a factor graph from a string representing the factorization of the joint probability distribution

In [None]:
from collections import namedtuple

# Define a named tuple `ParsedTerm` to represent parsed terms in a structured way.
# This will store:
# - `term`: The main term being parsed.
# - `var_name`: The variable name associated with the term.
# - `given`: The conditions (if any) under which the term is considered.

ParsedTerm = namedtuple('ParsedTerm', [
    'term',      # The main term (e.g., probability expression)
    'var_name',  # The variable name involved in the term
    'given',     # Any conditions or dependencies (e.g., given another variable)
])

def _parse_term(term):
    """
    Parses a term of the form "(a|b,c)" and extracts variables and conditioned-on variables.
    
    :param term: A string representing a probability term in the format "(var|given1,given2,...)" 
                 or simply "(var)" if there are no conditions.
    :return: A tuple (var, given), where:
             - var: A list of variables being considered.
             - given: A list of variables that the term is conditioned on.
    """
    
    # Ensure the term starts with '(' and ends with ')'
    assert term[0] == '(' and term[-1] == ')', "Term must be enclosed in parentheses"
    
    # Extract the content inside the parentheses
    term_variables = term[1:-1]

    # Handle conditional probability notation (i.e., presence of '|')
    if '|' in term_variables:
        var, given = term_variables.split('|')  # Split into variable and conditions
        var = var.split(',')  # Convert variable part into a list
        given = given.split(',')  # Convert given (conditional) part into a list
    else:
        var = term_variables.split(',')  # No conditions, just split the variable(s)
        given = []  # No conditioned-on variables

    return var, given  # Return parsed variable(s) and given condition(s)


def _parse_model_string_into_terms(model_string):
    """
    Parses a model string into a list of ParsedTerm objects.

    :param model_string: A string representing a probabilistic model, where 
                         terms are prefixed with 'p' (e.g., "p(A|B)p(B)").
    :return: A list of ParsedTerm namedtuples, each containing:
             - 'term': The probability expression (e.g., "p(A|B)").
             - 'var_name': The main variable(s) in the term.
             - 'given': The conditional variables (if any).
    """
    
    return [
        ParsedTerm('p' + term, *_parse_term(term))  # Prepend 'p' to maintain original term format
        for term in model_string.split('p')  # Split terms based on 'p' delimiter
        if term  # Ignore empty terms from splitting
    ]

def parse_model_into_variables_and_factors(model_string):
    """
    Parses a probabilistic model string into a dictionary of variables and a list of factors.

    :param model_string: A string representing a probabilistic model, where terms follow the format
                         "p(h1)p(h2|h1)p(v1|h1)p(v2|h2)".
    :return: A tuple (factors, variables) where:
             - factors: A list of Factor objects representing the relationships in the model.
             - variables: A dictionary mapping variable names to Variable objects.
    """
    
    # Step 1: Parse the model string into ParsedTerm objects
    parsed_terms = _parse_model_string_into_terms(model_string)
    
    # Step 2: Extract all unique variables from the model and store them in a dictionary
    variables = {}

    for parsed_term in parsed_terms:
        # Iterate over all variables in the parsed term
        for term in parsed_term.var_name:
            # If the variable hasn't been seen before, create a new Variable object
            if term not in variables:
                variables[term] = Variable(term)

    # Step 3: Create Factor objects from the parsed terms and establish neighbor relationships
    factors = []
    
    for parsed_term in parsed_terms:
        # Create a new Factor object using the full probability expression (e.g., "p(v1|h1)")
        new_factor = Factor(parsed_term.term)

        # Get all variable names involved in this factor (both the main variable and the given conditions)
        all_var_names = parsed_term.var_name + parsed_term.given

        for var_name in all_var_names:
            # Connect the factor to the corresponding Variable object
            new_factor.add_neighbor(variables[var_name])
            # Connect the Variable object to the Factor
            variables[var_name].add_neighbor(new_factor)

        # Store the new factor
        factors.append(new_factor)

    return factors, variables


We can combine factor nodes and variable nodes to create a factor graph and add a distribution to each factor node.

In [None]:
class PGM(object):
    """
    Represents a Probabilistic Graphical Model (PGM) with variables and factors.
    """

    def __init__(self, factors, variables):
        """
        Initializes the PGM with a set of factors and variables.

        :param factors: A list of Factor objects representing conditional dependencies.
        :param variables: A dictionary mapping variable names to Variable objects.
        """
        self._factors = factors
        self._variables = variables

    @classmethod
    def from_string(cls, model_string):
        """
        Constructs a PGM from a model string.

        :param model_string: A string defining the probabilistic model, 
                             e.g., "p(h1)p(h2|h1)p(v1|h1)p(v2|h2)".
        :return: An instance of PGM initialized with the parsed factors and variables.
        """
        factors, variables = parse_model_into_variables_and_factors(model_string)
        return PGM(factors, variables)

    def set_distributions(self, data):
        """
        Assigns probability distributions to the factors in the PGM.

        :param data: A dictionary mapping factor names to probability distributions.
                     Each entry contains a distribution with 'axes_labels' and 'probs'.
        :raises ValueError: If a factor's expected axes do not match the provided data.
        """
        var_dims = {}  # Dictionary to track variable dimensions across factors

        for factor in self._factors:
            factor_data = data[factor.name]  # Retrieve the corresponding data distribution

            # Ensure that all expected axes (variables) are present in the data
            if set(factor_data.axes_labels) != set(v.name for v in factor.neighbors):
                missing_axes = set(v.name for v in factor.neighbors) - set(data[factor.name].axes_labels)
                raise ValueError("data[{}] is missing axes: {}".format(factor.name, missing_axes))

            # Check and store variable dimensions to ensure consistency
            for var_name, dim in zip(factor_data.axes_labels, factor_data.probs.shape):
                if var_name not in var_dims:
                    var_dims[var_name] = dim  # Store dimension for the first time
                elif var_dims[var_name] != dim:
                    raise ValueError(
                        "data[{}] axes is wrong size, {}. Expected {}".format(
                            factor.name, dim, var_dims[var_name]
                        )
                    )

            # Assign the data to the factor
            factor.data = data[factor.name]

    def variable_from_name(self, var_name):
        """
        Retrieves a Variable object by its name.

        :param var_name: The name of the variable to retrieve.
        :return: The corresponding Variable object.
        """
        return self._variables[var_name]


We can notice that, in the previous example, we can write the marginal as a combination of sums and products:

$$p(x_5) = \sum_{x_1, x_2, x_3, x_4}p(x_1, x_2, x_3, x_4, x_5) =\\ = \sum_{x_3, x_4}f_3(x_3,x_4,x_5)\bigg[\sum_{x_1}f_1(x_1, x_3)\bigg]\bigg[\sum_{x_2}f_2(x_2, x_3)\bigg]$$

and interpret them as messages flowing from factors to variables (including a summation) or from variables to factors (via multiplication).

In [None]:
class Messages(object):
    """
    Handles message passing in a probabilistic graphical model.
    Implements variable-to-factor and factor-to-variable message computations 
    for belief propagation.
    """

    def __init__(self):
        """
        Initializes an empty dictionary to store messages between variables and factors.
        """
        self.messages = {}

    def _variable_to_factor_messages(self, variable, factor):
        """
        Computes the message from a variable to a factor.
        
        :param variable: The Variable object sending the message.
        :param factor: The Factor object receiving the message.
        :return: The computed message (a probability distribution).
        
        - The message is computed as the product of all incoming messages to the variable,
          excluding messages coming from the factor itself.
        - If no messages exist (base case), return a uniform distribution (or 1).
        """
        # TODO: Implement message computation logic here.
        return 

    def _factor_to_variable_messages(self, factor, variable):
        """
        Computes the message from a factor to a variable.
        
        :param factor: The Factor object sending the message.
        :param variable: The Variable object receiving the message.
        :return: The computed message (a probability distribution).
        
        - The message is computed by multiplying the factor's probability distribution
          with all incoming messages from other neighboring variables.
        - The result is then marginalized over all variables except the target variable.
        """

        # Create a deep copy of the factor's distribution to avoid modifying the original
        factor_dist = Distribution(factor.data.probs, factor.data.axes_labels)

        for neighbor_variable in factor.neighbors:
            if neighbor_variable.name == variable.name:
                continue  # Skip the target variable itself

            # Retrieve the incoming message from the variable to the factor and multiply
            # TODO: Implement message retrieval and multiplication here

        # Sum over all axes except for the target variable to marginalize them out
        # TODO: Implement marginalization here

        return 

    def marginal(self, variable):
        """
        Computes the marginal probability distribution of a variable.
        
        :param variable: The Variable object whose marginal is being computed.
        :return: The normalized marginal probability distribution.
        
        - The marginal is proportional to the product of all incoming messages to the variable.
        - The result is then normalized to ensure it represents a valid probability distribution.
        """

        # TODO: Implement marginal computation logic here

        # Normalize the resulting probability distribution before returning
        return 

    def variable_to_factor_messages(self, variable, factor):
        """
        Retrieves or computes the message from a variable to a factor.
        
        :param variable: The Variable object sending the message.
        :param factor: The Factor object receiving the message.
        :return: The computed or cached message.
        """
        message_name = (variable.name, factor.name)
        
        if message_name not in self.messages:
            self.messages[message_name] = self._variable_to_factor_messages(variable, factor)

        return self.messages[message_name]

    def factor_to_variable_message(self, factor, variable):
        """
        Retrieves or computes the message from a factor to a variable.
        
        :param factor: The Factor object sending the message.
        :param variable: The Variable object receiving the message.
        :return: The computed or cached message.
        """
        message_name = (factor.name, variable.name)
        
        if message_name not in self.messages:
            self.messages[message_name] = self._factor_to_variable_messages(factor, variable)

        return self.messages[message_name]


We can try to build the following factor graph:

![factor1](imgs/factor2.png)

In [None]:
p_h1 = Distribution(np.array([[0.2], [0.8]]), ['h1'])
p_h2_given_h1 = Distribution(np.array([[0.5, 0.2], [0.5, 0.8]]), ['h2', 'h1'])
p_v1_given_h1 = Distribution(np.array([[0.6, 0.1], [0.4, 0.9]]), ['v1', 'h1'])
p_v2_given_h2 = Distribution(p_v1_given_h1.probs, ['v2', 'h2'])

pgm = PGM.from_string("p(h1)p(h2|h1)p(v1|h1)p(v2|h2)")

pgm.set_distributions({
    "p(h1)": p_h1,
    "p(h2|h1)": p_h2_given_h1,
    "p(v1|h1)": p_v1_given_h1,
    "p(v2|h2)": p_v2_given_h2,
})

And compute the marginal distribution $p(v_2)$

In [None]:
pgm = PGM.from_string("p(h1)p(h2|h1)p(v1|h1)p(v2|h2)")

pgm.set_distributions({
    "p(h1)": p_h1,
    "p(h2|h1)": p_h2_given_h1,
    "p(v1|h1)": p_v1_given_h1,
    "p(v2|h2)": p_v2_given_h2,
})

m = Messages()
m.marginal(pgm.variable_from_name('v2'))

array([0.23, 0.77])

In [None]:
m.messages

{('p(h1)', 'h1'): array([0.2, 0.8]),
 ('v1', 'p(v1|h1)'): 1.0,
 ('p(v1|h1)', 'h1'): array([1., 1.]),
 ('h1', 'p(h2|h1)'): array([0.2, 0.8]),
 ('p(h2|h1)', 'h2'): array([0.26, 0.74]),
 ('h2', 'p(v2|h2)'): array([0.26, 0.74]),
 ('p(v2|h2)', 'v2'): array([0.23, 0.77])}

In [None]:
m.marginal(pgm.variable_from_name('v1'))

array([0.2, 0.8])

### Exercise 1

(From Bayesian Reasoning and Machine Learning, David Barber) You live in a house with three rooms, labelled 1, 2, 3. There is a door between rooms 1 and 2 and another between rooms 2 and 3. One cannot directly pass between rooms 1 and 3 in one time-step. An annoying fly is buzzing from one room to another and there is some smelly cheese in room 1 which seems to attract the fly more. Using $x_t$ for which room the fly is in at time t, with $dom(x_t) = {1,2,3}$, the movement of the fly can be described by a transition:
$p(x_{t+1} = i|x_t = j) = M_{ij}$

where M is a transition matrix:

$$
\begin{bmatrix}
0.7 & 0.5 & 0 \\
0.3 & 0.3 & 0.5 \\
0 & 0.2 & 0.5 \\
\end{bmatrix}
$$

Given that the fly is in room 1 at time 1, what is the probability of room occupancy at time t = 5? Assume a Markov chain which is defined by the joint distribution

$p(x_1, . . . , x_T ) = p(x_1) \prod p(x_{t+1}|x_t)$

We are asked to compute $p(x_5|x_1 = 1)$ which is given by
$\sum p(x_5|x_4)p(x_4|x_3)p(x_3|x_2)p(x_2|x_1 = 1)$




### Exercise 2: Hidden Markov Models

Imagine you're trying to guess someone's mood without directly asking them or using brain electrodes. Instead, you observe their facial expressions, whether they're smiling or frowning, to make an educated guess.

We assume moods can be categorized into two states: good and bad. When you meet someone for the first time, there's a 70% chance they're in a good mood and a 30% chance they're in a bad mood.

If someone is in a good mood, there's an 80% chance they'll stay in a good mood and a 20% chance they'll switch to a bad mood over time. The same probabilities of switching the mood apply if they start in a bad mood.

Lastly, when someone is in a good mood, they're 90% likely to smile and 10% likely to frown. Conversely, if they're in a bad mood, they have a 10% chance of smiling and a 90% chance of frowning.

The transitions are summarized in the following graph.

Your task is to use these probabilities to figure out the first and second hidden mood states (the probability that the first mood is good/bad and the probability that the second mood is good/bad) based on the observable facial expressions you see (imagine you see the sequence [smiling, frowning]).

![factor1](imgs/mood.png)
(image by Y. Natsume)