# Factor Elimination and Jointrees

**COMP9418-19T3, W08 Tutorial**

- Instructor: Gustavo Batista
- School of Computer Science and Engineering, UNSW Sydney
- Last Update 2nd November at 18:00, 2020

In this week's tutorial, we will implement the Factor Elimination (FE) algorithm using elimination tree and jointrees. This algorithm will allow us to answer queries about cluster marginals for all clusters in the elimination tree with the same time complexity of the Variable Elimination algorithm.

## Technical prerequisites

You will need certain packages installed to run this notebook.

If you are using ``conda``'s default
[full installation](https://conda.io/docs/install/full.html),
these requirements should all be satisfied already.

To render a visualization of some graphical models, you also need to install Graphviz [download page](http://www.graphviz.org/download). We have already used this library in Tutorial 1, thus, you should have it installed. If you do not have it and use the conda installation, then use the command ```conda install python-graphviz```. 

Once we have done all that, we
import some useful modules for later use.

In [None]:
# Make division default to floating-point, saving confusion
from __future__ import division
from __future__ import print_function

# combinatorics
from itertools import product, combinations
# ordered dictionaries are useful for keeping ordered sets of varibles
from collections import OrderedDict as odict
# visualise graphs
from graphviz import Graph
# visualise plots
from tabulate import tabulate

## The ICU network

Once again we will use a subset of the ICU-Alarm network as a benchmark. However, our code will run for any network. We include here a graphical representation so that you can remember the variables and the dependencies between them.

![ICU Graph](img/ICU_graph.png "Graph exercise")

As we did in previous tutorials, we include a dictionary named `icu_factors` with the CPTs of the nine variables in the subset of the ICU-Alarm network.

In [None]:
icu_factors = {
    'H': {
        'dom': ('H',), 
        'table': odict([
            ((0,), 0.80),
            ((1,), 0.20),
        ])
    },
    
    'V': {
        'dom': ('L', 'H', 'V'), 
        'table': odict([
            ((0, 0, 0), 0.05),
            ((0, 0, 1), 0.95),
            ((0, 1, 0), 0.99),
            ((0, 1, 1), 0.01),
            ((1, 0, 0), 0),
            ((1, 0, 1), 1),
            ((1, 1, 0), 1),
            ((1, 1, 1), 0),
        ])
    },

    'C' : {
        'dom': ('V', 'C'), 
        'table': odict([
            ((0, 0), 0.94),
            ((0, 1), 0.04),
            ((0, 2), 0.02),
            ((1, 0), 0.02),
            ((1, 1), 0.26),
            ((1, 2), 0.72),
        ])
    },

    'L' : {
        'dom': ('L',), 
        'table': odict([
            ((0, ), 0.95),
            ((1, ), 0.05),
        ])
    },

    'S' : {
        'dom': ('L', 'H', 'S'), 
        'table': odict([
            ((0, 0, 0), 0.04),
            ((0, 0, 1), 0.96),
            ((0, 1, 0), 0.48),
            ((0, 1, 1), 0.52),
            ((1, 0, 0), 0.95),
            ((1, 0, 1), 0.05),
            ((1, 1, 0), 0),
            ((1, 1, 1), 1),
        ])
    },

    'O' : {
        'dom': ('S', 'V', 'O'), 
        'table': odict([
            ((0, 0, 0), 0.97),
            ((0, 0, 1), 0.01),
            ((0, 0, 2), 0.02),
            ((0, 1, 0), 0.78),
            ((0, 1, 1), 0.19),
            ((0, 1, 2), 0.03),
            ((1, 0, 0), 0.22),
            ((1, 0, 1), 0.76),
            ((1, 0, 2), 0.02),
            ((1, 1, 0), 0.01),
            ((1, 1, 1), 0.01),
            ((1, 1, 2), 0.98),        
        ])
    },

    'T': {
        'dom': ('A', 'T'), 
        'table': odict([
            ((0, 0), 0.30),
            ((0, 1), 0.70),
            ((1, 0), 1),
            ((1, 1), 0),
        ])
    },

    'B' : {
        'dom': ('O', 'T', 'B'), 
        'table': odict([
            ((0, 0, 0), 1),
            ((0, 0, 1), 0),
            ((0, 0, 2), 0),
            ((0, 1, 0), 0.30),
            ((0, 1, 1), 0.62),
            ((0, 1, 2), 0.08),
            ((1, 0, 0), 0.93),
            ((1, 0, 1), 0.07),
            ((1, 0, 2), 0),
            ((1, 1, 0), 0.02),
            ((1, 1, 1), 0.49),
            ((1, 1, 2), 0.49),
            ((2, 0, 0), 0.90),
            ((2, 0, 1), 0.08),
            ((2, 0, 2), 0.02),
            ((2, 1, 0), 0.01),
            ((2, 1, 1), 0.08),
            ((2, 1, 2), 0.91),        
        ])
    },

    'A' : {
        'dom': ('A',), 
        'table': odict([
            ((0, ), 0.99),
            ((1, ), 0.01),
        ])
    }
}

outcomeSpace = dict(
    H=(0,1),
    L=(0,1),
    A=(0,1),
    V=(0,1),
    S=(0,1),
    T=(0,1),
    C=(0,1,2),
    O=(0,1,2),
    B=(0,1,2),
)

## Elimination Tree

Elimination tree is a data structure that informs the order we will eliminate the factors. We will define an elimination tree and provide it as input to the FE algorithm. 

In an elimination tree, each node of the tree corresponds to one or more factors in the network, although one network factor per node is the most common. In the figure below, we copy the nodes of the ICU Alarm network. We linked the nodes, intending to keep the cluster sizes small. Remember that the width of the elimination tree is defined as the size of the largest cluster minus one.

![Elimination Tree](img/elimination_tree.png "Elimination Tree")

### Exercise

In the next cell, declare the elimination tree specified in the figure above. Remember the tree is an undirected graph.

In [None]:
eTree = dict(
    None                                            # Elimination tree adjacency list definition: + 7 lines
)

##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    'B': '1,0!',
    'O': '0,1!',
    'C': '1,1!',
    'T': '2,1!',
    'S': '0,2!',
    'V': '1,2!',
    'A': '2,2!',
    'L': '0,3!',
    'H': '1,3!',
}

for v in eTree.keys():
    dot.node(str(v), pos=pos[v])

for v in eTree:
    for w in eTree[v]:
        dot.edge(v, w)
dot

In [None]:
# Answer

eTree = dict(
    L= ('S',),
    H= ('S',),
    C= ('V',),
    V= ('C', 'S'),
    S= ('L', 'H', 'V', 'O'),
    O= ('S', 'B'),
    B= ('O', 'T'),
    T= ('B', 'A'),
    A= ('T',)
)

##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    'B': '1,0!',
    'O': '0,1!',
    'C': '1,1!',
    'T': '2,1!',
    'S': '0,2!',
    'V': '1,2!',
    'A': '2,2!',
    'L': '0,3!',
    'H': '1,3!',
}

for v in eTree.keys():
    dot.node(str(v), pos=pos[v])

for v in eTree:
    for w in eTree[v]:
        dot.edge(v, w)
dot

An important piece of information in the elimination tree is the separator. Separators are associated with edges of the elimination tree. In an $i-j$ edge, the separator is defined as $S_{ij} = vars(i) \cap vars(j)$, where $vars(k)$ are the variables that appear in the factors in the $k$ side of the edge. 

Separators are relevant because they inform us which variables can be eliminated. In other words, the message associated with an edge $i-j$ is guaranteed to be over the separator $S_{ij}$. Therefore, we need to eliminate all other variables not present in $S_{ij}$ before sending the message.

We need to have quick access to separators, so let's declare them in a Python dictionary. Since $S_{ij} = S_{ji}$, we can adopt a canonical form and maintain just one of them. 

## Exercise 

Declare a dictionary with separators for each edge. For an edge $i-j$ declare only $S_{ij}$ such that $i$ precedes $j$ in alphabetical order.

In [None]:
# Notice we declare LS = ('L'). You do not need to declare SL = ('L'). The code only checks for separators declared in lexical order

S = dict(
    LS= ('L',),
    None                                                       # Separator dictionary: +7 lines
)

##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree with Separator', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    'B': '1,0!',
    'O': '0,1!',
    'C': '1,1!',
    'T': '2,1!',
    'S': '0,2!',
    'V': '1,2!',
    'A': '2,2!',
    'L': '0,3!',
    'H': '1,3!',
}

for v in eTree.keys():
    dot.node(str(v), pos=pos[v])

for v in eTree:
    for w in eTree[v]:
        if v < w:
            dot.edge(v, w, ','.join(S[v+w]))
dot

In [None]:
# Answer

# Notice we declare LS = ('L'). You do not need to declare SL = ('L'). The code only checks for separators declared in lexical order

S = dict(
    LS= ('L',),
    HS= ('H',),
    CV= ('V',),
    SV= ('L','H','V'),
    OS= ('V','S'),
    BO= ('O',),
    BT= ('T',),
    AT= ('A',)
)


##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree with Separator', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    'B': '1,0!',
    'O': '0,1!',
    'C': '1,1!',
    'T': '2,1!',
    'S': '0,2!',
    'V': '1,2!',
    'A': '2,2!',
    'L': '0,3!',
    'H': '1,3!',
}

for v in eTree.keys():
    dot.node(str(v), pos=pos[v])

for v in eTree:
    for w in eTree[v]:
        if v < w:
            dot.edge(v, w, ','.join(S[v+w]))
dot

## Propagating Messages

We will now implement the code to propagate the messages over the elimination tree. The main steps in the algorithm are the following:

1. Choose a root node $r$ in the tree $T$
2. Pull/collect messages towards root $r$
3. Push/distribute messages away from root $r$
5. return $\phi_i \prod_k M_{ki}$ for each node $i$ in tree $T$

We can implement the pull step using a depth-first search. The general idea is to use the "backtracking" step of the recursion to calculate the messages. In other words, we start a depth-first search from the root $r$ until be find a dead-end. Such a dead-end is a node with only one neighbour. Therefore, we can calculate the message and associate to the edge between the dead-end node and its neighbour. This message is guaranteed to be towards the root.

When we finish the pull step, we can also use a depth-first search to implement push step. In this case, we calculate a message every time we traverse an edge. Differently from the pull step, we will use the "forward" part of the recursion, so we guarantee the message is directed away from the root.

One final, and important detail. The messages are calculated according to the following equation:

$M_{ij} = project(\phi_i \prod_{k \neq j} M_{ki},S_{ij})$

where $project(\phi, S_{ij})$ is an operation that removes (by marginalization) all variables in $\phi_i$ but the ones in $S_{ij}$.

The critical detail here is that when we compute a message from node $i$ to $j$ we multiply all incoming messages to $i$ except for the message from $j$. 

We start our implementation by reusing our code from join and marginalization operations as well as the printFactor function. Feel free to you our code here.

In [None]:
def printFactor(f):
    """
    argument 
    `f`, a factor to print on screen
    """
    # Create a empty list that we will fill in with the probability table entries
    table = list()
    
    # Iterate over all keys and probability values in the table
    for key, item in f['table'].items():
        # Convert the tuple to a list to be able to manipulate it
        k = list(key)
        # Append the probability value to the list with key values
        k.append(item)
        # Append an entire row to the table
        table.append(k)
    # dom is used as table header. We need it converted to list
    dom = list(f['dom'])
    # Append a 'Pr' to indicate the probabity column
    dom.append('Pr')
    print(tabulate(table,headers=dom,tablefmt='orgtbl'))

def prob(factor, *entry):
    """
    argument 
    `factor`, a dictionary of domain and probability values,
    `entry`, a list of values, one for each variable in the same order as specified in the factor domain.
    
    Returns p(entry)
    """

    return factor['table'][entry]     # insert your code here, 1 line   

def join(f1, f2, outcomeSpace):
    """
    argument 
    `f1`, first factor to be joined.
    `f2`, second factor to be joined.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns a new factor with a join of f1 and f2
    """
    
    # First, we need to determine the domain of the new factor. It will be union of the domain in f1 and f2
    # But it is important to eliminate the repetitions
    common_vars = list(f1['dom']) + list(set(f2['dom']) - set(f1['dom']))
    
    # We will build a table from scratch, starting with an empty list. Later on, we will transform the list into a odict
    table = list()
    
    # Here is where the magic happens. The product iterator will generate all combinations of varible values 
    # as specified in outcomeSpace. Therefore, it will naturally respect observed values
    for entries in product(*[outcomeSpace[node] for node in common_vars]):
        
        # We need to map the entries to the domain of the factors f1 and f2
        entryDict = dict(zip(common_vars, entries))
        f1_entry = (entryDict[var] for var in f1['dom'])
        f2_entry = (entryDict[var] for var in f2['dom'])
        
        # Insert your code here
        p1 = prob(f1, *f1_entry)           # Use the fuction prob to calculate the probability in factor f1 for entry f1_entry 
        p2 = prob(f2, *f2_entry)           # Use the fuction prob to calculate the probability in factor f2 for entry f2_entry 
        
        # Create a new table entry with the multiplication of p1 and p2
        table.append((entries, p1 * p2))
    return {'dom': tuple(common_vars), 'table': odict(table)}


def marginalize(f, var, outcomeSpace):
    """
    argument 
    `f`, factor to be marginalized.
    `var`, variable to be summed out.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns a new factor f' with dom(f') = dom(f) - {var}
    """    
    
    # Let's make a copy of f domain and convert it to a list. We need a list to be able to modify its elements
    new_dom = list(f['dom'])
    new_dom.remove(var)            # Remove var from the list new_dom by calling the method remove(). 1 line
    table = list()                 # Create an empty list for table. We will fill in table from scratch. 1 line
    for entries in product(*[outcomeSpace[node] for node in new_dom]):
        s = 0;                     # Initialize the summation variable s. 1 line

        # We need to iterate over all possible outcomes of the variable var
        for val in outcomeSpace[var]:
            # To modify the tuple entries, we will need to convert it to a list
            entriesList = list(entries)
            # We need to insert the value of var in the right position in entriesList
            entriesList.insert(f['dom'].index(var), val)
          
            p = prob(f, *tuple(entriesList))     # Calculate the probability of factor f for entriesList. 1 line
            s = s + p                            # Sum over all values of var by accumulating the sum in s. 1 line
            
        # Create a new table entry with the multiplication of p1 and p2
        table.append((entries, s))
    return {'dom': tuple(new_dom), 'table': odict(table)}

def normalize(f):
    """
    argument 
    `f`, factor to be normalized.
    
    Returns a new factor f' as a copy of f with entries that sum up to 1
    """ 
    table = list()
    sum = 0
    for k, p in f['table'].items():
        sum = sum + p
    for k, p in f['table'].items():
        table.append((k, p/sum))
    return {'dom': f['dom'], 'table': odict(table)}

### Exercise

We will implement the code to compute messages for all edges in the elimination tree. For each undirected edge there will be two messages: one $M_{ij}$ and the other $M_{ji}$. The messages are stored in a dictionary, so we can easily retrieve them later. Similarly to separators, we use as key the concatenation of the names of the nodes.

Let's start with the function `getMessage` that call two auxiliary functions: `pull` and `push`.

In [None]:
def getMessages(factors, root, eTree, S, outcomeSpace):
    """
    argument 
    `factors`, dictionary with all factors.
    `root`, root node.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary with all messages
    """     
    # Initialize dictionary to store messages. The message key will be concatenation of nodes names
    messages = None                                                    # Initialize with an empty dictionary: 1 line
    # For each neighbouring node of root, we start a depth-first search
    for v in eTree[root]:
        # Call pull and store the resulting message in messages[v+root]
        messages[v+root] = None                                        # Call pull to compute the messages toward the root (see next cell): 1 line

    # Call push to recursively push messages out from the root. To start it off, set the previous node to '' or None.                 
    None                                                           # Call push recursively: 1 line
    return messages

In [None]:
# Answer

def getMessages(factors, root, eTree, S, outcomeSpace):
    """
    argument 
    `factors`, dictionary with all factors.
    `root`, root node.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary with all messages
    """     
    # Initialize dictionary to store messages. The message key will be concatenation of nodes names
    messages = dict()
    # For each neighbouring node of root, we start a depth-first search
    for v in eTree[root]:
        # Call pull and store the resulting message in messages[v+root]
        messages[v+root] = pull(v, root, factors, eTree, S, messages, outcomeSpace)

    # Call push to recursively push messages out from the root. To start it off, set the previous node to '' or None.                 
    push(root, '', factors, eTree, S, messages, outcomeSpace)
    
    return messages

### Exercise

We continue with the implementation of pull, a function that computes all messages toward the root.

In [None]:
def pull(curr, previous, factors, eTree, S, messages, outcomeSpace):
    """
    argument 
    `curr`, current node.
    `previous`, node we came from in the search.
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `messages`, dictionary with messages.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns a factor fx with a message from node previous to curr
    """
    # fx is an auxiliary factor. Initialize fx with the factor associated with the curr node
    fx = None                                                       # Initialize with factor at curr node: 1 line
    # Depth-first search
    for v in eTree[curr]:
        # This is an important step: avoid returning using the edge we came from
        if not v == previous:
            # Call pull recursively since root is not an edge with a single neighbour
            messages[v+curr] = None                                 # Call pull recursively: 1 line
            # Here, we returned from the recursive call. 
            # We need to join the received message with fx
            fx = None                                               # Join fx with received message: 1 line
    # fx has all incoming messages multiplied by the node factor. It is time to marginalize the variables not is S_{ij}
    for v in fx['dom']:
        if not v in S[''.join(sorted([previous,curr]))]:
            # Call marginalize to remove variable v from fx's domain
            fx = None                                               # Call marginalize to remove variable v
    return fx

In [None]:
# Answer

def pull(curr, previous, factors, eTree, S, messages, outcomeSpace):
    """
    argument 
    `curr`, current node.
    `previous`, node we came from in the search.
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `messages`, dictionary with messages.
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns a factor fx with a message from node previous to root
    """
    # fx is an auxiliary factor. Initialize fx with the factor associated with the curr node
    fx = factors[curr]
    # Depth-first search
    for v in eTree[curr]:
        # This is an important step: avoid returning using the edge we came from
        if not v == previous:
            # Call pull recursively since root is not an edge with a single neighbour
            messages[v+curr] = pull(v, curr, factors, eTree, S, messages, outcomeSpace)
            # Here, we returned from the recursive call. 
            # We need to join the received message with fx
            fx = join(fx, messages[v+curr], outcomeSpace)
    # fx has all incoming messages multiplied by the node factor. It is time to marginalize the variables not is S_{ij}
    for v in fx['dom']:
        if not v in S[''.join(sorted([previous,curr]))]:
            # Call marginalize to remove variable v from fx's domain
            fx = marginalize(fx, v, outcomeSpace)
    return fx

### Exercise

And finally, we implement a function that computes the messages away from the root.

In [None]:
def push(curr, previous, factors, eTree, S, messages, outcomeSpace):
    """
    argument 
    `curr`, current node.
    `previous`, previous node.
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `messages`, dictionary with messages.
    `outcomeSpace`, dictionary with the domain of each variable
    
    """    
    for v in eTree[curr]:
        # This is an important step: avoid returning using the edge we came from        
        if not v == previous:
            # Initialize messages[curr+v] with the factor associated with the curr node
            messages[curr+v] = None                                     # Initialize with factor at curr node: 1 line
            for w in eTree[curr]:
                # This is an important step: do not consider the incoming message from v when computing the outgoing message to v
                if not v == w:
                    # Join messages coming from w into messages[curr+v]
                    messages[curr+v] = None                             # Join messages comming from w: 1 line

            # messages[curr+v] has all incoming messages multiplied by the node factor. It is time to marginalize the variables not is S_{ij}
            for w in messages[curr+v]['dom']:
                if not w in S[''.join(sorted([v,curr]))]:
                    # Call marginalize to remove variable v from messages[curr+v] domain
                    messages[curr+v] = None                             # Remove variable v from message: 1 line
            # Call push recursively and go to the next node v
            None                                                        # Call push recursively

In [None]:
# Answer

def push(curr, previous, factors, eTree, S, messages, outcomeSpace):
    """
    argument 
    `curr`, current node.
    `previous`, previous node.
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `S`, separators dictionary.
    `messages`, dictionary with messages.
    `outcomeSpace`, dictionary with the domain of each variable
    
    """    
    for v in eTree[curr]:
        # This is an important step: avoid returning using the edge we came from        
        if not v == previous:
            # Initialize messages[curr+v] with the factor associated with the curr node
            messages[curr+v] = factors[curr]
            for w in eTree[curr]:
                # This is an important step: do not consider the incoming message from v when computing the outgoing message to v
                if not v == w:
                    # Join messages coming from w into messages[curr+v]
                    messages[curr+v] = join(messages[curr+v], messages[w+curr], outcomeSpace)

            # messages[curr+v] has all incoming messages multiplied by the node factor. It is time to marginalize the variables not is S_{ij}
            for w in messages[curr+v]['dom']:
                if not w in S[''.join(sorted([v,curr]))]:
                    # Call marginalize to remove variable v from messages[curr+v] domain
                    messages[curr+v] = marginalize(messages[curr+v], w, outcomeSpace)
            # Call push recursively and go to the next node v
            push(v, curr, factors, eTree, S, messages, outcomeSpace)

In [None]:
################
# Test code

m = getMessages(icu_factors, 'B', eTree, S, outcomeSpace)
printFactor(m['TB'])
print()
printFactor(m['BT'])

If your code is correct, you should see the following output:

```
|   T |    Pr |
|-----+-------|
|   0 | 0.307 |
|   1 | 0.693 |

|   T |   Pr |
|-----+------|
|   0 |    1 |
|   1 |    1 |
```

## Querying Cluster Marginals

We can now compute the marginals for one or more variables inside the same cluster. We implement a function `queryCluster`. 

`queryCluster` takes as input a node and a query, among other arguments. The query is a list of variable names. All variables must be in the node cluster. The function returns the marginal distribution for those variables. For now, we are not conditioning to any piece of evidence. We will do it in the next section.

In [None]:
def queryCluster(factors, eTree, messages, node, query):
    """
    argument 
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `messages`, dictionary with messages between neighbouring nodes
    `node`, a node in the elimination tree whose cluster contain the query variables.
    `query`, a list with query variables
    
    Returns factor with the marginal for the query variables
    """ 
    # fx is an auxiliary factor. Initialize fx with the factor associated with the root node    
    fx = None                                                           # Initialize fx: 1 line
    for v in eTree[node]:
        # Call join to multiply the incoming messages from all neighbouring nodes to v        
        fx = None                                                       # Multiply fx with incoming messages: 1 line
    for v in fx['dom']:
        if not v in query:
            # Call marginalize to remove variable v from fx domain            
            fx = None                                                   # Remove all variables not in the query: l line
    return fx
    
##################
# Test code
m = getMessages(icu_factors, 'B', eTree, S, outcomeSpace)
printFactor(queryCluster(icu_factors, eTree, m, 'S', ('L','H')))

In [None]:
# Answer

def queryCluster(factors, eTree, messages, node, query):
    """
    argument 
    `factors`, dictionary with all factors.
    `eTree`, elimination tree.
    `messages`, dictionary with messages between neighbouring nodes
    `node`, a node in the elimination tree whose cluster contain the query variables.
    `query`, a list with query variables
    
    Returns factor with the marginal for the query variables
    """ 
    # fx is an auxiliary factor. Initialize fx with the factor associated with the root node    
    fx = factors[node]    
    for v in eTree[node]:
        # Call join to multiply the incoming messages from all neighbouring nodes to v        
        fx = join(fx, messages[v+node], outcomeSpace)
    for v in fx['dom']:
        if not v in query:
            # Call marginalize to remove variable v from fx domain            
            fx = marginalize(fx, v, outcomeSpace)
    return fx
    
    
##################
# Test code
m = getMessages(icu_factors, 'B', eTree, S, outcomeSpace)
printFactor(queryCluster(icu_factors, eTree, m, 'S', ('L','H')))

If you implemented you code correctly, you should see the following output:

```
|   L |   H |   Pr |
|-----+-----+------|
|   0 |   0 | 0.76 |
|   0 |   1 | 0.19 |
|   1 |   0 | 0.04 |
|   1 |   1 | 0.01 |
````

It is interesting to note that the same marginals can be obtained from different node clusters. For instance, variables $L$ and $H$ are also present in the cluster of node 'V'. Therefore, the query:

In [None]:
printFactor(queryCluster(icu_factors, eTree, m, 'V', ('L','H')))

Provides the same output.

## Including Evidence

The inclusion of evidence allows us to answer queries of the form $P(X, e)$ and, after normalization, $P(X|e)$. There are two main ways to include evidence:

1. Eliminate the rows of the factors that do not match the evidence.
2. Create a new factor, known as evidence indicator, that associates a value 1 to the evidence and 0 otherwise.

For instance, if we want to specify the evidence $B=False$, we can create a factor $\lambda_B$:

| $B$  | $\lambda_B$ |
|:-|:-|
| true | 0           |
| false | 1          | 

Then, we include this factor in any node that has a cluster that contains the variable $B$.

So far, we have used the technique described in (1) to include evidence. So, in the next cells, we will observe evidence using evidence indicators.

### Exercise

Let's implement the function `evidence`. This function will create a new dictionary `lambdas` with evidence indicators. 

In [None]:
def evidence(outcomeSpace, **q_evi):
    """
    argument 
    `outcomeSpace`, dictionary with the domain of each variable
    `q_evi`, dictionary of evidence in the form of variables names and values
    
    Returns dictionary with evidence factors 
    """     
    # Create an empty dictionary
    lambdas = dict()
    for var, evi in q_evi.items():
        # Create an empty dictionary insde lambdas
        lambdas[var] = dict()
        # Domain of the evidence indicator (single variable)
        lambdas[var]['dom'] = None                               # Tuple with variable name: 1 line
        # Probability table for the evidence indicator 
        lambdas[var]['table'] = None                             # odict object with probability of evidence: 1 line
    return lambdas

########################
# Test code

l = evidence(outcomeSpace, H=0,V=1)
printFactor(l['H'])
print()
printFactor(l['V'])

In [None]:
def evidence(outcomeSpace, **q_evi):
    """
    argument 
    `outcomeSpace`, dictionary with the domain of each variable
    `q_evi`, dictionary of evidence in the form of variables names and values
    
    Returns dictionary with evidence factors 
    """     
    # Create an empty dictionary
    lambdas = dict()
    for var, evi in q_evi.items():
        # Create an empty dictionary insde lambdas
        lambdas[var] = dict()
        # Domain of the evidence indicator (single variable)
        lambdas[var]['dom'] = tuple(var)
        # Probability table for the evidence indicator 
        lambdas[var]['table'] = odict(((v,),int(v==evi)) for v in outcomeSpace[var])
    return lambdas

########################
# Test code

l = evidence(outcomeSpace, H=0,V=1)
printFactor(l['H'])
print()
printFactor(l['V'])

If you implemented your code correctly, you should see the following output:

```
|   H |   Pr |
|-----+------|
|   0 |    1 |
|   1 |    0 |

|   V |   Pr |
|-----+------|
|   0 |    0 |
|   1 |    1 |
```

### Exercise

Now, we will implement a function that will multiply the indicator factors. We will search the dictionary of factors to find a network factor whose domain mentions the evidence variable. For simplicity, we will pick the first factor we find. The next cell has the stub code for this function.

In [None]:
def distributeLambdas(factors, lambdas, outcomeSpace):
    """
    argument 
    `factors`, dictionary of network factors
    `lambdas`, dictionary of evidence indicators
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary of factors multiplied by the evidence indicators
    """  
    # Let's make a copy of the factors, we will return f
    f = factors.copy()
    # We iterate over all evidence indicator factors in lambdas
    for lambda_key, lambda_value in lambdas.items():
        # This loop is used to search for a factor that has the indicator variable in its domain
        for factor_key, factor_value in f.items():
            # It is simpler if we multiply the indicator factor in the first factor we find that mentions its variable
            if lambda_key in None                                                         # 1 line
                # Multiply the network factor with the indicator factor
                f[factor_key] = None                                                      # 1 line
    return f

#################
# Test code

lambdas = evidence(outcomeSpace, B=1)
f = distributeLambdas(icu_factors, lambdas, outcomeSpace)
m = getMessages(f, 'B', eTree, S, outcomeSpace)
printFactor(queryCluster(f, eTree, m, 'A', ('A')))

In [None]:
# Answer

def distributeLambdas(factors, lambdas, outcomeSpace):
    """
    argument 
    `factors`, dictionary of network factors
    `lambdas`, dictionary of evidence indicators
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary of factors multiplied by the evidence indicators
    """  
    # Let's make a copy of the factors, we will return f
    f = factors.copy()
    # We iterate over all evidence indicator factors in lambdas
    for lambda_key, lambda_value in lambdas.items():
        # This loop is used to search for a factor that has the indicator variable in its domain
        for factor_key, factor_value in f.items():
            # It is simpler if we multiply the indicator factor in the first factor we find that mentions its variable
            if lambda_key in factor_value['dom']:
                # Multiply the network factor with the indicator factor
                f[factor_key] = join(factor_value, lambda_value, outcomeSpace)
    return f

#################
# Test code

lambdas = evidence(outcomeSpace, B=1)
f = distributeLambdas(icu_factors, lambdas, outcomeSpace)
m = getMessages(f, 'B', eTree, S, outcomeSpace)
printFactor(queryCluster(f, eTree, m, 'A', ('A')))

This is the expected output for $P(A, B=1)$. 

```
|   A |          Pr |
|-----+-------------|
|   0 | 0.179175    |
|   1 | 0.000642448 |
```

Note that after normalization, we can get $P(A| B=1)$.

In [None]:
printFactor(normalize(queryCluster(f, eTree, m, 'A', ('A'))))

# Jointrees

We can reuse our code for elimination trees to implement a Shenoy-Shafer architecture for jointrees. We need first to define a jointree graph.

In terms of implementation, the main difference between the jointree and an elimination tree are the following:

1. Our implementation of elimination trees has assumed that we have one node per factor. Therefore, we named the elimination trees with the names of the variables associated with each factor. However, jointrees typically have clusters that involve several variables. For now on, we will use numbers to designate each jointree node.

2. As our implementation of elimination trees has one node per factor, we did not have to associate factors to nodes. However, for jointrees, we may have more than one factor in the same node. We will need to distribute the factors according to the clusters.

3. With jointrees, we may have nodes with no factor associated. Therefore, we will need to assign a trivial factor to those nodes.

For the ICU network, we handcrafted the following jointree:

![Jointree](img/jointree.png "Jointree")


Let's start specifying the jointree graph, clusters and separators.

## Exercise

Specify the jointree graph, separators and clusters. We wrote the first line of each dictionary.

Number the nodes as follows:
```
   5
   4
1  2  3
```

In [None]:
joinTree = {
    '1': ('2',),
    None                                                    # Jointree graph: 4 lines
}

S = {
    '12': ('S', 'V'),
    None                                                   # Jointree separator: 3 lines
}

C = {
    '1': ('S', 'L', 'H', 'V'),
    None                                                   # Jointree cluster: 4 lines
}

##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    '1': '0,0!',
    '2': '2,0!',
    '3': '4,0!',
    '4': '2,1.5!',
    '5': '2,3!',
}

for v in joinTree.keys():
    dot.node('\n'.join((str(v),','.join(C[v]))), pos=pos[v])

for v in joinTree:
    for w in joinTree[v]:
        if v < w:
            dot.edge('\n'.join((str(v),','.join(C[v]))), '\n'.join((str(w),','.join(C[w]))), ','.join(S[v+w]))
dot

In [None]:
# Answer

joinTree = {
    '1': ('2',),
    '2': ('1', '3', '4'),
    '3': ('2',),
    '4': ('2', '5'),
    '5': ('4',),
}

S = {
    '12': ('S', 'V'),
    '23': ('V',),
    '24': ('O',),
    '45': ('T'),
}

C = {
    '1': ('S', 'L', 'H', 'V'),
    '2': ('O', 'S', 'V'),
    '3': ('C', 'V'),
    '4': ('B', 'O', 'T'),
    '5': ('A', 'T'),
}

##############
# Test code

dot = Graph(engine="neato", comment='ICU Elimination Tree', strict=True)
dot.attr(overlap="false", splines="true")

pos = {
    '1': '0,0!',
    '2': '2,0!',
    '3': '4,0!',
    '4': '2,1.5!',
    '5': '2,3!',
}

for v in joinTree.keys():
    dot.node('\n'.join((str(v),','.join(C[v]))), pos=pos[v])

for v in joinTree:
    for w in joinTree[v]:
        if v < w:
            dot.edge('\n'.join((str(v),','.join(C[v]))), '\n'.join((str(w),','.join(C[w]))), ','.join(S[v+w]))
dot

The next function will create one trivial factor for each cluster of the jointree. The idea is that we will start with these trivial factors and design a function to assign network factors to clusters. If one factor ends with no assigned function, the trivial factor will be used to provide the correct computations.

Let's start with the function that generates trivial factors according to the jointree clusters.

## Exercise

Implement a function that creates one trivial factor for each jointree cluster.

In [None]:
def genTrivialFactors(clusters):
    """
    argument 
    clusters: dictionary mapping node number to a tuple of variables in that node.
    
    Returns a dictionary of trivial factors, one for each node of the jointree.
    Each trivial factor should have a domain of all variables in that node of the jointree.
    
    """
    # This empty dictionary will store one factor for each jointree cluster
    factors = {}
    # Let's iterate over the cluster keys
    for c in clusters.keys():
        # And create one empty factor for each cluster
        factors[c] = {'dom': None, 'table': odict([])}                                 # 1 line
        for values in product(*[outcomeSpace[var] for var in clusters[c]]):
            # The trivial factors assign value 1 (true) for all entries
            factors[c]['table'][values] = None                                         # 1 line
    return factors

In [None]:
# Answer

def genTrivialFactors(clusters):
    """
    argument 
    clusters: dictionary mapping node number to a tuple of variables in that node.
    
    Returns a dictionary of trivial factors, one for each node of the jointree.
    Each trivial factor should have a domain of all variables in that node of the jointree.
    
    """
    # This empty dictionary will store one factor for each jointree cluster
    factors = {}
    # Let's iterate over the cluster keys
    for c in clusters.keys():
        # And create one empty factor for each cluster
        factors[c] = {'dom': clusters[c], 'table': odict([])}
        for values in product(*[outcomeSpace[var] for var in clusters[c]]):
            # The trivial factors assign value 1 (true) for all entries
            factors[c]['table'][values] = 1
    return factors

Now, we will implement the function that associates the network factors to jointree clusters. Each factor **must** be assigned to one cluster in the jointree. This property is called *family preservation*. If we cannot find a cluster that can have a network factor then our jointree is incorrect (see the jointree property 2 - "Each factor in $G$ must appear in some cluster $\textbf{C}_i$". We raise an exception is this case.

## Exercise

Complete the implementation of the function ``distribute''.

In [None]:
def distribute(factors, clusters, outcomeSpace):
    """
    argument 
    `factors`, dictionary of network factors
    `clusters`, dictionary of jointree clusters
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary of factors multiplied according to the jointree clusters
    """  
    # Let's start generating trivial factors for each jointree cluster
    trivialFactors = None                                                                 # 1 line
    for factor_key, factor_value in factors.items():
        for tfactor_key, tfactor_value in trivialFactors.items():
            # We will find a match if the factor domain is a subset of the cluster (trivial factor) domain
            if(set(factor_value['dom']).issubset(None)):                                 # 1 line
                # Let's multiply the factors
                trivialFactors[tfactor_key] = None                                       # 1 line
                # and set the flag accordingly
                matching_found = 1
                break
        else:
            # This else clause will only be executed if the for loop reaches the end. Google python for/else for more info
            raise NameError('FamilyPreservationError')
    return trivialFactors

###################
# Test code

factors = distribute(icu_factors, C, outcomeSpace)
lambdas = evidence(outcomeSpace, B=1)
factors = distributeLambdas(t, lambdas, outcomeSpace)
m = getMessages(factors, '2', joinTree, S, outcomeSpace)
printFactor(queryCluster(t, joinTree, m, '5', ('A')))

In [None]:
# Answer

def distribute(factors, clusters, outcomeSpace):
    """
    argument 
    `factors`, dictionary of network factors
    `clusters`, dictionary of jointree clusters
    `outcomeSpace`, dictionary with the domain of each variable
    
    Returns dictionary of factors multiplied according to the jointree clusters
    """  
    # Let's start generating trivial factors for each jointree cluster
    trivialFactors = genTrivialFactors(clusters)
    for factor_key, factor_value in factors.items():
        for tfactor_key, tfactor_value in trivialFactors.items():
            # We will find a match if the factor domain is a subset of the cluster (trivial factor) domain
            if(set(factor_value['dom']).issubset(set(tfactor_value['dom']))):
                # Let's multiply the factors
                trivialFactors[tfactor_key] = join(tfactor_value, factor_value, outcomeSpace)
                # and set the flag accordingly
                matching_found = 1
                break
        else:
            # This else clause will only be executed if the for loop reaches the end. Google python for/else for more info
            raise NameError('FamilyPreservationError')
    return trivialFactors

###################
# Test code

factors = distribute(icu_factors, C, outcomeSpace)
lambdas = evidence(outcomeSpace, B=1)
factors = distributeLambdas(t, lambdas, outcomeSpace)
m = getMessages(factors, '2', joinTree, S, outcomeSpace)
printFactor(queryCluster(t, joinTree, m, '5', ('A')))

Great! We have reached the end of this tutorial.

There are some extensions that we can work on, to make this code more complete.

1. We need to design a smart way to recompute messages when we change the evidence. We do not have to recompute all messages as we discussed in the lectures.

2. We could make the query more user-friendly if we could find the cluster that has the answer automatically. Also, we could generate an exception if no cluster has all the queries variables.

3. We can implement the jointree operations to modify the jointree structure so that we can adjust the clusters to our query needs. The operations are add variable as well as merge, add and remove clusters.

See you next week!