Implementation of the sum-product algorithm to obtain the marginal distributions in an efficient way. This script implies the graph is acyclic. The algorithm is run on the following Bayesian Network:

<img src="graph.png" style="width: 700px;">

In [1]:
import numpy as np
import pdb


class Vertex(object):

    def __init__(self):
        self.neighbours = []
        self.invoice = []
        self.outvoice = []
        self.past_outvoice = []

    def acceptMessage(self, sender, message):
        self.invoice[self.neighbours.index(sender)] = message

    def sendMessages(self):
        for i,message in enumerate(self.outvoice):
            #pdb.set_trace()
            self.neighbours[i].acceptMessage(self, message)

            
    def normalizeMessages(self):
        #tot = np.sum(self.outvoice)
        self.outvoice = [x /np.sum(x) for x in self.outvoice]
            
    def updateStep(self):
        self.past_outvoice = self.outvoice[:]
        
    def convergence(self,epsi):
        for i in range(len(self.outvoice)):
            delta = np.absolute(self.outvoice[i] - self.past_outvoice[i])
            if any(deltaVal > epsi for deltaVal in delta): 
                return False




class VariableNode(Vertex):
    def __init__(self, name, dim):
        super().__init__()
        self.name = name
        self.dim = dim

    def setMessages(self):
        #pdb.set_trace()
        if len(self.neighbours) ==1:
            pass
        else:
            self.updateStep()
            for i in range(len(self.invoice)):
                
                if len(self.invoice) >1:
                    self.outvoice[i] = np.prod([invoice for index,invoice in enumerate(self.invoice)
                                            if index is not i],0)
                else:
                    self.outvoice[i] = self.invoice[i]
                   
            self.normalizeMessages()

class FunctionNode(Vertex):

    def __init__(self, name, probs, *args):
        super().__init__()
        self.name = name
        self.probs = probs
        self.neighbours = list(args) 
        self.neighboursTot = len(self.neighbours)

        # init messages
        for i,vertex in enumerate(self.neighbours):
            
            # init for factor
            self.invoice.append(np.ones((vertex.dim,1)))
            self.outvoice.append(np.ones((vertex.dim,1)))
            self.past_outvoice.append(np.ones((vertex.dim,1)))

            # init for VariableNode
            vertex.neighbours.append(self)
            vertex.invoice.append(np.ones((vertex.dim,1)))
            vertex.outvoice.append(np.ones((vertex.dim,1)))
            vertex.past_outvoice.append(np.ones((vertex.dim,1)))


    def setMessages(self):

        self.updateStep()


        for i in range(self.neighboursTot):
            #pdb.set_trace()
            shapeShift = list(self.probs.shape)
            shapeShift.insert(0,shapeShift.pop(i))
            
            for _ in range(self.neighboursTot-2):
                self.invoice[i] = self.invoice[i][...,None]
            
            self.invoice[i] = np.broadcast_to(self.invoice[i],np.asarray(shapeShift))
            self.invoice[i] = np.rollaxis(self.invoice[i], 0, i+1)
        

            
        for i in range(self.neighboursTot):
            
            if len(self.invoice) >1:

                temp = np.prod([invoice for index,invoice in enumerate(self.invoice) if index is not i],0)
                temp = np.multiply(temp,self.probs)

                temp = np.rollaxis(temp, i, 0)
                temp = np.sum(temp, tuple(range(1,self.neighboursTot)))
                temp = temp.reshape((temp.shape[0],1))
            else:

                temp = self.probs
                


            self.outvoice[i] = temp
        

        self.normalizeMessages()

        
    

In [2]:

class Graph:


    def __init__(self):
        self.varNode = {}
        self.funcNode = []
        self.convergence = False



    
    def addVarNode(self, name, dim):


        newVar = VariableNode(name, dim)
        self.varNode[name] = newVar


        return newVar


    def addFunctionNode(self, name,  probs, *args):
        addedFuncNode = FunctionNode(name, probs, *args)
        self.funcNode.append(addedFuncNode)

        return addedFuncNode


    def sumProduct(self, max,epsi):
        t = 0
        while t < max and not self.convergence:
            
            t = t + 1
            print("Nous sommes a l'iteration numero: ", t)

            for func in self.funcNode: 
                func.setMessages()
                func.sendMessages()


            for key, var in self.varNode.items():
     
                var.setMessages()
                var.sendMessages()


            conv = True
            for key, var in self.varNode.items():
                conv = conv and var.convergence(epsi)
                if not conv:
                    break
            if conv:
                for func in self.funcNode:
                    conv = conv and func.convergence(epsi)
                    if not conv:
                        break
            if conv:
                self.convergence = True
        
        

    def marginals(self, max,epsi):
        print("Debut du procede...")
        """
        self.sumProduct(max,epsi)

        marginals = {}

        for key, var in self.varNode.items():
            messagesProduct = np.prod([invoice for invoice in var.invoice],0)
            

            probabilities = messagesProduct / np.sum(messagesProduct)
            marginals[k] = probabilities

        return marginals

"""
        self.sumProduct(max,epsi)

        marginals = {}
        # for each var
        foo=True
        for k, v in (self.varNode).items():
            if foo: # only include enabled variables
                # multiply together messages
                vmarg = 1
                
                for i in range(0, len(v.invoice)):
                    vmarg = vmarg * v.invoice[i]

                # normalize
                n = np.sum(vmarg)
                vmarg = vmarg / n

                marginals[k] = vmarg

        return marginals

In [4]:


G = Graph()

# add VariableNode nodes
a = G.addVarNode('a',3)
b = G.addVarNode('b',2)
c = G.addVarNode('c',4)

# add factors
# unary factor
Pa = np.array([[0.3],[0.6],[0.1]])
G.addFunctionNode("Pa", Pa, a)

Pc = np.array([[0.2],[0.1],[0.2],[0.5]])
G.addFunctionNode("Pc", Pc, c)

# connecting factor
P_b_given_a_c = np.array([   [[0.1,0.2,0.3,0.4] , [0.25,0.25,0.25,0.25] , [0.5,0.4,0.05,0.05]]    ,
                         [[0.3,0.2,0.3,0.2] , [0.1,0.2,0.2,0.5] , [0.1,0.1,0.1,0.7]]  ])
G.addFunctionNode("Pbac", P_b_given_a_c, b, a,c)

# factors can connect an arbitrary number of VariableNodes

# run sum-product and get marginals for VariableNodes
marg = G.marginals(100,1e-5)
distA = marg['a']
distB = marg['b']
distC = marg['c']
print(distA)
print(distB)
print(distC)

Debut du procede...
[[ 0.28546256]
 [ 0.61321586]
 [ 0.10132159]]
[[ 0.45374449]
 [ 0.54625551]]
[[ 0.13744493]
 [ 0.07753304]
 [ 0.16387665]
 [ 0.62114537]]
