Skip to content

Commit

Permalink
Fixes the removal of some Python derived objects produced when a Pyth…
Browse files Browse the repository at this point in the history
…on derived type is passed as argument of some functions.

This commit solves this problem by ensuring the Python side is kept alive.

Related to: pybind/pybind11#1333
  • Loading branch information
davenza committed Jan 23, 2022
1 parent e473028 commit d070afc
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 66 deletions.
59 changes: 55 additions & 4 deletions pybnesian/factors/factors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@ class FactorType {

virtual bool is_python_derived() const { return false; }

static std::shared_ptr<FactorType> keep_python_alive(std::shared_ptr<FactorType>& f) {
static std::shared_ptr<FactorType>& keep_python_alive(std::shared_ptr<FactorType>& f) {
if (f && f->is_python_derived()) {
auto o = py::cast(f);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<FactorType*>();
f = std::shared_ptr<FactorType>(keep_python_state_alive, ptr);
}

return f;
}

static std::shared_ptr<FactorType> keep_python_alive(const std::shared_ptr<FactorType>& f) {
if (f && f->is_python_derived()) {
auto o = py::cast(f);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand All @@ -47,12 +58,21 @@ class FactorType {
return f;
}

static std::vector<std::shared_ptr<FactorType>> keep_vector_python_alive(
static std::vector<std::shared_ptr<FactorType>>& keep_vector_python_alive(
std::vector<std::shared_ptr<FactorType>>& v) {
for (auto& f : v) {
FactorType::keep_python_alive(f);
}

return v;
}

static std::vector<std::shared_ptr<FactorType>> keep_vector_python_alive(
const std::vector<std::shared_ptr<FactorType>>& v) {
std::vector<std::shared_ptr<FactorType>> fv;
fv.reserve(v.size());

for (auto& f : v) {
for (const auto& f : v) {
fv.push_back(FactorType::keep_python_alive(f));
}

Expand Down Expand Up @@ -105,7 +125,18 @@ class Factor {

virtual bool is_python_derived() const { return false; }

static std::shared_ptr<Factor> keep_python_alive(std::shared_ptr<Factor>& f) {
static std::shared_ptr<Factor>& keep_python_alive(std::shared_ptr<Factor>& f) {
if (f && f->is_python_derived()) {
auto o = py::cast(f);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<Factor*>();
f = std::shared_ptr<Factor>(keep_python_state_alive, ptr);
}

return f;
}

static std::shared_ptr<Factor> keep_python_alive(const std::shared_ptr<Factor>& f) {
if (f && f->is_python_derived()) {
auto o = py::cast(f);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand All @@ -116,6 +147,26 @@ class Factor {
return f;
}

static std::vector<std::shared_ptr<Factor>>& keep_vector_python_alive(std::vector<std::shared_ptr<Factor>>& v) {
for (auto& f : v) {
Factor::keep_python_alive(f);
}

return v;
}

static std::vector<std::shared_ptr<Factor>> keep_vector_python_alive(
const std::vector<std::shared_ptr<Factor>>& v) {
std::vector<std::shared_ptr<Factor>> fv;
fv.reserve(v.size());

for (const auto& f : v) {
fv.push_back(Factor::keep_python_alive(f));
}

return fv;
}

const std::string& variable() const { return m_variable; }

const std::vector<std::string>& evidence() const { return m_evidence; }
Expand Down
13 changes: 12 additions & 1 deletion pybnesian/kde/BandwidthSelector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@ class BandwidthSelector {

virtual bool is_python_derived() const { return false; }

static std::shared_ptr<BandwidthSelector> keep_python_alive(std::shared_ptr<BandwidthSelector>& b) {
static std::shared_ptr<BandwidthSelector>& keep_python_alive(std::shared_ptr<BandwidthSelector>& b) {
if (b && b->is_python_derived()) {
auto o = py::cast(b);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<BandwidthSelector*>();
b = std::shared_ptr<BandwidthSelector>(keep_python_state_alive, ptr);
}

return b;
}

static std::shared_ptr<BandwidthSelector> keep_python_alive(const std::shared_ptr<BandwidthSelector>& b) {
if (b && b->is_python_derived()) {
auto o = py::cast(b);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand Down
57 changes: 56 additions & 1 deletion pybnesian/learning/operators/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,18 @@ class Operator {
virtual bool operator==(const Operator& a) const = 0;
bool operator!=(const Operator& a) const { return !(*this == a); }

static std::shared_ptr<Operator> keep_python_alive(std::shared_ptr<Operator>& op) {
static std::shared_ptr<Operator>& keep_python_alive(std::shared_ptr<Operator>& op) {
if (op && op->is_python_derived()) {
auto o = py::cast(op);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<Operator*>();
op = std::shared_ptr<Operator>(keep_python_state_alive, ptr);
}

return op;
}

static std::shared_ptr<Operator> keep_python_alive(const std::shared_ptr<Operator>& op) {
if (op && op->is_python_derived()) {
auto o = py::cast(op);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand Down Expand Up @@ -330,6 +341,7 @@ class OperatorSet {
public:
OperatorSet() : m_local_cache(nullptr), m_owns_local_cache(false) {}
virtual ~OperatorSet() {}
virtual bool is_python_derived() const { return false; }
virtual void cache_scores(const BayesianNetworkBase&, const Score&) = 0;
virtual std::shared_ptr<Operator> find_max(const BayesianNetworkBase&) const = 0;
virtual std::shared_ptr<Operator> find_max(const BayesianNetworkBase&, const OperatorTabuSet&) const = 0;
Expand All @@ -356,6 +368,49 @@ class OperatorSet {
virtual void set_type_whitelist(const FactorTypeVector&){};
virtual void finished() { m_local_cache = nullptr; }

static std::shared_ptr<OperatorSet>& keep_python_alive(std::shared_ptr<OperatorSet>& op_set) {
if (op_set && op_set->is_python_derived()) {
auto o = py::cast(op_set);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<OperatorSet*>();
op_set = std::shared_ptr<OperatorSet>(keep_python_state_alive, ptr);
}

return op_set;
}

static std::shared_ptr<OperatorSet> keep_python_alive(const std::shared_ptr<OperatorSet>& op_set) {
if (op_set && op_set->is_python_derived()) {
auto o = py::cast(op_set);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<OperatorSet*>();
return std::shared_ptr<OperatorSet>(keep_python_state_alive, ptr);
}

return op_set;
}

static std::vector<std::shared_ptr<OperatorSet>>& keep_vector_python_alive(
std::vector<std::shared_ptr<OperatorSet>>& v) {
for (auto& op_set : v) {
OperatorSet::keep_python_alive(op_set);
}

return v;
}

static std::vector<std::shared_ptr<OperatorSet>> keep_vector_python_alive(
const std::vector<std::shared_ptr<OperatorSet>>& v) {
std::vector<std::shared_ptr<OperatorSet>> fv;
fv.reserve(v.size());

for (const auto& op_set : v) {
fv.push_back(OperatorSet::keep_python_alive(op_set));
}

return fv;
}

protected:
bool owns_local_cache() const { return m_owns_local_cache; }

Expand Down
40 changes: 37 additions & 3 deletions pybnesian/models/BayesianNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,18 @@ class BayesianNetworkBase : public std::enable_shared_from_this<BayesianNetworkB
}
}

static std::shared_ptr<BayesianNetworkBase> keep_python_alive(std::shared_ptr<BayesianNetworkBase>& m) {
static std::shared_ptr<BayesianNetworkBase>& keep_python_alive(std::shared_ptr<BayesianNetworkBase>& m) {
if (m && m->is_python_derived()) {
auto o = py::cast(m);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<BayesianNetworkBase*>();
m = std::shared_ptr<BayesianNetworkBase>(keep_python_state_alive, ptr);
}

return m;
}

static std::shared_ptr<BayesianNetworkBase> keep_python_alive(const std::shared_ptr<BayesianNetworkBase>& m) {
if (m && m->is_python_derived()) {
auto o = py::cast(m);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand Down Expand Up @@ -182,8 +193,20 @@ class ConditionalBayesianNetworkBase : public BayesianNetworkBase {
}
}

static std::shared_ptr<ConditionalBayesianNetworkBase> keep_python_alive(
static std::shared_ptr<ConditionalBayesianNetworkBase>& keep_python_alive(
std::shared_ptr<ConditionalBayesianNetworkBase>& m) {
if (m && m->is_python_derived()) {
auto o = py::cast(m);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<ConditionalBayesianNetworkBase*>();
m = std::shared_ptr<ConditionalBayesianNetworkBase>(keep_python_state_alive, ptr);
}

return m;
}

static std::shared_ptr<ConditionalBayesianNetworkBase> keep_python_alive(
const std::shared_ptr<ConditionalBayesianNetworkBase>& m) {
if (m && m->is_python_derived()) {
auto o = py::cast(m);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand All @@ -208,7 +231,18 @@ class BayesianNetworkType {
virtual std::shared_ptr<ConditionalBayesianNetworkBase> new_cbn(
const std::vector<std::string>& nodes, const std::vector<std::string>& interface_nodes) const = 0;

static std::shared_ptr<BayesianNetworkType> keep_python_alive(std::shared_ptr<BayesianNetworkType>& s) {
static std::shared_ptr<BayesianNetworkType>& keep_python_alive(std::shared_ptr<BayesianNetworkType>& s) {
if (s && s->is_python_derived()) {
auto o = py::cast(s);
auto keep_python_state_alive = std::make_shared<py::object>(o);
auto ptr = o.cast<BayesianNetworkType*>();
s = std::shared_ptr<BayesianNetworkType>(keep_python_state_alive, ptr);
}

return s;
}

static std::shared_ptr<BayesianNetworkType> keep_python_alive(const std::shared_ptr<BayesianNetworkType>& s) {
if (s && s->is_python_derived()) {
auto o = py::cast(s);
auto keep_python_state_alive = std::make_shared<py::object>(o);
Expand Down
10 changes: 9 additions & 1 deletion pybnesian/models/HeterogeneousBN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

namespace models {

MapDataToFactor keep_MapDataToFactor_alive(MapDataToFactor& m) {
MapDataToFactor& keep_MapDataToFactor_alive(MapDataToFactor& m) {
for (auto& item : m) {
FactorType::keep_vector_python_alive(item.second);
}

return m;
}

MapDataToFactor keep_MapDataToFactor_alive(const MapDataToFactor& m) {
MapDataToFactor alive;

for (auto& item : m) {
Expand Down
6 changes: 5 additions & 1 deletion pybnesian/models/HeterogeneousBN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ class DataTypeEqualTo {
using MapDataToFactor = std::
unordered_map<std::shared_ptr<DataType>, std::vector<std::shared_ptr<FactorType>>, DataTypeHash, DataTypeEqualTo>;

MapDataToFactor keep_MapDataToFactor_alive(MapDataToFactor& m);
MapDataToFactor& keep_MapDataToFactor_alive(MapDataToFactor& m);
MapDataToFactor keep_MapDataToFactor_alive(const MapDataToFactor& m);

class HeterogeneousBNType : public BayesianNetworkType {
public:
HeterogeneousBNType(const HeterogeneousBNType&) = delete;
void operator=(const HeterogeneousBNType&) = delete;

HeterogeneousBNType(HeterogeneousBNType&&) = default;
HeterogeneousBNType& operator=(HeterogeneousBNType&&) = default;

HeterogeneousBNType(std::vector<std::shared_ptr<FactorType>> default_ft)
: m_default_ftype(default_ft), m_default_ftypes(), m_single_default(true) {
if (default_ft.empty()) throw std::invalid_argument("Default factor_type cannot be empty.");
Expand Down
3 changes: 3 additions & 0 deletions pybnesian/models/HomogeneousBN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class HomogeneousBNType : public BayesianNetworkType {
HomogeneousBNType(const HomogeneousBNType&) = delete;
void operator=(const HomogeneousBNType&) = delete;

HomogeneousBNType(HomogeneousBNType&&) = default;
HomogeneousBNType& operator=(HomogeneousBNType&&) = default;

HomogeneousBNType(std::shared_ptr<FactorType> ft) : m_ftype(ft) {
if (ft == nullptr) throw std::invalid_argument("factor_type cannot be null.");

Expand Down
26 changes: 19 additions & 7 deletions pybnesian/pybindings/pybindings_factors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class PyFactorType : public FactorType {

try {
auto f = o.cast<std::shared_ptr<Factor>>();
return Factor::keep_python_alive(f);
Factor::keep_python_alive(f);
return f;
} catch (py::cast_error& e) {
throw std::runtime_error("The returned object of FactorType::new_factor is not a Factor.");
}
Expand All @@ -78,7 +79,8 @@ class PyFactorType : public FactorType {

try {
auto f = o.cast<std::shared_ptr<Factor>>();
return Factor::keep_python_alive(f);
Factor::keep_python_alive(f);
return f;
} catch (py::cast_error& e) {
throw std::runtime_error("The returned object of FactorType::new_factor is not a Factor.");
}
Expand Down Expand Up @@ -160,7 +162,7 @@ class PyFactor : public Factor {
try {
m_type = o.cast<std::shared_ptr<FactorType>>();
// Keep the type in the class member, so type_ref() can return a valid reference.
m_type = FactorType::keep_python_alive(m_type);
FactorType::keep_python_alive(m_type);
return m_type;
} catch (py::cast_error& e) {
throw std::runtime_error("The returned object of Factor::type is not a FactorType.");
Expand Down Expand Up @@ -735,10 +737,20 @@ Removes the assignment for the ``variable``.

py::class_<HCKDE, Factor, std::shared_ptr<HCKDE>>(root, "HCKDE")
.def(py::init<std::string, std::vector<std::string>>())
.def(py::init<std::string, std::vector<std::string>, std::shared_ptr<BandwidthSelector>>())
.def(py::init<std::string,
std::vector<std::string>,
std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash>>())
.def(py::init<>([](std::string variable,
std::vector<std::string> evidence,
std::shared_ptr<BandwidthSelector> bandwidth_selector) {
return HCKDE(variable, evidence, BandwidthSelector::keep_python_alive(bandwidth_selector));
}), py::arg("variable"), py::arg("evidence"), py::arg("bandwidth_selector"))
.def(py::init<>([](std::string variable,
std::vector<std::string> evidence,
std::unordered_map<Assignment, std::tuple<std::shared_ptr<BandwidthSelector>>, AssignmentHash> args) {
for (auto& arg : args) {
BandwidthSelector::keep_python_alive(std::get<0>(arg.second));
}

return HCKDE(variable, evidence, args);
}), py::arg("variable"), py::arg("evidence"), py::arg("bandwidth_selector"))
.def("conditional_factor", &HCKDE::conditional_factor, py::return_value_policy::reference_internal)
.def(py::pickle([](const HCKDE& self) { return self.__getstate__(); },
[](py::tuple t) { return HCKDE::__setstate__(t); }));
Expand Down
Loading

0 comments on commit d070afc

Please sign in to comment.