Skip to content

Commit

Permalink
Merge pull request #31 from varunagrawal/more-improvements-2
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Feb 26, 2022
2 parents 3e59b0e + 8619c7c commit 2ed3f5f
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 12 deletions.
10 changes: 5 additions & 5 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
void GaussianMixture::print(const std::string &s,
const KeyFormatter &keyFormatter) const {
std::cout << (s.empty() ? "" : s + " ");
std::cout << "GaussianMixture [";
std::cout << "GaussianMixture [ ";

for (Key key : frontals()) std::cout << keyFormatter(key) << " ";
std::cout << "| ";
if (parents().size()) std::cout << "| ";
for (Key key : parents()) std::cout << keyFormatter(key) << " ";

std::cout << "]";
std::cout << "{\n";

auto valueFormatter = [](const GaussianFactor::shared_ptr &v) {
auto printCapture = [](const GaussianFactor::shared_ptr &p) {
auto valueFormatter = [&](const GaussianFactor::shared_ptr &v) {
auto printCapture = [&](const GaussianFactor::shared_ptr &p) {
RedirectCout rd;
p->print();
p->print("", keyFormatter);
std::string s = rd.str();
return s;
};
Expand Down
9 changes: 7 additions & 2 deletions gtsam/hybrid/IncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
// We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph
std::unordered_set<Key> allVars(ordering.begin(), ordering.end());
for (auto &&conditional : hybridBayesNet_) {

// TODO(Varun) Using a for-range loop doesn't work since some of the
// conditionals are invalid pointers
for (size_t i = 0; i < hybridBayesNet_.size(); i++) {
auto conditional = hybridBayesNet_.at(i);
// Flag indicating if a conditional will be updated due to factors in
// `graph`
bool marked_for_update = false;
Expand All @@ -53,7 +57,8 @@ void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
// If a conditional is due to be updated, we remove if from the
// previous bayes net.
if (marked_for_update) {
auto it = find(hybridBayesNet_.begin(), hybridBayesNet_.end(), conditional);
auto it = find(hybridBayesNet_.begin(), hybridBayesNet_.end(),
conditional);
hybridBayesNet_.erase(it);
}
break;
Expand Down
31 changes: 30 additions & 1 deletion gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ virtual class GaussianMixture : gtsam::DCGaussianMixtureFactor {
class DCFactorGraph {
DCFactorGraph();
gtsam::DiscreteKeys discreteKeys() const;

size_t size() const;
bool empty() const;
void remove(size_t i);
void resize(size_t size);
size_t nrFactors() const;

void print(const std::string& str = "DCFactorGraph",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/hybrid/HybridFactorGraph.h>
Expand Down Expand Up @@ -84,6 +94,11 @@ class NonlinearHybridFactorGraph {

const gtsam::NonlinearFactorGraph& nonlinearGraph() const;

//TODO(Varun) issues with templated inheritance
const gtsam::DiscreteFactorGraph& discreteGraph() const;
const gtsam::DCFactorGraph& dcGraph() const;
gtsam::DiscreteKeys discreteKeys() const;

gtsam::GaussianHybridFactorGraph linearize(
const gtsam::Values& continuousValues) const;

Expand Down Expand Up @@ -113,13 +128,27 @@ class GaussianHybridFactorGraph {
void print(const std::string& str = "GaussianHybridFactorGraph",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

//TODO(Varun) issues with templated inheritance
const gtsam::DiscreteFactorGraph& discreteGraph() const;
const gtsam::DCFactorGraph& dcGraph() const;
gtsam::DiscreteKeys discreteKeys() const;

};

#include <gtsam/hybrid/IncrementalHybrid.h>

class IncrementalHybrid {
IncrementalHybrid();

void update(gtsam::GaussianHybridFactorGraph graph,
const gtsam::Ordering& ordering);
const gtsam::Ordering& ordering,
boost::optional<size_t> maxNrLeaves = nullptr);

GaussianMixture* gaussianMixture(size_t index) const;
const DiscreteFactorGraph& remainingDiscreteGraph() const;
const HybridBayesNet& hybridBayesNet() const;
const GaussianHybridFactorGraph& remainingFactorGraph() const;
};

} // namespace gtsam
6 changes: 3 additions & 3 deletions gtsam/hybrid/tests/testHybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ TEST(HybridFactorGraph, Printing) {
// Expected output for hybridBayesNet.
string expected_hybridBayesNet = R"(
size: 3
factor 0: GaussianMixture [x1 | x2 m1 ]{
factor 0: GaussianMixture [ x1 | x2 m1 ]{
Choice(m1)
0 Leaf Jacobian factor on 2 keys:
Conditional density [x1]
Expand All @@ -572,7 +572,7 @@ factor 0: GaussianMixture [x1 | x2 m1 ]{
}
factor 1: GaussianMixture [x2 | x3 m2 m1 ]{
factor 1: GaussianMixture [ x2 | x3 m2 m1 ]{
Choice(m2)
0 Choice(m1)
0 0 Leaf Jacobian factor on 2 keys:
Expand Down Expand Up @@ -609,7 +609,7 @@ factor 1: GaussianMixture [x2 | x3 m2 m1 ]{
}
factor 2: GaussianMixture [x3 | m2 m1 ]{
factor 2: GaussianMixture [ x3 | m2 m1 ]{
Choice(m2)
0 Choice(m1)
0 0 Leaf Jacobian factor on 1 keys:
Expand Down
2 changes: 2 additions & 0 deletions gtsam/inference/AbstractConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class GTSAM_EXPORT AbstractConditional {

virtual Parents parents() const = 0;
/// @}

//TODO(Varun) add iterators so we can call hybridBayesNet.keyVector()
};

/// traits
Expand Down
9 changes: 9 additions & 0 deletions python/gtsam/preamble/hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,12 @@
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/

#include <pybind11/stl.h>

// Support for binding boost::optional types in C++11.
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
namespace pybind11 { namespace detail {
template <typename T>
struct type_caster<boost::optional<T>> : optional_caster<boost::optional<T>> {};
}}
8 changes: 7 additions & 1 deletion python/gtsam/tests/test_Hybrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gtsam
import numpy as np
from gtsam import GaussianHybridFactorGraph
from gtsam import GaussianHybridFactorGraph, IncrementalHybrid
from gtsam.utils.test_case import GtsamTestCase


Expand All @@ -11,3 +11,9 @@ def setUp(self) -> None:
def test_elimination(self):
# Check if constructed correctly
self.assertIsInstance(self.ghfg, GaussianHybridFactorGraph)

def test_incremental(self):
ordering = gtsam.Ordering()
inc = IncrementalHybrid()
inc.update(self.ghfg, ordering)
# inc.update(self.ghfg, ordering, 4)

0 comments on commit 2ed3f5f

Please sign in to comment.