/
size_checks.hpp
135 lines (124 loc) · 4.45 KB
/
size_checks.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
/**
* @file size_checks.hpp
* @author Kirill Mishchenko
* @author Bisakh Mondal
*
* Utility for checking same size & same dimensionality.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_UTIL_SIZE_CHECKS_HPP
#define MLPACK_UTIL_SIZE_CHECKS_HPP
namespace mlpack {
namespace util {
/**
* Check for if the given data points & labels have same size.
*
* @param data data.
* @param labels Labels.
* @param callerDescription A description of the caller that can be used for
* error generation.
* @param addInfo Name to use for labels for precise error generation. Default
* is "labels"; for example, "weights" could also be used.
* @param isDataTranspose Bool parameter which can be set true to transpose data
* before size-check. Default is false.
* @param isLabelTranspose Bool parameter which can be set true to transpose label
* before size-check. Default is false.
*/
template<typename DataType, typename LabelsType>
inline void CheckSameSizes(
const DataType& data,
const LabelsType& label,
const std::string& callerDescription,
const std::string& addInfo = "labels",
const bool& isDataTranspose = false,
const bool& isLabelTranspose = false,
const typename std::enable_if<
!std::is_integral<LabelsType>::value>::type* = 0)
{
const size_t dataPoints = (isDataTranspose == true) ? data.n_rows : data.n_cols;
const size_t labelPoints = (isLabelTranspose == true) ? label.n_rows : label.n_cols;
if (dataPoints != labelPoints)
{
std::ostringstream oss;
oss << callerDescription << ": number of points (" << dataPoints << ") "
<< "does not match number of " << addInfo << " (" << labelPoints
<< ")!" << std::endl;
throw std::invalid_argument(oss.str());
}
}
/**
* An overload of CheckSameSizes() where the size to be checked is known
* previously. The second parameter is of type unsigned int.
*/
template<typename DataType, typename SizeType>
inline void CheckSameSizes(
const DataType& data,
const SizeType& size,
const std::string& callerDescription,
const std::string& addInfo = "labels",
const typename std::enable_if<std::is_integral<SizeType>::value>::type* = 0)
{
if (data.n_cols != size)
{
std::ostringstream oss;
oss << callerDescription << ": number of points (" << data.n_cols << ") "
<< "does not match number of " << addInfo << " (" << size << ")!"
<< std::endl;
throw std::invalid_argument(oss.str());
}
}
/**
* Check for if the given dataset dimension matches with the model's.
*
* @param data dataset.
* @param dimension Dimension of the model.
* @param callerDescription A description of the caller that can be used for
* error generation.
* @param addInfo Name to use for dataset for precise error generation. Default
* is "dataset"; for example, "weights" could also be used.
*/
template<typename DataType, typename DimType>
inline void CheckSameDimensionality(
const DataType& data,
const DimType& dimension,
const std::string& callerDescription,
const std::string& addInfo = "dataset",
const typename std::enable_if<!std::is_integral<DimType>::value>::type* = 0)
{
if (data.n_rows != dimension.n_rows)
{
std::ostringstream oss;
oss << callerDescription << ": dimensionality of " << addInfo << " ("
<< data.n_rows << ") is not equal to the dimensionality of the model"
" (" << dimension.n_rows << ")!";
throw std::invalid_argument(oss.str());
}
}
/**
* An overload of CheckSameDimensionality() where the dimension to be checked
* is known second param is unsigned long int.
*/
template<typename DataType, typename DimType>
inline void CheckSameDimensionality(
const DataType& data,
const DimType& dimension,
const std::string& callerDescription,
const std::string& addInfo = "dataset",
const typename std::enable_if<std::is_integral<DimType>::value>::type* = 0)
{
if (data.n_rows != dimension)
{
std::ostringstream oss;
oss << callerDescription << ": dimensionality of " << addInfo << " ("
<< data.n_rows << ") is not equal to the dimensionality of the model"
" (" << dimension << ")!";
throw std::invalid_argument(oss.str());
}
}
} // namespace util
} // namespace mlpack
#endif