/
unit_tests.cpp
114 lines (93 loc) · 2.77 KB
/
unit_tests.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright © 2016 Ondrej Martinsky, All rights reserved
// www.quantandfinancial.com
#include <functional>
#include <math.h>
#include <iomanip>
#include <algorithm>
#include <assert.h>
#include <sstream>
#include "ad_engine.hpp"
#include "unit_tests.hpp"
using namespace std;
double numerical_derivative(function<double(double)> f, double x)
{
double d = 1e-6;
return (f(x + d) - f(x - d)) / (2 * d);
}
double numerical_derivative(function<double(double, double)> f, double x, double y, int which)
{
double d = 1e-6;
if (which==0)
return (f(x + d, y) - f(x - d, y)) / (2 * d);
else
return (f(x, y + d) - f(x, y - d)) / (2 * d);
}
#define CHECK(actual, expected) check((actual), (expected), __LINE__)
void check(double actual, double expected, int line)
{
double absdiff = abs(actual - expected);
double reldiff = absdiff / max(abs(actual), abs(expected));
if (reldiff > 1e-6 && absdiff > 1e-10)
{
stringstream ss;
ss << "Error Line=" << line;
ss << ", Actual=" << setw(10) << actual << endl;
ss << ", Expected=" << setw(10) << expected << endl;
ss << ", AbsDiff=" << setw(10) << absdiff << endl;
ss << ", RelDiff=" << setw(10) << reldiff << endl;
throw exception(ss.str().c_str());
}
}
void unit_tests()
{
{
ADEngine e;
ADDouble a(e, 3.);
auto f = [&](auto x) -> auto {
return ADDouble(e, 1.0);
};
CHECK(e.get_derivative(f(a), a), 0.0);
}
{
ADEngine e;
ADDouble a(e, 3.);
auto f = [](auto x) -> auto {
return x;
};
CHECK(e.get_derivative(f(a), a), 1.0);
}
{
ADEngine e;
ADDouble a(e, 3.);
auto f = [](auto x) -> auto {
return x + x + x + x;
};
CHECK(e.get_derivative(f(a), a), numerical_derivative(f, a.get_value()));
}
{
ADEngine e;
ADDouble a(e, 3.);
auto f = [](auto x) -> auto {
return (x + x) + (x + x);
};
CHECK( e.get_derivative(f(a), a), numerical_derivative(f, a.get_value()));
}
{
ADEngine e;
ADDouble a(e, 1.);
ADDouble b(e, 4.);
auto f = [](auto x, auto y) -> auto {
return x / y;
};
CHECK(e.get_derivative(f(a, b), b), numerical_derivative(f, a.get_value(), b.get_value(), 1));
}
{
ADEngine e;
ADDouble a(e, 3.);
ADDouble b(e, 4.);
auto f = [](auto x, auto y) -> auto {
return log(x) + log(x) + exp(y) + (x + y) * (2. * x - y) / (x - 0.5 * y) / y;
};
CHECK(e.get_derivative(f(a, b), a), numerical_derivative(f, a.get_value(), b.get_value(), 0));
}
}