In [1]:
from sympy import *

In [2]:
x = symbols('x')
tol = 1e-12

def check_zero(f):
    if f == 0:
        return ""
    return f" - {f}"

def write_test_dual(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_dual_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = Dual64::from({x0}).derive().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_dual_n(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_dual_n_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = DualN64::<2>::from({x0}).derive(0).derive(1).{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hyperdual(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_hyperdual_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HyperDual64::from({x0}).derive1().derive2().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps2{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1eps2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hyperdual_n(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_hyperdual_n_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HyperDualN64::<2>::from({x0}).derive(0).derive(1).{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.gradient[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.gradient[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.hessian[(0,0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.hessian[(0,1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.hessian[(1,0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.hessian[(1,1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hd2(f, x0, method, additional_param="", index=None):
    test = "#[test]\n"
    test += f"fn test_hd2_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HD2_64::from({x0}).derive().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hd3(f, x0, method, additional_param="", index=None):
    test = "#[test]\n"
    test += f"fn test_hd3_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HD3_64::from({x0}).derive().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v3{check_zero(f.diff().diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test
    
def write_all_tests(number):
    write_test = {'dual': write_test_dual, 'dual_n': write_test_dual_n, 'hyperdual': write_test_hyperdual, 'hyperdual_n': write_test_hyperdual_n, 'hd2': write_test_hd2, 'hd3': write_test_hd3}[number]
    test = ""
    
    dual_type = {'dual': 'Dual64', 'dual_n': 'DualN64', 'hyperdual': 'HyperDual64', 'hyperdual_n': 'HyperDualN64', 'hd2': 'HD2_64', 'hd3': 'HD3_64'}[number]
    test += f"use num_hyperdual::{dual_type};\n"
    test += "use num_hyperdual::DualNum;\n\n"
    
    test += write_test(1/x, 1.2, "recip")
    test += write_test(exp(x), 1.2, "exp")
    test += write_test(exp(x)-1, 1.2, "exp_m1")
    test += write_test(2**x, 1.2, "exp2")
    test += write_test(ln(x), 1.2, "ln")
    test += write_test(log(x, 4.2), 1.2, "log", 4.2)
    test += write_test(ln(1+x), 1.2, "ln_1p")
    test += write_test(log(x,2), 1.2, "log2")
    test += write_test(log(x,10), 1.2, "log10")
    test += write_test(sqrt(x), 1.2, "sqrt")
    test += write_test(cbrt(x), 1.2, "cbrt")
    test += write_test(x**4.2, 1.2, "powf", 4.2)
    test += write_test(x**0, 0.0, "powf", 0.0, 0)
    test += write_test(x**1, 0.0, "powf", 1.0, 1)
    test += write_test(x**2, 0.0, "powf", 2.0, 2)
    test += write_test(x**3, 0.0, "powf", 3.0, 3)
    test += write_test(x**4, 0.0, "powf", 4.0, 4)
    test += write_test(x**6, 1.2, "powi", 6)
    test += write_test(x**0, 0.0, "powi", 0, 0)
    test += write_test(x**1, 0.0, "powi", 1, 1)
    test += write_test(x**2, 0.0, "powi", 2, 2)
    test += write_test(x**3, 0.0, "powi", 3, 3)
    test += write_test(x**4, 0.0, "powi", 4, 4)
    test += write_test(sin(x), 1.2, "sin")
    test += write_test(cos(x), 1.2, "cos")
    test += write_test(tan(x), 1.2, "tan")
    test += write_test(asin(x), 0.2, "asin")
    test += write_test(acos(x), 0.2, "acos")
    test += write_test(atan(x), 0.2, "atan")
    test += write_test(sinh(x), 1.2, "sinh")
    test += write_test(cosh(x), 1.2, "cosh")
    test += write_test(tanh(x), 1.2, "tanh")
    test += write_test(asinh(x), 1.2, "asinh")
    test += write_test(acosh(x), 1.2, "acosh")
    test += write_test(atanh(x), 0.2, "atanh")
    test += write_test(jn(0,x), 1.2, "sph_j0")
    test += write_test(jn(1,x), 1.2, "sph_j1")
    test += write_test(jn(2,x), 1.2, "sph_j2")

    with open(f'tests/test_{number}.rs', 'w') as f:
        f.write(test)

In [3]:
for number in ['dual', 'dual_n', 'hyperdual', 'hyperdual_n', 'hd2', 'hd3']:
    write_all_tests(number)