forked from AmbaPant/mantid
-
Notifications
You must be signed in to change notification settings - Fork 1
/
FunctionAdapterTestCommon.h
97 lines (88 loc) · 4.11 KB
/
FunctionAdapterTestCommon.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Mantid Repository : https://github.com/mantidproject/mantid
//
// Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
// NScD Oak Ridge National Laboratory, European Spallation Source,
// Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once
#include "MantidAPI/FunctionFactory.h"
#include "MantidAPI/IPeakFunction.h"
#include <boost/algorithm/string/replace.hpp>
#include <boost/python/detail/wrap_python.hpp>
namespace Mantid {
namespace PythonInterface {
// Generic template is NOT defined. This causes a linker error when
// called by a specialization that doesn't yet exist. If you get this
// then you need to write a new specialization for the given function type
template <typename FunctionType> std::string clsBlueprint(bool includeDerivative);
template <> inline std::string clsBlueprint<Mantid::API::IFunction1D>(bool includeDerivative) {
std::string blueprint = "from mantid.api import IFunction1D, FunctionFactory\n"
"class {0}(IFunction1D):\n"
" def init(self):\n"
" self.declareParameter('A', 1.0)\n"
" def function1D(self, x):\n"
"{1}\n";
if (includeDerivative) {
blueprint.append(" def functionDeriv1D(self, x, jacobian):\n"
"{2}\n");
}
blueprint.append("FunctionFactory.Instance().subscribe({0})\n");
return blueprint;
}
template <> inline std::string clsBlueprint<Mantid::API::IPeakFunction>(bool includeDerivative) {
std::string blueprint = "from mantid.api import IPeakFunction, FunctionFactory\n"
"class {0}(IPeakFunction):\n"
" def init(self):\n"
" self.declareParameter('A', 1.0)\n"
" def functionLocal(self, x):\n"
"{1}\n";
if (includeDerivative) {
blueprint.append(" def functionDerivLocal(self, x, jacobian):\n"
"{2}\n");
}
blueprint.append(" def centre(self):"
" return 0.0\n"
" def setCentre(self, x):\n"
" pass\n"
" def height(self):\n"
" return 1.0\n"
" def setHeight(self, x):\n"
" pass\n"
" def fwhm(self):\n"
" return 0.1\n"
" def setFwhm(self, x):\n"
" pass\n"
"FunctionFactory.Instance().subscribe({0})\n");
return blueprint;
}
template <typename FunctionType>
void subscribeTestFunction(const std::string &clsName, std::string functionImpl, std::string derivImpl) {
using boost::algorithm::replace_all;
using boost::algorithm::replace_all_copy;
const bool includeDerivative(!derivImpl.empty());
std::string blueprint = replace_all_copy(clsBlueprint<FunctionType>(includeDerivative), "{0}", clsName);
replace_all(blueprint, "{1}", functionImpl);
if (includeDerivative) {
replace_all(blueprint, "{2}", derivImpl);
}
PyRun_SimpleString(blueprint.c_str());
}
template <typename FunctionType>
std::shared_ptr<FunctionType> createTestFunction(std::string clsName, std::string functionImpl,
std::string derivImpl = "") {
using Mantid::API::FunctionFactory;
subscribeTestFunction<FunctionType>(clsName, std::move(functionImpl), std::move(derivImpl));
return std::dynamic_pointer_cast<FunctionType>(FunctionFactory::Instance().createFunction(clsName));
}
class FunctionAdapterTestJacobian : public Mantid::API::Jacobian {
public:
FunctionAdapterTestJacobian(size_t ny, size_t np) : m_np(np), m_data(ny * np) {}
void set(size_t iY, size_t iP, double value) override { m_data[iY * m_np + iP] = value; }
double get(size_t iY, size_t iP) override { return m_data[iY * m_np + iP]; }
void zero() override { m_data.assign(m_data.size(), 0.0); }
private:
size_t m_np;
std::vector<double> m_data;
};
} // namespace PythonInterface
} // namespace Mantid