Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@ffg macro #131

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -18,9 +19,9 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
Documenter = "0.25.2"
ForwardDiff = "0.10.12"
SpecialFunctions = "0.8.0, 0.10.3"
StatsBase = "0.32.2, 0.33.1"
StatsFuns = "0.9.5"
julia = "1"
1,030 changes: 1,030 additions & 0 deletions demo/FFG_nonlinear_kalman_filter.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/ForneyLab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ include("factor_graph.jl")
# Composite nodes
include("factor_nodes/composite.jl")

# Code generation for models
include("codegen/graph.jl")
include("codegen/helpers.jl")
include("codegen/variable.jl")

# Generic methods
include("algorithms/cluster.jl")
include("message_passing.jl")
Expand Down
35 changes: 35 additions & 0 deletions src/algorithms/inference_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,41 @@ messagePassingAlgorithm(target_variable::Variable,
id=Symbol(""),
free_energy=false) = messagePassingAlgorithm([target_variable], pfz; id=id, free_energy=free_energy)


# Shorthands for algorithm compilation by passing only variable ids
function messagePassingAlgorithm(target_variable_ids::Vector{Symbol}, # Quantities of interest
pfz::PosteriorFactorization=currentPosteriorFactorization();
id=Symbol(""),
free_energy=false)

target_variables = Vector{Variable}(undef, length(target_variable_ids))

for (i, target_variable_id) in enumerate(target_variable_ids)
target_variable = get(currentGraph().variables, target_variable_id, nothing)
if isnothing(target_variable)
error("Variable with id $(target_variable_id) does not exist.")
else
target_variables[i] = target_variable
end
end

return messagePassingAlgorithm(target_variables, pfz; id=id, free_energy=free_energy)
end


function messagePassingAlgorithm(target_variable_id::Symbol,
pfz::PosteriorFactorization=currentPosteriorFactorization();
id=Symbol(""),
free_energy=false)

target_variable = get(currentGraph().variables, target_variable_id, nothing)
if isnothing(target_variable)
error("Variable with id $(target_variable_id) does not exist.")
end

return messagePassingAlgorithm([target_variable], pfz; id=id, free_energy=free_energy)
end

function interfaceToScheduleEntry(algo::InferenceAlgorithm)
mapping = Dict{Interface, ScheduleEntry}()
for (id, pf) in algo.posterior_factorization
Expand Down
26 changes: 26 additions & 0 deletions src/algorithms/posterior_factor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,32 @@ end
PosteriorFactor(seed_variable::Variable; pfz=currentPosteriorFactorization(), id=generateId(PosteriorFactor)) = PosteriorFactor(Set([seed_variable]), pfz=pfz, id=id)
PosteriorFactor(seed_variables::Vector{Variable}; pfz=currentPosteriorFactorization(), id=generateId(PosteriorFactor)) = PosteriorFactor(Set(seed_variables), pfz=pfz, id=id)

function PosteriorFactor(seed_variable_id::Symbol; pfz=currentPosteriorFactorization(), id=generateId(PosteriorFactor))

seed_variable = get(currentGraph().variables, seed_variable_id, nothing)
if isnothing(seed_variable)
error("Variable with id $(seed_variable_id) does not exist.")
end

return PosteriorFactor(Set([seed_variable]), pfz=pfz, id=id)
end

function PosteriorFactor(seed_variable_ids::Vector{Symbol}; pfz=currentPosteriorFactorization(), id=generateId(PosteriorFactor))
seed_variables = Vector{Variable}(undef, length(seed_variable_ids))

for (i, seed_variable_id) in enumerate(seed_variable_ids)
seed_variable = get(currentGraph().variables, seed_variable_id, nothing)
if isnothing(seed_variable)
error("Variable with id $(seed_variable_id) does not exist.")
else
seed_variables[i] = seed_variable
end
end

return PosteriorFactor(Set(seed_variables), pfz=pfz, id=id)
end


"""
`messagePassingSchedule()` generates a message passing schedule for the posterior factor
"""
Expand Down
16 changes: 16 additions & 0 deletions src/algorithms/posterior_factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ function PosteriorFactorization(args::Vararg{Union{T, Set{T}, Vector{T}} where T
return pfz
end

"""
Construct a `PosteriorFactorization` consisting of one `PosteriorFactor` for each argument addressed by its id
"""
function PosteriorFactorization(args::Vararg{Union{Symbol, Set{Symbol}, Vector{Symbol}}}; ids=Symbol[])
pfz = PosteriorFactorization()
isempty(ids) || (length(ids) == length(args)) || error("Length of ids must match length of posterior factor arguments")
for (i, arg) in enumerate(args)
if isempty(ids)
PosteriorFactor(arg, id=generateId(PosteriorFactor))
else
PosteriorFactor(arg, id=ids[i])
end
end
return pfz
end

iterate(pfz::PosteriorFactorization) = iterate(pfz.posterior_factors)
iterate(pfz::PosteriorFactorization, state) = iterate(pfz.posterior_factors, state)
values(pfz::PosteriorFactorization) = values(pfz.posterior_factors)
Expand Down
26 changes: 26 additions & 0 deletions src/codegen/graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
export @ffg

using MacroTools: postwalk, rmlines, prettify, @capture

macro ffg(expr::Expr)
return esc(postwalk(rmlines, generate_model(expr)))
end

function generate_model(expr::Expr)

@capture(expr, (mname_(margs__) = body_) | (function mname_(margs__) body_ end)) || error("Model definition has to be a function.")

body = postwalk(rewrite_expression, body)

graph_sym = gensym(:factor_graph)

result = quote
function $mname($(margs...))
$(graph_sym) = FactorGraph()
$body
return $(graph_sym)
end
end

return result
end
23 changes: 23 additions & 0 deletions src/codegen/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Extract options dictionary from expression
function get_options(expr::Expr)
options = Dict{Symbol,Any}()

expr = postwalk(expr) do x
if @capture(x, lhs_ where {exoptions__})
for option in exoptions
@capture(option, key_ = value_)
options[key] = value
end
return lhs
end
return x
end

if !isempty(options)
return expr, options
else
return expr, nothing
end
end

get_options(a::Any) = a, nothing
88 changes: 88 additions & 0 deletions src/codegen/variable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
function rewrite_expression(expression::Expr)
expr = if @capture(expression, var_ ~ rhs_)
rewrite_tilde_expression(var, rhs)
elseif @capture(expression, var_ = rhs_)
(rhs, options) = get_options(rhs)
if options === nothing
return expression
end
rewrite_assign_expression(var, rhs, options)
elseif is_for(expression)
rewrite_for_block(expression)
else
expression
end
return expr
end

rewrite_expression(ex::Any) = ex

function rewrite_tilde_expression(var, rhs)

(rhs, options) = get_options(rhs)
@capture(rhs, pdist_(params__))

var_id = extract_variable_id(var, options)

# Build total expression
return :(
begin
# Use existing Variable if it exists, otherwise create a new one
$(var) = try
$(var)
catch _
Variable(id = $(var_id))
end

# Create new variable if:
# - the existing object is not a Variable
# - the existing object is a Variable from another FactorGraph
if (!isa($(var), Variable)
|| !haskey(currentGraph().variables, $(var).id)
|| currentGraph().variables[$(var).id] !== $(var))

$(var) = Variable(id = $(var_id))
end

$(pdist)($(var), $(params...))
$(var)
end
)
end

function rewrite_assign_expression(var, rhs, options)

var_id = extract_variable_id(var, options)

var_id_sym = gensym()

return :(
begin
$(var) = $(rhs)
$(var_id_sym) = $(var_id)
if $(var_id_sym) != :auto
# update id of newly created Variable
currentGraph().variables[$(var_id_sym)] = $(var)
delete!(currentGraph().variables, $(var).id)
$(var).id = $(var_id_sym)
end
$(var)
end
)
end

# for loop
is_for(expr::Expr) = expr.head === :for
is_for(expr) = false

function rewrite_for_block(def)
body_block = def.args[2]

for (i, expr) in enumerate(body_block.args)
body_block.args[i] = rewrite_expression(expr)
end

return quote
$(def)
end
end
4 changes: 2 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ end
# If variable expression is a symbol
# RV x ...
function extract_variable_id(expr::Symbol, options)
if haskey(options, :id)
if (options !== nothing) && haskey(options, :id)
return check_id_available(options[:id])
else
return guard_variable_id(:($(string(expr))))
Expand All @@ -194,7 +194,7 @@ end
# If variable expression is an indexing expression
# RV x[i] ...
function extract_variable_id(expr::Expr, options)
if haskey(options, :id)
if (options !== nothing) && haskey(options, :id)
return check_id_available(options[:id])
else
argstr = map(arg -> :(string($arg)), @view expr.args[2:end])
Expand Down