Skip to content

Commit

Permalink
Improve interpolation docs and bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Dec 5, 2023
1 parent b3f303a commit 9100ff7
Show file tree
Hide file tree
Showing 24 changed files with 241 additions and 72 deletions.
19 changes: 7 additions & 12 deletions include/gsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,19 @@

#include <gsl/gsl_interp.h>

#include "interpolation.h"

namespace gauss2d::fit {

/**
* See GSL docs for 1D interpolation types, or (as of 2.7):
* https://www.gnu.org/software/gsl/doc/html/interp.html#d-interpolation-types
*/
enum class GSLInterpType {
linear, ///< Linear interpolation.
polynomial, ///< Polynomial interpolation.
cspline, ///< Cubic spline with natural boundary conditions.
akima, ///< Non-rounded Akima spline with natural boundary conditions.
};

static const std::unordered_map<GSLInterpType, const gsl_interp_type*> GSLInterpTypes {
{GSLInterpType::linear, gsl_interp_linear},
{GSLInterpType::polynomial, gsl_interp_polynomial},
{GSLInterpType::cspline, gsl_interp_cspline},
{GSLInterpType::akima, gsl_interp_akima}
static const std::unordered_map<InterpType, const gsl_interp_type*> GSLInterpTypes {
{InterpType::linear, gsl_interp_linear},
{InterpType::polynomial, gsl_interp_polynomial},
{InterpType::cspline, gsl_interp_cspline},
{InterpType::akima, gsl_interp_akima}
};
} // namespace gauss2d::fit

Expand Down
12 changes: 10 additions & 2 deletions include/gslinterpolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifdef GAUSS2D_FIT_HAS_GSL

#include "gsl.h"
#include "interpolation.h"

#include "gauss2d/object.h"

#include <gsl/gsl_errno.h>
Expand All @@ -23,21 +25,27 @@ class GSLInterpolator : public Object {
std::vector<double> _y;

public:
const GSLInterpType interp_type;
const InterpType interp_type;
static constexpr InterpType INTERPTYPE_DEFAULT = InterpType::cspline;

/// Get the interpolant value for a knot of the given index
double get_knot_x(size_t idx) const;
/// Get the interpolated function value for a knot of the given index
double get_knot_y(size_t idx) const;

/// Get the interpolated function value for a given interpolant value
double eval(double x) const;
/// Get the derivative of the interpolated function value for a given interpolant value
double eval_deriv(double x) const;
/// Get the number of knots
size_t size() const;

std::string repr(bool name_keywords = false) const override;
std::string str() const override;

explicit GSLInterpolator(
std::vector<double> x, std::vector<double> y,
const GSLInterpType interp_type = GSLInterpType::cspline);
InterpType interp_type = INTERPTYPE_DEFAULT);
~GSLInterpolator();
};

Expand Down
8 changes: 6 additions & 2 deletions include/gslsersicmixinterpolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ namespace gauss2d::fit {
class GSLSersicMixInterpolator : public SersicMixInterpolator {
private:
mutable double _final_correction = 1.;
const InterpType _interp_type;
const unsigned short _order;
std::vector<std::pair<std::unique_ptr<GSLInterpolator>, std::unique_ptr<GSLInterpolator>>> _interps;

public:
bool correct_final_integral = true;
const GSLInterpType interp_type;
static constexpr InterpType INTERPTYPE_DEFAULT = InterpType::cspline;
/// The knot positions and values.
const std::vector<SersicMixValues>& knots;

/// Get the multiplicative factor required to adjust the integral for the final order
/// component such that the sum of all integral factors is unity (normalized).
double get_final_correction() const;

std::vector<IntegralSize> get_integralsizes(double sersicindex) const override;
std::vector<IntegralSize> get_integralsizes_derivs(double sersicindex) const override;

InterpType get_interptype() const override;
unsigned short get_order() const override;

const double sersicindex_min;
Expand All @@ -44,7 +48,7 @@ class GSLSersicMixInterpolator : public SersicMixInterpolator {

explicit GSLSersicMixInterpolator(
unsigned short order = SERSICMIX_ORDER_DEFAULT,
const GSLInterpType interp_type=GSLInterpType::cspline);
InterpType interp_type=INTERPTYPE_DEFAULT);
~GSLSersicMixInterpolator();
};

Expand Down
15 changes: 15 additions & 0 deletions include/interpolation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef GAUSS2DFIT_INTERPOLATION_H
#define GAUSS2DFIT_INTERPOLATION_H

namespace gauss2d::fit {

enum class InterpType {
linear, ///< Linear interpolation.
polynomial, ///< Polynomial interpolation.
cspline, ///< Cubic spline interpolation.
akima, ///< Akima spline with natural boundary conditions.
};

} // namespace gauss2d::fit

#endif // GAUSS2DFIT_INTERPOLATION_H
1 change: 1 addition & 0 deletions include/linearsersicmixinterpolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LinearSersicMixInterpolator : public SersicMixInterpolator {
std::vector<IntegralSize> get_integralsizes(double sersicindex) const override;
std::vector<IntegralSize> get_integralsizes_derivs(double sersicindex) const override;

InterpType get_interptype() const override;
unsigned short get_order() const override;

const double sersicindex_min;
Expand Down
1 change: 1 addition & 0 deletions include/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ headers = [
'gslinterpolator.h',
'gslsersicmixinterpolator.h',
'integralmodel.h',
'interpolation.h',
'linearintegralmodel.h',
'linearsersicmixinterpolator.h',
'math.h',
Expand Down
2 changes: 2 additions & 0 deletions include/sersicmix.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stdexcept>
#include <vector>

#include "interpolation.h"
#include "gauss2d/object.h"

namespace gauss2d::fit {
Expand Down Expand Up @@ -39,6 +40,7 @@ class SersicMixInterpolator : public Object {
virtual std::vector<IntegralSize> get_integralsizes(double sersicindex) const = 0;
virtual std::vector<IntegralSize> get_integralsizes_derivs(double sersicindex) const = 0;

virtual InterpType get_interptype() const = 0;
virtual unsigned short get_order() const = 0;
};

Expand Down
3 changes: 2 additions & 1 deletion include/sersicmixcomponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class SersicMixComponentIndexParameter : public SersicIndexParameter {
/// Return the size ratio derivative for a given Gaussian sub-component index
double get_sizeratio_deriv(unsigned short index) const;

unsigned short order;
InterpType get_interptype() const;
unsigned short get_order() const;

void set_value(double value) override;
void set_value_transformed(double value_transformed) override;
Expand Down
7 changes: 1 addition & 6 deletions python/gauss2d/fit/gsl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ using namespace pybind11::literals;
namespace g2f = gauss2d::fit;

void bind_gsl(py::module &m) {
auto _e = py::enum_<g2f::GSLInterpType>(m, "GSLInterpType")
.value("linear", g2f::GSLInterpType::linear)
.value("polynomial", g2f::GSLInterpType::polynomial)
.value("cspline", g2f::GSLInterpType::cspline)
.value("akima", g2f::GSLInterpType::akima)
;
// Placeholder for now - Python doesn't need to know about gsl_interp_type
}

#endif // GAUSS2D_FIT_HAS_GSL
62 changes: 62 additions & 0 deletions python/gauss2d/fit/gslinterpolator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* This file is part of gauss2dfit.
*
* Developed for the LSST Data Management System.
* This product includes software developed by the LSST Project
* (https://www.lsst.org).
* See the COPYRIGHT file at the top-level directory of this distribution
* for details of code ownership.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

#ifdef GAUSS2D_FIT_HAS_GSL

#include <pybind11/attr.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <memory>

#include "pybind11.h"

#include "gauss2d/fit/gsl.h"
#include "gauss2d/fit/gslinterpolator.h"

namespace py = pybind11;
using namespace pybind11::literals;

namespace g2f = gauss2d::fit;

void bind_gslsersicmixinterpolator(py::module &m) {
auto _e = py::class_<g2f::GSLSersicMixInterpolator, std::shared_ptr<g2f::GSLSersicMixInterpolator>,
g2f::SersicMixInterpolator>(m, "GSLSersicMixInterpolator")
.def(py::init<short, const g2f::GSLInterpType>(),
"order"_a = g2f::SERSICMIX_ORDER_DEFAULT,
"interp_type"_a = g2f::GSLInterpType::cspline)
.def_readwrite("correct_final_integral",
&g2f::GSLSersicMixInterpolator::correct_final_integral)
.def_property_readonly("final_correction",
&g2f::GSLSersicMixInterpolator::get_final_correction)
.def("integralsizes", &g2f::GSLSersicMixInterpolator::get_integralsizes)
.def("integralsizes_derivs",
&g2f::GSLSersicMixInterpolator::get_integralsizes_derivs)
.def_property_readonly("interp_type", &g2f::GSLSersicMixInterpolator::interp_type)
.def_property_readonly("order", &g2f::GSLSersicMixInterpolator::get_order)
.def("__repr__",
[](const g2f::GSLSersicMixInterpolator &self) { return self.repr(true); })
.def("__str__", &g2f::GSLSersicMixInterpolator::str);
}

#endif // GAUSS2D_FIT_HAS_GSL
5 changes: 3 additions & 2 deletions python/gauss2d/fit/gslsersicmixinterpolator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,17 @@ namespace g2f = gauss2d::fit;
void bind_gslsersicmixinterpolator(py::module &m) {
auto _e = py::class_<g2f::GSLSersicMixInterpolator, std::shared_ptr<g2f::GSLSersicMixInterpolator>,
g2f::SersicMixInterpolator>(m, "GSLSersicMixInterpolator")
.def(py::init<short, const g2f::GSLInterpType>(),
.def(py::init<short, const g2f::InterpType>(),
"order"_a = g2f::SERSICMIX_ORDER_DEFAULT,
"interp_type"_a = g2f::GSLInterpType::cspline)
"interp_type"_a = g2f::GSLSersicMixInterpolator::INTERPTYPE_DEFAULT)
.def_readwrite("correct_final_integral",
&g2f::GSLSersicMixInterpolator::correct_final_integral)
.def_property_readonly("final_correction",
&g2f::GSLSersicMixInterpolator::get_final_correction)
.def("integralsizes", &g2f::GSLSersicMixInterpolator::get_integralsizes)
.def("integralsizes_derivs",
&g2f::GSLSersicMixInterpolator::get_integralsizes_derivs)
.def_property_readonly("interptype", &g2f::GSLSersicMixInterpolator::get_interptype)
.def_property_readonly("order", &g2f::GSLSersicMixInterpolator::get_order)
.def("__repr__",
[](const g2f::GSLSersicMixInterpolator &self) { return self.repr(true); })
Expand Down
44 changes: 44 additions & 0 deletions python/gauss2d/fit/interpolation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* This file is part of gauss2dfit.
*
* Developed for the LSST Data Management System.
* This product includes software developed by the LSST Project
* (https://www.lsst.org).
* See the COPYRIGHT file at the top-level directory of this distribution
* for details of code ownership.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

#include <pybind11/attr.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "pybind11.h"

#include "gauss2d/fit/interpolation.h"

namespace py = pybind11;
using namespace pybind11::literals;

namespace g2f = gauss2d::fit;

void bind_interpolation(py::module &m) {
auto _e = py::enum_<g2f::InterpType>(m, "InterpType")
.value("linear", g2f::InterpType::linear)
.value("polynomial", g2f::InterpType::polynomial)
.value("cspline", g2f::InterpType::cspline)
.value("akima", g2f::InterpType::akima)
;
}
1 change: 1 addition & 0 deletions python/gauss2d/fit/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ sources = [
'gsl.cc',
'gslsersicmixinterpolator.cc',
'integralmodel.cc',
'interpolation.cc',
'linearintegralmodel.cc',
'linearsersicmixinterpolator.cc',
'model.cc',
Expand Down
1 change: 1 addition & 0 deletions python/gauss2d/fit/pybind11.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ PYBIND11_MODULE(_gauss2d_fit, m) {
bind_component(m);
bind_componentmixture(m);
bind_integralmodel(m);
bind_interpolation(m);
bind_prior(m);
bind_sersicmix(m);
#ifdef GAUSS2D_FIT_HAS_GSL
Expand Down
1 change: 1 addition & 0 deletions python/gauss2d/fit/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void bind_gsl(py::module &m);
void bind_gslsersicmixinterpolator(py::module &m);
#endif
void bind_integralmodel(py::module &m);
void bind_interpolation(py::module &m);
void bind_linearintegralmodel(py::module &m);
void bind_linearsersicmixinterpolator(py::module &m);
void bind_model(py::module &m);
Expand Down
3 changes: 2 additions & 1 deletion python/gauss2d/fit/sersicmixcomponent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ void bind_sersicmixcomponent(py::module &m) {
py::class_<C, std::shared_ptr<C>, Base>(m, pyclass_name.c_str()))
// new properties
.def_property_readonly("integralratio", &g2f::SersicMixComponentIndexParameter::get_integralratio)
.def_readonly("order", &g2f::SersicMixComponentIndexParameter::order)
.def_property_readonly("interptype", &g2f::SersicMixComponentIndexParameter::get_interptype)
.def_property_readonly("order", &g2f::SersicMixComponentIndexParameter::get_order)
.def_property_readonly("sizeratio", &g2f::SersicMixComponentIndexParameter::get_sizeratio)
// constructor with added arg
.def(py::init<T, std::shared_ptr<const parameters::Limits<T>>,
Expand Down
2 changes: 1 addition & 1 deletion src/gslinterpolator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ size_t GSLInterpolator::size() const {
}

GSLInterpolator::GSLInterpolator(std::vector<double> x, std::vector<double> y,
const GSLInterpType interp_type_)
const InterpType interp_type_)
: _n_knots(x.size()),
_x(x),
_y(y),
Expand Down
15 changes: 9 additions & 6 deletions src/gslsersicmixinterpolator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,17 @@ std::vector<IntegralSize> GSLSersicMixInterpolator::get_integralsizes_derivs(dou
}
if(correct_final_integral) {
const auto & interp = _interps[max_ord];
result.push_back({correct_final_integral*interp.first->eval_deriv(sersicindex), interp.second->eval_deriv(sersicindex)});
result.push_back({correct_final_integral*interp.first->eval_deriv(sersicindex),
interp.second->eval_deriv(sersicindex)});
} else {
_final_correction = 1.;
}

return result;
}

InterpType GSLSersicMixInterpolator::get_interptype() const { return _interp_type; }

unsigned short GSLSersicMixInterpolator::get_order() const { return _order; }

std::string GSLSersicMixInterpolator::repr(bool name_keywords) const {
Expand All @@ -89,9 +92,9 @@ std::string GSLSersicMixInterpolator::str() const {
return "GSLSersicMixInterpolator(order=" + std::to_string(_order) + ")";
}

GSLSersicMixInterpolator::GSLSersicMixInterpolator(unsigned short order, const GSLInterpType interp_type_)
: _order(order),
interp_type(interp_type_),
GSLSersicMixInterpolator::GSLSersicMixInterpolator(unsigned short order, const InterpType interp_type_)
: _interp_type(interp_type_),
_order(order),
knots(get_sersic_mix_knots(order)),
sersicindex_min(knots[0].sersicindex),
sersicindex_max(knots.back().sersicindex) {
Expand All @@ -117,8 +120,8 @@ GSLSersicMixInterpolator::GSLSersicMixInterpolator(unsigned short order, const G
}
for(size_t iord = 0; iord < _order; ++iord) {
_interps.push_back({
std::make_unique<GSLInterpolator>(sersics, integrals[iord], interp_type),
std::make_unique<GSLInterpolator>(sersics, sigmas[iord], interp_type)
std::make_unique<GSLInterpolator>(sersics, integrals[iord], _interp_type),
std::make_unique<GSLInterpolator>(sersics, sigmas[iord], _interp_type)
});
}
}
Expand Down
Loading

0 comments on commit 9100ff7

Please sign in to comment.