# Inference in Bayesian Network by Variable Elimination

---

In the alarm network, we might have the following questions:

- If there was an earthquake, how likely Mary will call you?
- If both John and Mary called you, how likely there was a burglary?
- If Mary called you, how likely John will call you as well?

<img src="img/alarm-bn.png" width=500></img>

Answering such questions is the **inference** in Bayesian network. Formally, the inference in Bayesian network is defined as follows.

> **DEFINITION**: Given a set of **query** nodes $[X_1, \dots, X_n]$, and a set of **evidence** nodes with their **observed values** $[Y_1 = y_1, \dots, Y_m = y_m]$, the **inference** is to calculate the conditional probabilities $P(X_1, \dots, X_n\ |\ Y_1 = y_1, \dots, Y_m = y_m)$.

The inference in Bayesian network is very **flexible**, and any node in the network can be a query node or an evidence node. Depending on different query and evidence nodes, some often encountered inference/reasoning scenarios are:

- **Causal reasoning**: the evidence nodes are the causes of the query nodes. This is forward reasoning.
- **Diagnostic reasoning**: the evidence nodes are the effects of the query nodes. This is backward reasoning.
- **Inter-causal reasoning**: the query nodes are common causes of the evidence nodes.

There are many different algorithms for inference in Bayesian newtork. In this tutorial, we introduce a famous algorithm called **variable elimination**, which is an exact inference algorithm (by enumeration).

## Inference by Enumeration <a name="exact"></a>

---

Let's consider the following inference scenario in the alarm network.

> **QUESTION**: What is the conditional probability of burglary, given that John calls you? That is, what is $P(B\ |\ J = t)$?

This conditional probability cannot be calculated directly from $B$ and $J$ alone. We need to consider other nodes ($E$, $A$ and $M$) in the network as well.

To calculate $P(B\ |\ J = t)$ in the alarm network, we have

- **Query** node: $B$
- **Evidence** node: $J$
- **Observation**: $J = t$
- **Hidden** nodes: $\{E, A, M\}$

After considering the hidden nodes $\{E, A, M\}$, we can calculate $P(B\ |\ J = t)$ by as follows.

$$
\begin{aligned}
& P(B\ |\ J = t) & \\
& = \frac{P(B, J = t)}{P(J = t)} & \hspace{50pt} \textrm{[product rule]} \\
& = \alpha * \sum_{E \in \{t, f\}}\sum_{A \in \{t, f\}}\sum_{M \in \{t, f\}}P(B, E, A, J = t, M) & \hspace{50pt} \textrm{[sum rule]} \\
& = \alpha * \sum_{E \in \{t, f\}}\sum_{A \in \{t, f\}}\sum_{M \in \{t, f\}}P(B) * P(E) * P(A\ |\ B, E) * P(J = t\ |\ A) * P(M\ |\ A) & \hspace{50pt} \textrm{[Factorisation]}
\end{aligned}
$$

where $\alpha = \frac{1}{P(J = t)}$ is the **normalisation factor**, and is not needed to calculate (we can simply normalise the conditional probabilities for all possible query variable values so that they add up to 1).

We can see that except $\alpha$, all the probabilities $P(B)$, $P(E)$, $P(A\ |\ B, E)$, $P(J = t\ |\ A)$, $P(M\ |\ A)$ can be directly read from the probability tables of the network. Therefore, we have successfully found a way to calculate the probability. 

<!-- In general, given the query nodes $\{X_1, \dots, X_n\}$, evidence nodes $\{Y_1 = y_1, \dots, Y_m = y_m\}$ and hidden nodes $\{H_1, \dots, H_k\}$, the inference can be done as follows.

$$
\begin{aligned}
& P(X_1, \dots, X_n\ |\ Y_1 = y_1, \dots, Y_m = y_m) \\ 
& = \frac{P(X_1, \dots, X_n, Y_1 = y_1, \dots, Y_m = y_m)}{P(Y_1 = y_1, \dots, Y_m = y_m)} & \hspace{50pt} \textrm{[product rule]} \\
& = \alpha * \sum_{[h_1, \dots, h_k] \in \\ \Omega(H_1, \dots, H_k)} P(X_1, \dots, X_n, Y_1 = y_1, \dots, Y_m = y_m, H_1 = h_1, \dots, H_k = h_k) & \hspace{50pt} \textrm{[sum rule]} \\
& = \alpha * \sum_{[h_1, \dots, h_k] \in \\ \Omega(H_1, \dots, H_k)} \prod_{i=1}^{n} P(X_i\ |\ parents(X_i)) * \prod_{i=1}^{m} P(y_i\ |\ parents(Y_i)) * \prod_{i=1}^{k} P(h_i\ |\ parents(H_i)) & \hspace{50pt} \textrm{[Factorisation]}
\end{aligned}
$$

where $\alpha = \frac{1}{P(Y_1 = y_1, \dots, Y_m = y_m)}$ is the **normalisation factor**, and is not needed to calculate (we can simply normalise the conditional probabilities for all possible query variable values so that they add up to 1). -->


## Computational Complexity

---

If we directly calculate $P(B\ |\ J = t)$, how many operations are needed? When we look at the last line (ignoring $\alpha$), 

$$
\sum_{E \in \{t, f\}}\sum_{A \in \{t, f\}}\sum_{M \in \{t, f\}}P(B) * P(E) * P(A\ |\ B, E) * P(J = t\ |\ A) * P(M\ |\ A)
$$

For each $B \in \{t, f\}$, we have $2 \times 2 \times 2 = 8$ terms to be added. In total, there are <a style="color: blue;">$2 \times 7 = 14$</a> additions.

For each $B \in \{t, f\}$, $E \in \{t, f\}$, $A \in \{t, f\}$ and $M \in \{t, f\}$, there are 5 probabilities to be multiplied, needing 4 multiplications. In total, there are <a style="color: red;">$2^4 \times 4 = 64$</a> multiplications.


<!-- In general, to calculate

$$
\sum_{[h_1, \dots, h_k] \in \\ \Omega(H_1, \dots, H_k)} \prod_{i=1}^{n} P(X_i\ |\ parents(X_i)) * \prod_{i=1}^{m} P(y_i\ |\ parents(Y_i)) * \prod_{i=1}^{k} P(h_i\ |\ parents(H_i))
$$

- For each possible query values, there are $|\Omega(H_1)| * \dots * |\Omega(H_k)|$ terms to be added, which is the number of possible value combinations of the hidden variables. There are $|\Omega(H_1)| * \dots * |\Omega(H_k)| - 1$ number of additions.
- There are $|\Omega(X_1)| * \dots * |\Omega(X_n)|$ possible query values. Therefore, **there are $|\Omega(X_1)| * \dots * |\Omega(X_n)| * (|\Omega(H_1)| * \dots * |\Omega(H_k)| - 1)$ additions in total.**
- Each term has $n+m+k$ probabilities to be multiplied, needing $n+m+k-1$ multiplications.
- There are $|\Omega(X_1)| * \dots * |\Omega(X_n)| * |\Omega(H_1)| * \dots * |\Omega(H_k)|$ terms in total. Therefore, **there are $|\Omega(X_1)| * \dots * |\Omega(X_n)| * |\Omega(H_1)| * \dots * |\Omega(H_k)| *(n+m+k-1)$ multiplications in total.** -->

In large and complex Bayesian networks, the complexity can be intractable. For example, if all the variables are binary, i.e. $|\Omega(X)| = 2$, if we have 1 query node, 3 evidence nodes and 10 hidden nodes, then the total number of multiplications will be

$$
2^{(1+10)} \times (1+3+10-1) = 26624.
$$

## Speed Up by Variable Elimination

---

The **variable elimination** algorithm is a very important exact inference algorithm that speeds up the above calculation process. The key idea is to **eliminate hidden variables as early as possible**. 

To introduce the variable elimination algorithm, we first introduce a different view of calculating the probabilities, which is **factor operations**. 

> **DEFINITION**: A **factor** of some random variables is a **table** of all the possible values of the random variables. Note that the **table value can be any function** involving the random variables.

To calculate

$$
\sum_{E \in \{t, f\}}\sum_{A \in \{t, f\}}\sum_{M \in \{t, f\}}P(B) * P(E) * P(A\ |\ B, E) * P(J = t\ |\ A) * P(M\ |\ A)
$$

We can define five initial **factors**, each for a probability in this calculation.

$f_1(B) = P(B)$:

| B | P(B) |
| - | --------------- |
| t |    0.001        |
| f |    0.999        |

$f_2(E) = P(E)$:

| E | P(E) |
| - | --------------- |
| t |    0.002        |
| f |    0.998        |  

$f3(A, B, E) = P(A\ |\ B, E)$:

| A | B | E | P(A &#124; B, E) |
| - | - | - | --------------- |
| t | t | t |   0.95        |
| f | t | t |   0.05        |
| t | t | f |   0.94        |
| f | t | f |   0.06        |
| t | f | t |   0.29        |
| f | f | t |   0.71        |
| t | f | f |   0.001        |
| f | f | f |   0.999        |    

$f_4(A) = P(J=t\ |\ A)$:

| A | P(J=t &#124; A) |
| - | --------------- |
| t |    0.9        |
| f |    0.05        | 

$f_5(M, A) = P(M\ |\ A)$:

| M | A | P(M &#124; A) |
| - | - | --------------- |
| t | t |   0.7        |
| f | t |   0.3        |  
| t | f |   0.01        |
| f | f |   0.99        |  

> **DEFINITION**: The **join** operation between two factors $f_1$ and $f_2$, denoted as $f_1 \otimes f_2$, is a table of the *union* of the variables in $f_1$ and $f_2$, where each row is the multiplication of the corresponding row of $f_1$ and $f_2$.

In the above example, $f_1(B) \otimes f_2(E) = P(B) * P(E)$ is shown as follows. It converts two 2-row tables into a 4-row table, leading to 4 multiplications.

| B | E | P(B) * P(E) |
| - | - | --------------- |
| t | t |   0.001 * 0.002 = 0.000002        |
| f | t |   0.999 * 0.002 = 0.001998      |  
| t | f |   0.001 * 0.998 = 0.000998        |
| f | f |   0.999 * 0.998 = 0.997002       | 

On the other hand, $f_4(A) \otimes f_5(M, A) = P(J = t\ |\ A) * P(M\ |\ A)$ is shown as follows.

| M | A | P(J = t &#124; A) * P(M &#124; A) |
| - | - | --------------- |
| t | t |   0.9 * 0.7 = 0.63        |
| f | t |   0.9 * 0.3 = 0.27        |  
| t | f |   0.05 * 0.01 = 0.0005       |
| f | f |   0.05 * 0.99 = 0.0495       | 

Due to the overlap between the variables of the two joined factors, the resultant table is still 4 rows, the same as the original $f_5$. In general, **the complexity of the join operator depends on the size of the joint factors and their overlapping variables**.

> **DEFINITION**: The **elimination/sum-out** operation of a factor $f$ on $X \in f$, denoted as $\sum_{X}f$, is a table of all the variables except $X$, where each row is the sum of the all the rows in $f$ with the corresponding values of the remaining variables.

For example, if we **eliminate/sum-out** $M$ in $f_4(A) \otimes f_5(M, A)$, then we can obtain the following factor $\sum_{M}(f_4(A) \otimes f_5(M, A))$, where each row is $P(J = t\ |\ A) * P(M = t\ |\ A) + P(J = t\ |\ A) * P(M = f\ |\ A)$ for each $A$ value.

| A | P(J = t &#124; A) P(M = t &#124; A) + P(J = t &#124; A) P(M = f &#124; A) |
| - | --------------- |
| t |    0.63 + 0.27 = 0.9        |
| f |    0.0005 + 0.0495 = 0.05        | 

Elimination/Sum-out can reduce the size of the factor.

Then, we can write the calculation of the conditional probabilities as the factor operations.

$$
\sum_{E \in \{t, f\}}\sum_{A \in \{t, f\}}\sum_{M \in \{t, f\}}P(B) * P(E) * P(A\ |\ B, E) * P(J = t\ |\ A) * P(M\ |\ A)
$$

$$
\sum_{E}\sum_{A}\sum_{M}f_1(B) \otimes f_2(E) \otimes f_3(A, B, E) \otimes f_4(A) \otimes f_5(M, A)
$$

Note that the order of the join and elimination operations can be swapped freely. To save computational cost, we should eliminate hidden variables as early as possible to reduce the size of the tables for later join operations. The **variable elimination** algorithm is proposed to this end.

```Python
def variable_elimination(query_nodes, evidence_nodes, observations, hidden_nodes):
    Set all_nodes = [query_nodes, evidence_nodes, hidden_nodes]
    # Initialise the factors
    factors = []
    for node in all_nodes:
        Initialise factor = P(node | parents(node)) with the observations
        Add factor into factors
        
    Sort hidden_nodes in some way
    
    # At each iteration, eliminate one hidden node
    for node in sorted_hidden_nodes:
        Join all the factors containing node
        Eliminate node from the joined factor
    
    Join all the factors containing query_nodes
    Normalise the probabilities in the final factor
    return the final factor
```

We show the process of the `variable_elimination` algorithm to calculate $P(B\ |\ J)$ through the following equation as follows. Let the order of the hidden variables be $M \rightarrow A \rightarrow E$.

$$
\begin{align}
& \sum_{E}\sum_{A}\sum_{M}f_1(B) \otimes f_2(E) \otimes f_3(A, B, E) \otimes f_4(A) \otimes f_5(M, A) \\
& = \underbrace{f_1(B) \otimes \underbrace{\sum_{E}\Big( f_2(E) \otimes \underbrace{\sum_{A}\big( f_3(A, B, E) \otimes f_4(A) \otimes \underbrace{\sum_{M} f_5(M, A)}_{f_6(A)}\big)}_{f_8(B, E)}\Big)}_{f_{10}(B)}}_{f_{11}(B)}
\end{align}
$$

**Iteration 1**: Eliminate $M$ from $f_5(M, A)$, which is the only factor containing $M$, to get $f_6(A)$. It costs <a style="color: blue;">2 additions</a>.

| A | f6(A) |
| - | --------------- |
| t |    0.7 + 0.3 = 1.0        |
| f |    0.01 + 0.99 = 1.0        |

**Iteration 2**: 

First, join all the factors containing $A$, $f_3(A, B, E)$, $f_4(A)$ and $f_6(A)$, to obtain $f_7(A, B, E)$. It costs <span style="color: red;">16 multiplications</span>.
    
| A | B | E | f7(A, B, E) |
| - | - | - | --------------- |
| t | t | t |   0.95 &#042; 0.9 * 1.0 = 0.855       |
| f | t | t |   0.05 &#042; 0.05 * 1.0 = 0.0025       |
| t | t | f |   0.94 &#042; 0.9 * 1.0 = 0.846      |
| f | t | f |   0.06 &#042; 0.05 * 1.0 = 0.003      |
| t | f | t |   0.29 &#042; 0.9 * 1.0 = 0.261      |
| f | f | t |   0.71 &#042; 0.05 * 1.0 = 0.0355      |
| t | f | f |   0.001 &#042; 0.9 * 1.0 = 0.0009      |
| f | f | f |   0.999 &#042; 0.05 * 1.0 = 0.04995      | 
    
Second, Eliminate $A$ from $f_7(A, B, E)$ to obtain $f_8(B, E)$. It costs <span style="color: blue;">4 additions</span>.
    
| B | E | f8(B, E) |
| - | - | --------------- |
| t | t |   0.855 + 0.0025 = 0.8575          |
| f | t |   0.261 + 0.0355 = 0.2965       |  
| t | f |   0.846 + 0.003 = 0.849       |
| f | f |   0.0009 + 0.04995 = 0.05085    | 
    
**Iteration 3**:

First, join all the factors containing $E$, $f_2(E)$ and $f_8(B, E)$, to obtain $f_9(B, E)$. It costs <span style="color: red;">4 multiplications</span>.
    
| B | E | f9(B, E) |
| - | - | --------------- |
| t | t |   0.002 * 0.8575 = 0.001715        |
| f | t |   0.002 * 0.2965 = 0.000593      |  
| t | f |   0.998 * 0.849 = 0.847302        |
| f | f |   0.998 * 0.05085 = 0.0507483       | 
    
Second, eliminate $E$ from $f_9(B, E)$ to obtain $f_{10}(B)$. It costs <span style="color: blue;">2 additions</span>.
    
| B | f10(B) |
| - | --------------- |
| t |    0.001715 + 0.847302 = 0.849017       |
| f |    0.000593 + 0.0507483 = 0.0513413        |

**Iteration 4**: Join all the factors containing $B$, $f_1(B)$ and $f_{10}(B)$, to obtain $f_{11}(B)$. It costs <span style="color: red;">2 multiplications</span>.

| B | f11(B) |
| - | --------------- |
| t |    0.001 * 0.849017 = 0.000849017       |
| f |    0.999 * 0.0513413 = 0.0512899587        |

In total, the variable elimination costs only <span style="color: blue;">8 additions</span> and <span style="color: red;">22 multiplications</span>, which is much smaller than the original 14 additions and 64 multiplications.

**Finally**, we normalise the probabilities in $f_{11}(B)$ to obtain the final factor:

| B | norm f11(B) |
| - | --------------- |
| t |    0.000849017 / (0.000849017 + 0.0512899587) = 0.01628372994      |
| f |    0.0512899587 / (0.000849017 + 0.0512899587) = 0.98371627005        |

Let's verify the results by the `VariableElimination` function in the `pgmpy` library. First, we build the alarm network as follows (requiring `pgmpy` library installed. For installing `pgmpy`, run `pip install pgmpy`).

In [1]:
from pgmpy.models import BayesianNetwork
from pgmpy.factors.discrete import TabularCPD

# Define the network structure
alarm_model = BayesianNetwork(
    [
        ("Burglary", "Alarm"),
        ("Earthquake", "Alarm"),
        ("Alarm", "JohnCall"),
        ("Alarm", "MaryCall"),
    ]
)

# Define the probability tables by TabularCPD
cpd_burglary = TabularCPD(
    variable="Burglary", variable_card=2, values=[[0.999], [0.001]]
)

cpd_earthquake = TabularCPD(
    variable="Earthquake", variable_card=2, values=[[0.998], [0.002]]
)

cpd_alarm = TabularCPD(
    variable="Alarm",
    variable_card=2,
    values=[[0.999, 0.71, 0.06, 0.05], [0.001, 0.29, 0.94, 0.95]],
    evidence=["Burglary", "Earthquake"],
    evidence_card=[2, 2],
)

cpd_johncall = TabularCPD(
    variable="JohnCall",
    variable_card=2,
    values=[[0.95, 0.1], [0.05, 0.9]],
    evidence=["Alarm"],
    evidence_card=[2],
)

cpd_marycall = TabularCPD(
    variable="MaryCall",
    variable_card=2,
    values=[[0.99, 0.3], [0.01, 0.7]],
    evidence=["Alarm"],
    evidence_card=[2],
)

# Associating the probability tables with the model structure
alarm_model.add_cpds(
    cpd_burglary, cpd_earthquake, cpd_alarm, cpd_johncall, cpd_marycall
)

Then, we run the `VariableElimination.query` method to calculate $P(B\ |\ J = t)$.

In [2]:
from pgmpy.inference import VariableElimination

alarm_infer = VariableElimination(alarm_model)

q = alarm_infer.query(variables=["Burglary"], evidence={"JohnCall": 1})
print(q)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

+-------------+-----------------+
| Burglary    |   phi(Burglary) |
| Burglary(0) |          0.9837 |
+-------------+-----------------+
| Burglary(1) |          0.0163 |
+-------------+-----------------+


We can see that the `pgmpy` library gives the same results, which verifies the correctness of our calculation.

---

- More tutorials can be found [here](https://github.com/meiyi1986/tutorials).
- [Yi Mei's homepage](https://meiyi1986.github.io/)