-
Notifications
You must be signed in to change notification settings - Fork 122
/
SeqDomain.cpp
175 lines (163 loc) · 5.5 KB
/
SeqDomain.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
//----------------------------------------------------------------------
// Includes
//----------------------------------------------------------------------
#include "MantidCurveFitting/SeqDomain.h"
#include "MantidCurveFitting/ParDomain.h"
namespace Mantid
{
namespace CurveFitting
{
/// Return the number of points in the domain
size_t SeqDomain::size() const
{
size_t n = 0;
for(auto it = m_creators.begin(); it != m_creators.end(); ++it)
{
n += (**it).getDomainSize();
}
return n;
}
/// Return the number of parts in the domain
size_t SeqDomain::getNDomains() const
{
return m_creators.size();
}
/**
* Create and return i-th domain and i-th values, (i-1)th domain is released.
* @param i :: Index of domain to return.
* @param domain :: Output pointer to the returned domain.
* @param values :: Output pointer to the returned values.
*/
void SeqDomain::getDomainAndValues(size_t i, API::FunctionDomain_sptr& domain, API::IFunctionValues_sptr& values) const
{
if ( i >= m_creators.size() ) throw std::range_error("Function domain index is out of range.");
if ( !m_domain[i] || i != m_currentIndex )
{
m_domain[m_currentIndex].reset();
m_values[m_currentIndex].reset();
m_creators[i]->createDomain(m_domain[i], m_values[i]);
m_currentIndex = i;
}
domain = m_domain[i];
values = m_values[i];
}
/**
* Add new domain creator
* @param creator :: A shared pointer to a new creator.
*/
void SeqDomain::addCreator( API::IDomainCreator_sptr creator )
{
m_creators.push_back( creator );
m_domain.push_back( API::FunctionDomain_sptr() );
m_values.push_back( API::IFunctionValues_sptr() );
}
/**
* Create an instance of SeqDomain in one of two forms: either SeqDomain for sequential domain creation
* or ParDomain for parallel calculations
* @param type :: Either Sequential or Parallel
*/
SeqDomain* SeqDomain::create(API::IDomainCreator::DomainType type)
{
if (type == API::IDomainCreator::Sequential)
{
return new SeqDomain;
}
else if (type == API::IDomainCreator::Parallel)
{
return new ParDomain;
}
throw std::invalid_argument("Unknown SeqDomain type");
}
/**
* Calculate the value of a least squares cost function
* @param leastSquares :: The least squares cost func to calculate the value for
*/
void SeqDomain::leastSquaresVal(const CostFuncLeastSquares& leastSquares)
{
API::FunctionDomain_sptr domain;
API::IFunctionValues_sptr values;
const size_t n = getNDomains();
for(size_t i = 0; i < n; ++i)
{
values.reset();
getDomainAndValues( i, domain, values );
auto simpleValues = boost::dynamic_pointer_cast<API::FunctionValues>(values);
if (!simpleValues)
{
throw std::runtime_error("LeastSquares: unsupported IFunctionValues.");
}
leastSquares.addVal( domain, simpleValues );
}
}
//------------------------------------------------------------------------------------------------
/**
* Calculate the value of a least squares cost function
* @param rwp :: The RWP cost func to calculate the value for
*/
void SeqDomain::rwpVal(const CostFuncRwp& rwp)
{
API::FunctionDomain_sptr domain;
API::IFunctionValues_sptr values;
const size_t n = getNDomains();
for(size_t i = 0; i < n; ++i)
{
values.reset();
getDomainAndValues( i, domain, values );
auto simpleValues = boost::dynamic_pointer_cast<API::FunctionValues>(values);
if (!simpleValues)
{
throw std::runtime_error("LeastSquares: unsupported IFunctionValues.");
}
rwp.addVal( domain, simpleValues );
}
}
/**
* Calculate the value, first and second derivatives of a least squares cost function
* @param leastSquares :: The least squares cost func to calculate the value for
* @param evalFunction :: Flag to evaluate the value of the cost function
* @param evalDeriv :: Flag to evaluate the first derivatives
* @param evalHessian :: Flag to evaluate the Hessian (second derivatives)
*/
void SeqDomain::leastSquaresValDerivHessian(const CostFuncLeastSquares& leastSquares, bool evalFunction, bool evalDeriv, bool evalHessian)
{
API::FunctionDomain_sptr domain;
API::IFunctionValues_sptr values;
const size_t n = getNDomains();
for(size_t i = 0; i < n; ++i)
{
values.reset();
getDomainAndValues( i, domain, values );
auto simpleValues = boost::dynamic_pointer_cast<API::FunctionValues>(values);
if (!simpleValues)
{
throw std::runtime_error("LeastSquares: unsupported IFunctionValues.");
}
leastSquares.addValDerivHessian(leastSquares.getFittingFunction(),domain,simpleValues,evalFunction,evalDeriv,evalHessian);
}
}
/**
* Calculate the value, first and second derivatives of a RWP cost function
* @param rwp :: The rwp cost func to calculate the value for
* @param evalFunction :: Flag to evaluate the value of the cost function
* @param evalDeriv :: Flag to evaluate the first derivatives
* @param evalHessian :: Flag to evaluate the Hessian (second derivatives)
*/
void SeqDomain::rwpValDerivHessian(const CostFuncRwp& rwp, bool evalFunction, bool evalDeriv, bool evalHessian)
{
API::FunctionDomain_sptr domain;
API::IFunctionValues_sptr values;
const size_t n = getNDomains();
for(size_t i = 0; i < n; ++i)
{
values.reset();
getDomainAndValues( i, domain, values );
auto simpleValues = boost::dynamic_pointer_cast<API::FunctionValues>(values);
if (!simpleValues)
{
throw std::runtime_error("Rwp: unsupported IFunctionValues.");
}
rwp.addValDerivHessian(rwp.getFittingFunction(),domain,simpleValues,evalFunction,evalDeriv,evalHessian);
}
}
} // namespace CurveFitting
} // namespace Mantid