In [1]:
from sympy import *

In [2]:
x = symbols('x')
tol = 1e-12

def check_zero(f):
    if f == 0:
        return ""
    return f" - {f}"

def write_test_dual(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_dual_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = Dual64::from({x0}).derivative().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_dual_vec(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_dual_vec_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = DualSVec64::new({x0}, Derivative::some(Vector::from([1.0, 1.0]))).{method}({additional_param});\n"
    test += f"    let eps = res.eps.unwrap_generic(Const::<2>, Const::<1>);\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hyperdual(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_hyperdual_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HyperDual64::from({x0}).derivative1().derivative2().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps2{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1eps2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_dual2_vec(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_dual2_vec_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = Dual2SVec64::new(\n"
    test += f"        {x0},\n"
    test += f"        Derivative::some(RowSVector::from([1.0, 1.0])),\n"
    test += f"        Derivative::none(),\n"
    test += f"    )\n"
    test += f"    .{method}({additional_param});\n"
    test += f"    let v1 = res.v1.unwrap_generic(Const::<1>, Const::<2>);\n"
    test += f"    let v2 = res.v2.unwrap_generic(Const::<2>, Const::<2>);\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v1[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v1[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v2[(0, 0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v2[(0, 1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v2[(1, 0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((v2[(1, 1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_hyperdual_vec(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_hyperdual_vec_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HyperDualVec64::new(\n"
    test += f"        {x0},\n"
    test += f"        Derivative::some(SVector::from([1.0, 1.0])),\n"
    test += f"        Derivative::some(RowSVector::from([1.0, 1.0])),\n"
    test += f"        Derivative::none(),\n"
    test += f"    )\n"
    test += f"    .{method}({additional_param});\n"
    test += f"    let eps1 = res.eps1.unwrap_generic(Const::<2>, Const::<1>);\n"
    test += f"    let eps2 = res.eps2.unwrap_generic(Const::<1>, Const::<2>);\n"
    test += f"    let eps1eps2 = res.eps1eps2.unwrap_generic(Const::<2>, Const::<2>);\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps2[0]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps2[1]{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1eps2[(0, 0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1eps2[(0, 1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1eps2[(1, 0)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((eps1eps2[(1, 1)]{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_dual2(f, x0, method, additional_param="", index=None):
    test = "#[test]\n"
    test += f"fn test_dual2_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = Dual2_64::from({x0}).derivative().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test

def write_test_dual3(f, x0, method, additional_param="", index=None):
    test = "#[test]\n"
    test += f"fn test_dual3_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = Dual3_64::from({x0}).derivative().{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.v3{check_zero(f.diff().diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test


def write_test_hyperhyperdual(f, x0, method, additional_param = "", index=None):
    test = "#[test]\n"
    test += f"fn test_hyperhyperdual_{method}{'' if index is None else f'_{index}'}() {{\n"
    test += f"    let res = HyperHyperDual64::from({x0})\n"
    test += f"        .derivative1()\n"
    test += f"        .derivative2()\n"
    test += f"        .derivative3()\n"
    test += f"        .{method}({additional_param});\n"
    test += f"    assert!((res.re{check_zero(f.evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps2{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps3{check_zero(f.diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1eps2{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1eps3{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps2eps3{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps2eps3{check_zero(f.diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += f"    assert!((res.eps1eps2eps3{check_zero(f.diff().diff().diff().evalf(subs={x: x0}))}).abs() < {tol});\n"
    test += "}\n\n"
    return test
    
def write_all_tests(number):
    write_test = {'dual': write_test_dual, 'dual_vec': write_test_dual_vec, 'hyperdual': write_test_hyperdual, 'dual2_vec': write_test_dual2_vec, 'hyperdual_vec': write_test_hyperdual_vec, 'dual2': write_test_dual2, 'dual3': write_test_dual3, 'hyperhyperdual': write_test_hyperhyperdual}[number]
    nalgebra_imports = {'dual_vec': ['Const, Vector'], 'dual2_vec': ['Const', 'RowSVector'], 'hyperdual_vec': ['Const', 'RowSVector', 'SVector']}.get(number)
    test = ""
    
    if nalgebra_imports is not None:
        test += f"use nalgebra::{{{', '.join(nalgebra_imports)}}};\n"
    test += "use num_dual::*;\n"
    test += "\n"
    test += write_test(1/x, 1.2, "recip")
    test += write_test(exp(x), 1.2, "exp")
    test += write_test(exp(x)-1, 1.2, "exp_m1")
    test += write_test(2**x, 1.2, "exp2")
    test += write_test(ln(x), 1.2, "ln")
    test += write_test(log(x, 4.2), 1.2, "log", 4.2)
    test += write_test(ln(1+x), 1.2, "ln_1p")
    test += write_test(log(x,2), 1.2, "log2")
    test += write_test(log(x,10), 1.2, "log10")
    test += write_test(sqrt(x), 1.2, "sqrt")
    test += write_test(cbrt(x), 1.2, "cbrt")
    test += write_test(x**4.2, 1.2, "powf", 4.2)
    test += write_test(x**0, 0.0, "powf", 0.0, 0)
    test += write_test(x**1, 0.0, "powf", 1.0, 1)
    test += write_test(x**2, 0.0, "powf", 2.0, 2)
    test += write_test(x**3, 0.0, "powf", 3.0, 3)
    test += write_test(x**4, 0.0, "powf", 4.0, 4)
    test += write_test(x**6, 1.2, "powi", 6)
    test += write_test(x**0, 0.0, "powi", 0, 0)
    test += write_test(x**1, 0.0, "powi", 1, 1)
    test += write_test(x**2, 0.0, "powi", 2, 2)
    test += write_test(x**3, 0.0, "powi", 3, 3)
    test += write_test(x**4, 0.0, "powi", 4, 4)
    test += write_test(sin(x), 1.2, "sin")
    test += write_test(cos(x), 1.2, "cos")
    test += write_test(tan(x), 1.2, "tan")
    test += write_test(asin(x), 0.2, "asin")
    test += write_test(acos(x), 0.2, "acos")
    test += write_test(atan(x), 0.2, "atan")
    test += write_test(sinh(x), 1.2, "sinh")
    test += write_test(cosh(x), 1.2, "cosh")
    test += write_test(tanh(x), 1.2, "tanh")
    test += write_test(asinh(x), 1.2, "asinh")
    test += write_test(acosh(x), 1.2, "acosh")
    test += write_test(atanh(x), 0.2, "atanh")
    test += write_test(jn(0,x), 1.2, "sph_j0")
    test += write_test(jn(1,x), 1.2, "sph_j1")
    test += write_test(jn(2,x), 1.2, "sph_j2")
    test += write_test(besselj(0,x), 0.0, "bessel_j0", index=0)
    test += write_test(besselj(1,x), 0.0, "bessel_j1", index=0)
    test += write_test(besselj(2,x), 0.0, "bessel_j2", index=0)
    test += write_test(besselj(0,x), 1.2, "bessel_j0", index=1)
    test += write_test(besselj(1,x), 1.2, "bessel_j1", index=1)
    test += write_test(besselj(2,x), 1.2, "bessel_j2", index=1)
    test += write_test(besselj(0,x), 7.2, "bessel_j0", index=2)
    test += write_test(besselj(1,x), 7.2, "bessel_j1", index=2)
    test += write_test(besselj(2,x), 7.2, "bessel_j2", index=2)
    test += write_test(besselj(0,x), -1.2, "bessel_j0", index=3)
    test += write_test(besselj(1,x), -1.2, "bessel_j1", index=3)
    test += write_test(besselj(2,x), -1.2, "bessel_j2", index=3)
    test += write_test(besselj(0,x), -7.2, "bessel_j0", index=4)
    test += write_test(besselj(1,x), -7.2, "bessel_j1", index=4)
    test += write_test(besselj(2,x), -7.2, "bessel_j2", index=4)[:-1]

    with open(f'tests/test_{number}.rs', 'w') as f:
        f.write(test)

In [3]:
for number in ['dual', 'dual_vec', 'dual2', 'dual2_vec', 'hyperdual', 'hyperdual_vec', 'dual3', 'hyperhyperdual']:
    write_all_tests(number)