diff --git a/clang/lib/Headers/__clang_spirv_math.h b/clang/lib/Headers/__clang_spirv_math.h index 761083c10fda1..72eec8f4e4bd9 100644 --- a/clang/lib/Headers/__clang_spirv_math.h +++ b/clang/lib/Headers/__clang_spirv_math.h @@ -714,6 +714,107 @@ double nan(const char *__tagp) { return __spirv_ocl_nan(__make_mantissa(__tagp)); } +// normcdfinv implementation using Acklam's Inverse Normal CDF approximation +// algorithm +__DEVICE__ double normcdfinv(double __p) { + if (__p <= 0.0) + return -__builtin_inf(); + if (__p >= 1.0) + return __builtin_inf(); + + const double a1 = -3.969683028665376e+01; + const double a2 = 2.209460984245205e+02; + const double a3 = -2.759285104469687e+02; + const double a4 = 1.383577518672690e+02; + const double a5 = -3.066479806614716e+01; + const double a6 = 2.506628277459239e+00; + const double b1 = -5.447609879822406e+01; + const double b2 = 1.615858368580409e+02; + const double b3 = -1.556989798598866e+02; + const double b4 = 6.680131188771972e+01; + const double b5 = -1.328068155288572e+01; + const double c1 = -7.784894002430293e-03; + const double c2 = -3.223964580411365e-01; + const double c3 = -2.400758277161838e+00; + const double c4 = -2.549732539343734e+00; + const double c5 = 4.374664141464968e+00; + const double c6 = 2.938163982698783e+00; + const double d1 = 7.784695709041462e-03; + const double d2 = 3.224671290700398e-01; + const double d3 = 2.445134137142996e+00; + const double d4 = 3.754408661907416e+00; + const double p_low = 0.02425; + const double p_high = 1.0 - p_low; + double q, r, result; + if (__p < p_low) { + q = sqrt(-2.0 * log(__p)); + result = fma(fma(fma(fma(fma(c1, q, c2), q, c3), q, c4), q, c5), q, c6) / + fma(fma(fma(fma(d1, q, d2), q, d3), q, d4), q, 1.0); + } else if (__p <= p_high) { + q = __p - 0.5; + r = q * q; + result = fma(fma(fma(fma(fma(a1, r, a2), r, a3), r, a4), r, a5), r, a6) * + q / + fma(fma(fma(fma(fma(b1, r, b2), r, b3), r, b4), r, b5), r, 1.0); + } else { + q = sqrt(-2.0 * log(1.0 - __p)); + result = -fma(fma(fma(fma(fma(c1, q, c2), q, c3), q, c4), q, c5), q, c6) / + fma(fma(fma(fma(d1, q, d2), q, d3), q, d4), q, 1.0); + } + + return result; +} + +__DEVICE__ float normcdfinvf(float __p) { + if (__p <= 0.0f) + return -__builtin_inff(); + if (__p >= 1.0f) + return __builtin_inff(); + const float a1 = -3.969683028665376e+01f; + const float a2 = 2.209460984245205e+02f; + const float a3 = -2.759285104469687e+02f; + const float a4 = 1.383577518672690e+02f; + const float a5 = -3.066479806614716e+01f; + const float a6 = 2.506628277459239e+00f; + const float b1 = -5.447609879822406e+01f; + const float b2 = 1.615858368580409e+02f; + const float b3 = -1.556989798598866e+02f; + const float b4 = 6.680131188771972e+01f; + const float b5 = -1.328068155288572e+01f; + const float c1 = -7.784894002430293e-03f; + const float c2 = -3.223964580411365e-01f; + const float c3 = -2.400758277161838e+00f; + const float c4 = -2.549732539343734e+00f; + const float c5 = 4.374664141464968e+00f; + const float c6 = 2.938163982698783e+00f; + const float d1 = 7.784695709041462e-03f; + const float d2 = 3.224671290700398e-01f; + const float d3 = 2.445134137142996e+00f; + const float d4 = 3.754408661907416e+00f; + const float p_low = 0.02425f; + const float p_high = 1.0f - p_low; + float q, r, result; + if (__p < p_low) { + q = sqrtf(-2.0f * logf(__p)); + result = + fmaf(fmaf(fmaf(fmaf(fmaf(c1, q, c2), q, c3), q, c4), q, c5), q, c6) / + fmaf(fmaf(fmaf(fmaf(d1, q, d2), q, d3), q, d4), q, 1.0f); + } else if (__p <= p_high) { + q = __p - 0.5f; + r = q * q; + result = + fmaf(fmaf(fmaf(fmaf(fmaf(a1, r, a2), r, a3), r, a4), r, a5), r, a6) * + q / + fmaf(fmaf(fmaf(fmaf(fmaf(b1, r, b2), r, b3), r, b4), r, b5), r, 1.0f); + } else { + q = sqrtf(-2.0f * logf(1.0f - __p)); + result = + -fmaf(fmaf(fmaf(fmaf(fmaf(c1, q, c2), q, c3), q, c4), q, c5), q, c6) / + fmaf(fmaf(fmaf(fmaf(d1, q, d2), q, d3), q, d4), q, 1.0f); + } + + return result; +} #pragma pop_macro("__DEVICE__") #endif // __CLANG_GPU_DISABLE_MATH_WRAPPERS #endif // __CLANG_SPIRV_MATH_H__ diff --git a/clang/test/Headers/spirv_normcdfinv.c b/clang/test/Headers/spirv_normcdfinv.c new file mode 100644 index 0000000000000..731bc3a777690 --- /dev/null +++ b/clang/test/Headers/spirv_normcdfinv.c @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -internal-isystem %S/Inputs/include -fopenmp -triple x86_64-unknown-unknown -fopenmp-targets=spirv64-intel-unknown -emit-llvm-bc %s -o %t-host.bc +// RUN: %clang_cc1 -include __clang_openmp_device_functions.h -internal-isystem %S/../../lib/Headers/openmp_wrappers -internal-isystem %S/Inputs/include -disable-llvm-passes -fopenmp -triple spirv64 -fopenmp-targets=spirv64 -emit-llvm %s -fopenmp-is-target-device -fopenmp-host-ir-file-path %t-host.bc -o - | FileCheck %s +// RUN: %clang_cc1 -include __clang_openmp_device_functions.h -internal-isystem %S/../../lib/Headers/openmp_wrappers -internal-isystem %S/Inputs/include -disable-llvm-passes -fopenmp -triple spirv64-intel -fopenmp-targets=spirv64-intel -emit-llvm %s -fopenmp-is-target-device -fopenmp-host-ir-file-path %t-host.bc -o - | FileCheck %s +// expected-no-diagnostics +#include + +// Test that normcdfinvf is properly defined and uses SPIRV OCL builtins +// CHECK-LABEL: define {{.*}} @{{.*}}test_normcdfinvf +void test_normcdfinvf(float x, float *result) { + #pragma omp target map(from: result[0:1]) + { + // CHECK: call {{.*}} float @{{.*}}normcdfinvf{{.*}}(float + result[0] = normcdfinvf(x); + } +} + +// Test that normcdfinv is properly defined and uses SPIRV OCL builtins +// CHECK-LABEL: define {{.*}} @{{.*}}test_normcdfinv +void test_normcdfinv(double x, double *result) { + #pragma omp target + { + // CHECK: call {{.*}} double @{{.*}}normcdfinv{{.*}}(double + result[0] = normcdfinv(x); + } +} + + diff --git a/offload/test/offloading/normcdfinv_accuracy.cpp b/offload/test/offloading/normcdfinv_accuracy.cpp new file mode 100644 index 0000000000000..bf6aa1b60a0fc --- /dev/null +++ b/offload/test/offloading/normcdfinv_accuracy.cpp @@ -0,0 +1,240 @@ +// RUN: %libomptarget-compilexx-generic -fopenmp-offload-mandatory && +// %libomptarget-run-generic REQUIRES: gpu + +#include +#include +#include + +#define TOLERANCE_F32 1e-3f +#define TOLERANCE_F64 1e-3 +#pragma omp declare target +static constexpr __attribute__((always_inline, nothrow)) float +normcdfinvf(float __a); +static constexpr __attribute__((always_inline, nothrow)) double +normcdfinv(double __a); +static constexpr __attribute__((always_inline, nothrow)) float +normcdff(float __a); +static constexpr __attribute__((always_inline, nothrow)) double +normcdf(double __a); +#pragma omp end declare target +// Test normcdfinv accuracy for float +bool test_normcdfinvf() { + bool passed = true; + + // Test known values + struct TestCase { + float input; + float expected; + const char *name; + } test_cases[] = { + {0.5f, 0.0f, "median (0.5)"}, {0.1587f, -1.0f, "1 sigma below"}, + {0.8413f, 1.0f, "1 sigma above"}, {0.0228f, -2.0f, "2 sigma below"}, + {0.9772f, 2.0f, "2 sigma above"}, {0.00135f, -3.0f, "3 sigma below"}, + {0.99865f, 3.0f, "3 sigma above"}, + }; + + int num_tests = sizeof(test_cases) / sizeof(test_cases[0]); + + for (int i = 0; i < num_tests; i++) { + float result = 0.0f; + float input = test_cases[i].input; + +#pragma omp target map(tofrom : result) map(to : input) + { + result = normcdfinvf(input); + } + + float error = fabsf(result - test_cases[i].expected); + if (error > TOLERANCE_F32) { + printf("FAIL: normcdfinvf(%s): normcdfinvf(%f) = %f, expected %f (error: " + "%e)\n", + test_cases[i].name, input, result, test_cases[i].expected, error); + passed = false; + } else { + printf("PASS: normcdfinvf(%s): normcdfinvf(%f) = %f (error: %e)\n", + test_cases[i].name, input, result, error); + } + } + + return passed; +} + +// Test normcdfinv accuracy for double +bool test_normcdfinv() { + bool passed = true; + + struct TestCase { + double input; + double expected; + const char *name; + } test_cases[] = { + {0.5, 0.0, "median (0.5)"}, {0.1587, -1.0, "1 sigma below"}, + {0.8413, 1.0, "1 sigma above"}, {0.0228, -2.0, "2 sigma below"}, + {0.9772, 2.0, "2 sigma above"}, {0.00135, -3.0, "3 sigma below"}, + {0.99865, 3.0, "3 sigma above"}, + }; + + int num_tests = sizeof(test_cases) / sizeof(test_cases[0]); + + for (int i = 0; i < num_tests; i++) { + double result = 0.0; + double input = test_cases[i].input; + +#pragma omp target map(tofrom : result) map(to : input) + { + result = normcdfinv(input); + } + + double error = fabs(result - test_cases[i].expected); + if (error > TOLERANCE_F64) { + printf("FAIL: normcdfinv(%s): normcdfinv(%f) = %f, expected %f (error: " + "%e)\n", + test_cases[i].name, input, result, test_cases[i].expected, error); + passed = false; + } else { + printf("PASS: normcdfinv(%s): normcdfinv(%f) = %f (error: %e)\n", + test_cases[i].name, input, result, error); + } + } + + return passed; +} + +// Test inverse property: normcdfinv(normcdf(x)) ≈ x +bool test_inverse_property() { + bool passed = true; + + double test_values[] = {-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0}; + int num_values = sizeof(test_values) / sizeof(test_values[0]); + + for (int i = 0; i < num_values; i++) { + double x = test_values[i]; + double result = 0.0; + +#pragma omp target map(tofrom : result) map(to : x) + { + double cdf_val = normcdf(x); + result = normcdfinv(cdf_val); + } + + double error = fabs(result - x); + if (error > TOLERANCE_F64) { + printf("FAIL: Inverse property at x=%f: normcdfinv(normcdf(%f)) = %f " + "(error: %e)\n", + x, x, result, error); + passed = false; + } else { + printf("PASS: Inverse property at x=%f: error = %e\n", x, error); + } + } + + return passed; +} + +// Test symmetry property: normcdfinv(1-p) ≈ -normcdfinv(p) +bool test_symmetry_property() { + bool passed = true; + + double test_probs[] = {0.1, 0.2, 0.3, 0.4}; + int num_probs = sizeof(test_probs) / sizeof(test_probs[0]); + + for (int i = 0; i < num_probs; i++) { + double p = test_probs[i]; + double result1 = 0.0, result2 = 0.0; + +#pragma omp target map(tofrom : result1, result2) map(to : p) + { + result1 = normcdfinv(p); + result2 = normcdfinv(1.0 - p); + } + + double expected = -result1; + double error = fabs(result2 - expected); + + if (error > TOLERANCE_F64) { + printf("FAIL: Symmetry at p=%f: normcdfinv(%f) = %f, normcdfinv(%f) = %f " + "(error: %e)\n", + p, p, result1, 1.0 - p, result2, error); + passed = false; + } else { + printf("PASS: Symmetry at p=%f: error = %e\n", p, error); + } + } + + return passed; +} + +// Test all three regions of the approximation +bool test_three_regions() { + bool passed = true; + + struct TestCase { + double input; + const char *region; + } test_cases[] = { + {0.001, "low tail (p < 0.02425)"}, + {0.01, "low tail"}, + {0.5, "central region (0.02425 <= p <= 0.97575)"}, + {0.99, "high tail (p > 0.97575)"}, + {0.999, "high tail"}, + }; + + int num_tests = sizeof(test_cases) / sizeof(test_cases[0]); + + for (int i = 0; i < num_tests; i++) { + double p = test_cases[i].input; + double result = 0.0; + +#pragma omp target map(tofrom : result) map(to : p) + { + result = normcdfinv(p); + } + + // Verify by checking that normcdf(result) ≈ p + double verify = 0.0; +#pragma omp target map(tofrom : verify) map(to : result) + { + verify = normcdf(result); + } + + double error = fabs(verify - p); + if (error > 1e-6) { + printf("FAIL: Region test %s: normcdf(normcdfinv(%f)) = %f (error: %e)\n", + test_cases[i].region, p, verify, error); + passed = false; + } else { + printf("PASS: Region test %s: normcdfinv(%f) = %f, verify error = %e\n", + test_cases[i].region, p, result, error); + } + } + + return passed; +} + +int main() { + bool all_passed = true; + + printf("=== Testing normcdfinvf (float) ===\n"); + all_passed &= test_normcdfinvf(); + + printf("\n=== Testing normcdfinv (double) ===\n"); + all_passed &= test_normcdfinv(); + + printf("\n=== Testing inverse property ===\n"); + all_passed &= test_inverse_property(); + + printf("\n=== Testing symmetry property ===\n"); + all_passed &= test_symmetry_property(); + + printf("\n=== Testing three regions ===\n"); + all_passed &= test_three_regions(); + + if (all_passed) { + printf("\n=== ALL TESTS PASSED ===\n"); + // CHECK: ALL TESTS PASSED + return 0; + } else { + printf("\n=== SOME TESTS FAILED ===\n"); + return 1; + } +}