Skip to content

Commit

Permalink
Re #9682. Implemented binary search
Browse files Browse the repository at this point in the history
Replaced the linear search with a binary search and added unit tests.
  • Loading branch information
Michael Wedel committed Jun 23, 2014
1 parent 1804557 commit ce9f979
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
4 changes: 4 additions & 0 deletions Code/Mantid/Framework/Kernel/inc/MantidKernel/Interpolation.h
Expand Up @@ -62,10 +62,14 @@ class MANTID_KERNEL_DLL Interpolation
/// unit of y-axis
Unit_sptr m_yUnit;

protected:
size_t findIndexOfNextLargerValue(const std::vector<double> &data, double key, size_t range_start, size_t range_end) const;

public:

/// Constructor default to linear interpolation and x-unit set to TOF
Interpolation();
virtual ~Interpolation() { }

/// add data point
void addPoint(const double& xx, const double& yy);
Expand Down
38 changes: 28 additions & 10 deletions Code/Mantid/Framework/Kernel/src/Interpolation.cpp
Expand Up @@ -21,6 +21,26 @@ namespace Kernel
m_yUnit = UnitFactory::Instance().create("TOF");
}

size_t Interpolation::findIndexOfNextLargerValue(const std::vector<double> &data, double key, size_t range_start, size_t range_end) const
{
if(range_end < range_start)
{
throw std::range_error("Value is outside array range.");
}

size_t center = range_start + (range_end - range_start) / 2;

if(data[center] > key && data[center - 1] <= key) {
return center;
}

if(data[center] <= key) {
return findIndexOfNextLargerValue(data, key, center + 1, range_end);
} else {
return findIndexOfNextLargerValue(data, key, range_start, center - 1);
}
}

void Interpolation::setXUnit(const std::string& unit)
{
m_xUnit = UnitFactory::Instance().create(unit);
Expand Down Expand Up @@ -48,7 +68,7 @@ namespace Kernel

// check first if at is within the limits of interpolation interval

if ( at <= m_x[0] )
if ( at < m_x[0] )
{
return m_y[0]-(m_x[0]-at)*(m_y[1]-m_y[0])/(m_x[1]-m_x[0]);
}
Expand All @@ -58,16 +78,14 @@ namespace Kernel
return m_y[N-1]+(at-m_x[N-1])*(m_y[N-1]-m_y[N-2])/(m_x[N-1]-m_x[N-2]);
}

// otherwise

for (unsigned int i = 1; i < N; i++)
{
if ( m_x[i] > at )
{
return m_y[i-1] + (at-m_x[i-1])*(m_y[i]-m_y[i-1])/(m_x[i]-m_x[i-1]);
}
try {
// otherwise
// General case. Find index of next largest value by binary search.
size_t idx = findIndexOfNextLargerValue(m_x, at, 1, N - 1);
return m_y[idx-1] + (at - m_x[idx-1])*(m_y[idx]-m_y[idx-1])/(m_x[idx]-m_x[idx-1]);
} catch(std::range_error) {
return 0.0;
}
return 0.0;
}

/** Add point in the interpolation.
Expand Down
38 changes: 38 additions & 0 deletions Code/Mantid/Framework/Kernel/test/InterpolationTest.h
Expand Up @@ -163,6 +163,34 @@ class InterpolationTest : public CxxTest::TestSuite
checkInterpolationResults(interpolation);
}

void testFindIndexOfNextLargerValue()
{
TestableInterpolation interpolation;

size_t N = m_tableXValues.size();

// lower limit - can be treated like general case
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 200.0, 1, N - 1), 1);

// Exact interpolation points
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 201.0, 1, N - 1), 2);
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 202.0, 1, N - 1), 3);
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 203.0, 1, N - 1), 4);

// Arbitrary interpolation points
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 200.5, 1, N - 1), 1);
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 201.25, 1, N - 1), 2);
TS_ASSERT_EQUALS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 203.5, 1, N - 1), 4);


// upper limit - must be covered as edge case before this can ever be called.
TS_ASSERT_THROWS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 204.0, 1, N - 1), std::range_error);

// outside interpolation limits - edge cases as well
TS_ASSERT_THROWS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 199, 1, N - 1), std::range_error)
TS_ASSERT_THROWS(interpolation.findIndexOfNextLargerValue(m_tableXValues, 2000.0, 1, N - 1), std::range_error)
}

private:
Interpolation getInitializedInterpolation(std::string xUnit, std::string yUnit)
{
Expand Down Expand Up @@ -240,6 +268,16 @@ class InterpolationTest : public CxxTest::TestSuite
// Values outside interpolation range
std::vector<double> m_outsideXValues;
std::vector<double> m_outsideYValues;

// For the test of findIndexOfNextLargerValue access to protected member is needed
class TestableInterpolation : public Interpolation {
friend class InterpolationTest;

public:
TestableInterpolation() : Interpolation()
{ }
~TestableInterpolation() { }
};
};

#endif /*INTERPOLATIONTEST_H_*/

0 comments on commit ce9f979

Please sign in to comment.