# Implementing Additional Nodes

This demo demonstrates how to build a custom node in ForneyLab. As an example, we will develop a node that represents the logical `AND` operator.

Additional nodes can be developed in an external package and included in ForneyLab, or added to the ForneyLab core directly. There are four relevant directories to implementing new node functionality, namely:

- `/test/factor_nodes` : contains tests for node construction and update rule lookup and execution;
- `/src/factor_nodes` : implements node constructors and related methods;
- `/src/update_rules` : implements patterns for matching update rules;
- `/src/engines/julia/update_rules` : implements the actual update computations.

## Message Updates

The first step towards implementing a new node in ForneyLab is to define the node function and compute the required message updates. The `AND` node represents a deterministic constraint between three binary variables, $x, y, z \in \{0, 1\}$. For logical operations on binary variables, the constraints imposed by the node function $f(y, x, z)$ can be represented in a truth table. For the `AND` note, the truth table becomes
\begin{array}{c | c c}
    y & x & z\\
    \hline
    0 & 0 & 0\\
    0 & 1 & 0\\
    0 & 0 & 1\\
    1 & 1 & 1
\end{array}

### Forward message for $y$
We can now derive the forward sum-product message for $y$, given Bernoulli inbound messages for $x$ and $y$, as
\begin{align*}
    \overrightarrow{\mu}_y(y) &= \sum_x \sum_z \overrightarrow{\mu}_x(x)\,\overrightarrow{\mu}_z(z)\,f(y, x, z)\\
    &= \sum_x \sum_z \mathcal{B}er\!(x | p_x)\,\mathcal{B}er\!(z | p_z)\,f(y, x, z)\\
    &= \mathcal{B}er\!(x=0 | p_x)\,\mathcal{B}er\!(z=0 | p_z)f(y, x=0, z=0) + \mathcal{B}er\!(x=1 | p_x)\,\mathcal{B}er\!(z=0 | p_z)f(y, x=1, z=0) +\\
    &\quad\mathcal{B}er\!(x=0 | p_x)\,\mathcal{B}er\!(z=1 | p_z)f(y, x=0, z=1) + \mathcal{B}er\!(x=1 | p_x)\,\mathcal{B}er\!(z=1 | p_z)f(y, x=1, z=1)\\
    &= (1 - p_x)(1 - p_z)f(y, x=0, z=0) + p_x(1 - p_z)f(y, x=1, z=0) +\\
    &\quad (1 - p_x)p_z f(y, x=1, z=0) + p_x p_z f(y, x=1, z=1)\,.
\end{align*}
Using the truth table, we can substitute $y$ and evaluate the terms, such that
\begin{align*}
    \overrightarrow{\mu}_y(y) &= \begin{cases} p_x p_z &\text{ if } y=1\\
    (1 - p_x)(1 - p_z) + p_x(1 - p_z) + (1 - p_x)p_z &\text{ if } y=0 \end{cases}\\
    &= \begin{cases} p_x p_z &\text{ if } y=1\\
    1 - p_x p_z &\text{ if } y=0 \end{cases}\\
    &= \mathcal{B}er(y | p_x p_z)\,.
\end{align*}
This derivation might appear overly specific for a result that could have been optained by simply applying the product rule of probability. This derivation however follows the general recipe for deriving sum-product updates, which will prove useful when deriving the backward messages.

### Backward message for $x$
The backward message for $x$ follows from the same recipe, as
\begin{align*}
    \overleftarrow{\mu}_x(x) &= \sum_y \sum_z \overleftarrow{\mu}_y(y)\,\overrightarrow{\mu}_z(z)\,f(y, x, z)\\
    &= \sum_y \sum_z \mathcal{B}er\!(y | p_y)\,\mathcal{B}er\!(z | p_z)\,\delta(y - xz)\\
    &= \mathcal{B}er\!(y=0 | p_y)\,\mathcal{B}er\!(z=0 | p_z)\,f(y=0, x, z=0) + \mathcal{B}er\!(y=1 | p_y)\,\mathcal{B}er\!(z=0 | p_z)\,f(y=1, x, z=0) +\\
    &\quad \mathcal{B}er\!(y=0 | p_y)\,\mathcal{B}er\!(z=1 | p_z)\,f(y=0, x, z=1) + \mathcal{B}er\!(y=1 | p_y)\,\mathcal{B}er\!(z=1 | p_z)\,f(y=1, x, z=1)\,.
\end{align*}    
Note that the combination $y=1, z=0$ is disallowed by the truth table, and therefore $f(y=1, x, z=0)=0$ for both $x$. Furthermore, note that $f(y=0, x, z=0)=1$ for both $x$. This leads to
\begin{align*}    
    &= (1 - p_y)(1 - p_z)f(y=0, x, z=0) + 0 + (1 - p_y)p_z f(y=0, x, z=1) + p_y p_z f(y=1, x, z=1)\\
    &= \begin{cases}(1-p_y)(1-p_z) &\text{ if } x=1\\
        (1-p_y)(1-p_z) + p_z &\text{ if } x=0 \end{cases}\\
    &\propto \mathcal{B}er\left(x \bigg| \frac{a}{2a + p_z}\right)\,, \text{ with } a=(1-p_y)(1-p_z).
\end{align*}

### Backward message for $z$
From symmerty, the backward message for $z$ follows as
\begin{align*}
    \overleftarrow{\mu}_z(z) \propto \mathcal{B}er\left(z \bigg| \frac{a}{2a + p_x}\right)\,, \text{ with } a=(1-p_y)(1-p_x)\,.
\end{align*}

## Unit Tests
With the sum-product messages derived, best practice is to first implement unit tests for the to-be-implemented `And` factor node. Tests for already available nodes can be adapted and implemented in a `/test/factor_nodes/test_and.jl` file. The tests currently fail ofcourse.

In [1]:
module AndTest

using Test
using ForneyLab
using ForneyLab: outboundType, isApplicable
using ForneyLab: SPAndOutNBB, SPAndIn1BNB, SPAndIn2BBN


#-------------
# Update rules
#-------------

@testset "SPAndOutNBB" begin
    @test SPAndOutNBB <: SumProductRule{And}
    @test outboundType(SPAndOutNBB) == Message{Bernoulli}
    @test isApplicable(SPAndOutNBB, [Nothing, Message{Bernoulli}, Message{Bernoulli}]) 
    @test !isApplicable(SPAndOutNBB, [Message{Bernoulli}, Nothing, Message{Bernoulli}]) 

    @test ruleSPAndOutNBB(nothing, Message(Univariate, Bernoulli, p=0.4), Message(Univariate, Bernoulli, p=0.5)) == Message(Univariate, Bernoulli, p=0.2)
end

@testset "SPAndIn1BNB" begin
    @test SPAndIn1BNB <: SumProductRule{And}
    @test outboundType(SPAndIn1BNB) == Message{Bernoulli}
    @test isApplicable(SPAndIn1BNB, [Message{Bernoulli}, Nothing, Message{Bernoulli}]) 
    @test !isApplicable(SPAndIn1BNB, [Message{Bernoulli}, Message{Bernoulli}, Nothing]) 

    @test ruleSPAndIn1BNB(Message(Univariate, Bernoulli, p=0.1), nothing, Message(Univariate, Bernoulli, p=0.25)) == Message(Univariate, Bernoulli, p=0.421875)
end

@testset "SPAndIn2BBN" begin
    @test SPAndIn2BBN <: SumProductRule{And}
    @test outboundType(SPAndIn2BBN) == Message{Bernoulli}
    @test isApplicable(SPAndIn2BBN, [Message{Bernoulli}, Message{Bernoulli}, Nothing]) 
    @test !isApplicable(SPAndIn2BBN, [Nothing, Message{Bernoulli}, Message{Bernoulli}]) 

    @test ruleSPAndIn2BBN(Message(Univariate, Bernoulli, p=0.1), Message(Univariate, Bernoulli, p=0.25), nothing) == Message(Univariate, Bernoulli, p=0.421875)
end

end

UndefVarError: UndefVarError: SPAndOutNBB not defined

## Factor Node Definition
We assume that some tests are available, and define a new `And` factor node. This definition can be included in a `/src/factor_nodes/and.jl` file.

In [2]:
import Base: &
using ForneyLab
using ForneyLab: @ensureVariables, generateId, addNode!, associate!, DeltaFactor

"""
Description:

    A logical and-constraint factor node.

    f(out,in1,in2) = δ(out - in1*in2)

Interfaces:

    1. out
    2. in1
    3. in2

Construction:

    And(out, in1, in2, id=:some_id)
"""
mutable struct And <: DeltaFactor
    id::Symbol
    interfaces::Vector{Interface}
    i::Dict{Symbol,Interface}

    function And(out, in1, in2; id=generateId(And))
        @ensureVariables(out, in1, in2)
        self = new(id, Array{Interface}(undef, 3), Dict{Int,Interface}())
        addNode!(currentGraph(), self)
        self.i[:out] = self.interfaces[1] = associate!(Interface(self), out)
        self.i[:in1] = self.interfaces[2] = associate!(Interface(self), in1)
        self.i[:in2] = self.interfaces[3] = associate!(Interface(self), in2)

        return self
    end
end

slug(::Type{And}) = "&" # Symbol for node visualization

# Define extra syntax for model definition
function (&)(in1::Variable, in2::Variable)
    out = Variable()
    And(out, in1, in2)
    return out
end
;

## Update Rule Definitions

The update rules can be registered in a `/src/update_rules/and.jl` file as

In [3]:
@sumProductRule(:node_type     => And,
                :outbound_type => Message{Bernoulli},
                :inbound_types => (Nothing, Message{Bernoulli}, Message{Bernoulli}),
                :name          => SPAndOutNBB)

@sumProductRule(:node_type     => And,
                :outbound_type => Message{Bernoulli},
                :inbound_types => (Message{Bernoulli}, Nothing, Message{Bernoulli}),
                :name          => SPAndIn1BNB)

@sumProductRule(:node_type     => And,
                :outbound_type => Message{Bernoulli},
                :inbound_types => (Message{Bernoulli}, Message{Bernoulli}, Nothing),
                :name          => SPAndIn2BBN)
;

## Message Update Computations

Finally, the actual message update computations can be implemented in a `/src/engines/julia/update_rules/and.jl` file, as

In [4]:
function ruleSPAndOutNBB(msg_out::Nothing, msg_in1::Message{Bernoulli}, msg_in2::Message{Bernoulli})
    p_in1 = msg_in1.dist.params[:p]
    p_in2 = msg_in2.dist.params[:p]

    return Message(Univariate, Bernoulli, p=p_in1*p_in2)
end

function ruleSPAndIn1BNB(msg_out::Message{Bernoulli}, msg_in1::Nothing, msg_in2::Message{Bernoulli})
    p_out = msg_out.dist.params[:p]
    p_in2 = msg_in2.dist.params[:p]
    a = (1 - p_out)*(1 - p_in2)

    return Message(Univariate, Bernoulli, p=a/(2*a + p_in2))
end

function ruleSPAndIn1BNB(msg_out::Message{Bernoulli}, msg_in1::Message{Bernoulli}, msg_in2::Nothing)
    p_out = msg_out.dist.params[:p]
    p_in1 = msg_in1.dist.params[:p]
    a = (1 - p_out)*(1 - p_in1)

    return Message(Univariate, Bernoulli, p=a/(2*a + p_in1))
end
;

## Using the And Node

We can now use the `And` node for modeling and inference.

In [5]:
# Define a probabilistic model
fg = FactorGraph()

@RV x ~ Bernoulli(0.5)
@RV z ~ Bernoulli(0.5)
@RV y = x & z # Shorthand definition for the And node
Bernoulli(y, 0.9)
;

In [6]:
ForneyLab.draw(fg) # Inspect the factor graph

In [7]:
# Derive a message passing algorithm
algo = messagePassingAlgorithm(x)
code = algorithmSourceCode(algo)
eval(Meta.parse(code))
;

In [8]:
println(code) # Inspect algorithm source code

begin

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 4))

messages[1] = ruleSPBernoulliOutNP(nothing, Message(Univariate, PointMass, m=0.5))
messages[2] = ruleSPBernoulliOutNP(nothing, Message(Univariate, PointMass, m=0.9))
messages[3] = ruleSPBernoulliOutNP(nothing, Message(Univariate, PointMass, m=0.5))
messages[4] = ruleSPAndIn1BNB(messages[2], nothing, messages[3])

marginals[:x] = messages[1].dist * messages[4].dist

return marginals

end

end # block


In [9]:
data = Dict()
marginals = step!(data)
marginals[:x] # Inspect result

Ber(p=0.08)
