diff --git a/src/unittests/f32_interval.spec.ts b/src/unittests/f32_interval.spec.ts index 987e968844..2fbcc7641a 100644 --- a/src/unittests/f32_interval.spec.ts +++ b/src/unittests/f32_interval.spec.ts @@ -13,6 +13,7 @@ import { additionInterval, atanInterval, atan2Interval, + atanhInterval, ceilInterval, clampMedianInterval, clampMinMaxInterval, @@ -727,6 +728,34 @@ g.test('atanInterval') ); }); +g.test('atanhInterval') + .paramsSubcasesOnly( + // prettier-ignore + [ + // Some of these are hard coded, since the error intervals are difficult to express in a closed human readable + // form due to the inherited nature of the errors. + { input: kValue.f32.infinity.negative, expected: kAny }, + { input: kValue.f32.negative.min, expected: kAny }, + { input: -1, expected: kAny }, + { input: -0.1, expected: [hexToF64(0xbfb9af9a, 0x60000000), hexToF64(0xbfb9af8c, 0xc0000000)] }, // ~-0.1003... + { input: 0, expected: [hexToF64(0xbe960000, 0x20000000), hexToF64(0x3e980000, 0x00000000)] }, // ~0 + { input: 0.1, expected: [hexToF64(0x3fb9af8b, 0x80000000), hexToF64(0x3fb9af9b, 0x00000000)] }, // ~0.1003... + { input: 1, expected: kAny }, + { input: kValue.f32.positive.max, expected: kAny }, + { input: kValue.f32.infinity.positive, expected: kAny }, + ] + ) + .fn(t => { + const input = t.params.input; + const expected = new F32Interval(...t.params.expected); + + const got = atanhInterval(input); + t.expect( + objectEquals(expected, got), + `atanhInterval(${input}) returned ${got}. Expected ${expected}` + ); + }); + g.test('ceilInterval') .paramsSubcasesOnly( // prettier-ignore diff --git a/src/webgpu/shader/execution/expression/call/builtin/atanh.spec.ts b/src/webgpu/shader/execution/expression/call/builtin/atanh.spec.ts index a60eee324f..d8d02169fc 100644 --- a/src/webgpu/shader/execution/expression/call/builtin/atanh.spec.ts +++ b/src/webgpu/shader/execution/expression/call/builtin/atanh.spec.ts @@ -12,7 +12,12 @@ Note: The result is not mathematically meaningful when abs(e) >= 1. import { makeTestGroup } from '../../../../../../common/framework/test_group.js'; import { GPUTest } from '../../../../../gpu_test.js'; -import { allInputSources } from '../../expression.js'; +import { TypeF32 } from '../../../../../util/conversion.js'; +import { atanhInterval } from '../../../../../util/f32_interval.js'; +import { biasedRange, fullF32Range } from '../../../../../util/math.js'; +import { allInputSources, Case, makeUnaryF32IntervalCase, run } from '../../expression.js'; + +import { builtin } from './builtin.js'; export const g = makeTestGroup(GPUTest); @@ -30,7 +35,18 @@ g.test('f32') .params(u => u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const) ) - .unimplemented(); + .fn(async t => { + const makeCase = (n: number): Case => { + return makeUnaryF32IntervalCase(n, atanhInterval); + }; + + const cases = [ + ...biasedRange(-1, -0.9, 20), // discontinuity at x = -1 + ...biasedRange(1, 0.9, 20), // discontinuity at x = 1 + ...fullF32Range(), + ].map(makeCase); + run(t, builtin('atanh'), [TypeF32], TypeF32, t.params, cases); + }); g.test('f16') .specURL('https://www.w3.org/TR/WGSL/#float-builtin-functions') diff --git a/src/webgpu/util/f32_interval.ts b/src/webgpu/util/f32_interval.ts index f03f8b3743..ada1f55576 100644 --- a/src/webgpu/util/f32_interval.ts +++ b/src/webgpu/util/f32_interval.ts @@ -577,6 +577,21 @@ export function atan2Interval(y: number | F32Interval, x: number | F32Interval): return runBinaryOp(toInterval(y), toInterval(x), Atan2IntervalOp); } +const AtanhIntervalOp: PointToIntervalOp = { + impl: (n: number) => { + // atanh(x) = log((1.0 + x) / (1.0 - x)) * 0.5 + const numerator = additionInterval(1.0, n); + const denominator = subtractionInterval(1.0, n); + const log_interval = logInterval(divisionInterval(numerator, denominator)); + return multiplicationInterval(log_interval, 0.5); + }, +}; + +/** Calculate an acceptance interval of atanh(x) */ +export function atanhInterval(n: number): F32Interval { + return runPointOp(toInterval(n), AtanhIntervalOp); +} + const CeilIntervalOp: PointToIntervalOp = { impl: (n: number): F32Interval => { return correctlyRoundedInterval(Math.ceil(n));