Skip to content

Add CPU Cast op support for Float8E8M0#28435

Merged
tianleiwu merged 8 commits into
mainfrom
tlwu/float8e8m0_ops
May 12, 2026
Merged

Add CPU Cast op support for Float8E8M0#28435
tianleiwu merged 8 commits into
mainfrom
tlwu/float8e8m0_ops

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu commented May 9, 2026

Description

This PR adds Float8E8M0 support to the CPU Cast operator, implementing follow-up items 3 (Cast op support) and 4 (rounding mode) from PR #28381.

Changes

Float8E8M0 RoundMode (include/onnxruntime/core/common/float8.h)

  • Added RoundMode enum (Up, Down, Nearest) to Float8E8M0 struct
  • Extended the constructor with an optional round_mode parameter (default: Up)
  • Up/Nearest: ties round away from zero (higher power of 2)
  • Down: ties round towards zero (lower power of 2)
  • Backward-compatible — existing callers use the default Up mode

CPU Cast op (onnxruntime/core/providers/cpu/tensor/cast_op.cc)

  • Added Float8E8M0 to the Cast kernel's enabled type list (AllIRv10WithInt2)
  • Added FLOAT8E8M0 to the saturate attribute validation
  • Added round_mode attribute parsing ("up", "down", "nearest") per the ONNX opset 25 Cast schema
  • Added CastToE8M0Dispatcher template that handles casting from any source type (float, double, int, MLFloat16, BFloat16, Int4, Int2, string, other Float8 types) to Float8E8M0 with proper saturate and round_mode support
  • E8M0 destination dispatch is handled separately from the other Float8 types to thread round_mode through

Tests (onnxruntime/test/providers/cpu/tensor/cast_op_test.cc)

  • 10 new tests:
    • FloatToFloat8E8M0_Saturate / _NoSaturate — basic float→E8M0 with saturate on/off
    • FloatToFloat8E8M0_RoundModeUp / _RoundModeDown / _RoundModeNearest — all three rounding modes
    • Float8E8M0ToFloat / Float8E8M0ToDouble — E8M0→float/double conversion
    • MLFloat16ToFloat8E8M0 / DoubleToFloat8E8M0 / Int32ToFloat8E8M0 — various source types

Motivation and Context

PR #28381 added the Float8E8M0 data type to ORT but deferred Cast op support. This PR completes the Cast op integration so that models can use Cast(to=FLOAT8E8M0) with proper saturate and round_mode semantics as defined in the ONNX opset 25 spec.

Testing

All 10 new Cast op tests pass. All 19 existing Float8E8M0 unit tests and all 4 existing Float8 Cast tests continue to pass.

./onnxruntime_provider_test --gtest_filter="CastOpTest.*Float8E8M0*"
./onnxruntime_test_all --gtest_filter="Float8E8M0_Tests.*"
./onnxruntime_provider_test --gtest_filter="CastOpTest.ToFloat8*"

@justinchuby
Copy link
Copy Markdown
Contributor

Copy link
Copy Markdown
Contributor

Copilot AI commented May 9, 2026

@copilot please cross review with https://github.com/onnx/onnx/blob/3a8bb46b4787e9d03417c6d777ffb6d31cb58c6c/onnx/numpy_helper.py#L20

Cross-reviewed against both the ONNX numpy_helper.py to_float8e8m0 reference and the Cast opset 25 schema in defs.cc. Found that all three rounding modes had wrong semantics — fixed in de57195.

Before → After (for normal floats):

  • up: was round-half-up (mantissa >= 0x400000) → now ceiling (mantissa > 0). Per ONNX: has_fraction = (f_bits & 0x7FFFFF) > 0. Example: 1.3 now correctly maps to 2.0, not 1.0; 2.5 maps to 4.0, not 2.0.
  • down: was round-half-down (mantissa > 0x400000) → now floor (never increment). Per ONNX: pass # No rounding needed. Example: 1.7 now correctly maps to 1.0, not 2.0.
  • nearest: was already correct (mantissa >= 0x400000 = guard bit), unchanged.

Also fixed subnormal rounding thresholds to match the ONNX G/R/S scheme, and replaced self-referential expected values in the three round-mode tests with hardcoded FromBits() values so the tests actually validate the semantics.

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • onnx.ai
    • Triggering command: /home/REDACTED/work/_temp/ghcca-node/node/bin/node /home/REDACTED/work/_temp/ghcca-node/node/bin/node --enable-source-maps /home/REDACTED/work/_temp/copilot-developer-action-main/dist/index.js (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Copilot AI requested a review from justinchuby May 9, 2026 16:27
tianleiwu added 4 commits May 9, 2026 14:25
…tations

- Fix denorm Up (ceiling) path: mantissa == 0x400000 is exactly 2^(-127),
  which is representable as val=0. Changed threshold from (mantissa > 0)
  to (mantissa > 0x400000) so exact 2^(-127) values don't incorrectly
  round up to val=1.
- Fix Rounding test: 1.25 with Up (ceiling) mode rounds up to 2.0
  (val=128), not down to 1.0 (val=127).
- Fix SubnormalRounding test: subnormals are valid positive values within
  E8M0 range, so saturate=false does not produce NaN for them.
…with DISABLE_FLOAT8_TYPES

When DISABLE_FLOAT8_TYPES is defined (minimal/wasm/Android builds),
Float8E8M0 is not declared, so boost::mp11::mp_remove<..., Float8E8M0>
fails to compile. Since AllIRv10WithInt2 already excludes Float8E8M0
in that case, the PreOpset24 type lists are equivalent to the full ones.

Fixes CI failures in:
- Linux CPU Minimal Build E2E (6a/6b/6c/7)
- Android CI (AndroidBinarySizeCheckJob_MinimalBaseline, android_nnapi_ep)
- Wasm builds (wasm_Debug, wasm_Release)
- WebGPU builds
@tianleiwu tianleiwu changed the title Add CPU Cast op support for Float8E8M0 with round_mode Add CPU Cast op support for Float8E8M0 May 11, 2026
@titaiwangms
Copy link
Copy Markdown
Contributor

Review

Overall a well-scoped, well-tested change. The new dispatcher cleanly threads round_mode for the E8M0 destination only, and the pre-opset-24 type-list split keeps the new type from leaking into earlier registrations. Found a couple of items worth addressing before merge.

Major

1. The "backward-compatible" claim is misleading. The ctor signature is compatible (default arg), but Float8E8M0(v) now produces different bits than before — e.g. 1.25f was val=127 (1.0), now val=128 (2.0). Two public helpers in float8.h still use the default:

  • inline Float8E8M0 operator""_f8e8m0f(long double v)Float8E8M0(static_cast<float>(v), true) (float8.h:844)
  • FloatToFloat8E8M0(const float*, Float8E8M0*, size_t, bool saturate)new (d) Float8E8M0(*src, saturate) (float8.h:858)

Both are in a public header. Suggest one of:

  • Keep the ctor default as Nearest (preserves old behavior) and pass RoundMode::Up explicitly from the Cast kernel — Up is the ONNX op's default, not the ctor's.
  • Drop the default entirely so callers opt in.
  • Or update FloatToFloat8E8M0 to take a RoundMode parameter and call out the breaking change in the PR description / changelog.

2. Subnormal nearest semantics — matches the ONNX reference, but worth confirming intent. The implementation uses mantissa > 0x400000 for subnormal Nearest. A linear-distance reading of "nearest" would put the threshold at 0x600000 (midpoint between 2^-127 and 2^-126). I checked: ONNX's reference to_float8e8m0 in numpy_helper.py uses the same G/R/S scheme with lsb = (exponent > 0) → for subnormals lsb=0, giving exactly mantissa > 0x400000. So the PR matches the reference even though it disagrees with a linear reading of the prose. Not a bug per the reference — but there is no test for subnormal nearest. Recommend adding one (e.g. mantissa 0x500000val=1) to lock in the chosen semantics.

Minor

3. CastToE8M0Dispatcher is a serial scalar loop and bypasses the optimized casters (TensorCaster, MlasConvertHalfToFloatBufferInParallel for MLFloat16, threadpool parallelism). No prior path is regressed (E8M0 dest is new), but for large MLFloat16/BFloat16 → E8M0 casts this could be a perf cliff. Consider threadpool parallelization or specializing the MLFloat16 / BFloat16 paths.

4. CUDA Cast registration not updated. Opset 24+ Cast on CUDA still omits Float8E8M0 and ignores round_mode. Models with Cast(to=FLOAT8E8M0) will silently fall back to CPU. Either out-of-scope follow-up or worth a note in the PR description.

5. Duplicate Up: and Nearest: case bodies in the subnormal switch. Both set round_up = (mantissa > 0x00400000). If they really must be identical (per the ONNX reference, they do), merge with a fall-through and a comment ("for subnormals, Up and Nearest coincide because only one representable value lies in the subnormal range"). Currently looks like copy-paste.

6. default: arm on an exhaustive enum class switch (both branches) silently treats unknown values as Nearest. Suggest dropping default: so the compiler warns if a future case is added.

7. Test oracles use the function under test. _Saturate, _NoSaturate, MLFloat16ToFloat8E8M0, DoubleToFloat8E8M0, Int32ToFloat8E8M0 all build expected outputs via Float8E8M0(v, saturate, round_mode) — the very ctor being exercised. The _RoundMode* and edge-case tests already use explicit FromBits() constants, which is the right pattern; please extend it to the others.

8. Missing negative tests:

  • Cast at opset 23 with to=FLOAT8E8M0 should fail to find a kernel.
  • round_mode="down" (or "nearest") with to != FLOAT8E8M0 should throw — code does it, but no test covers the rejection.

Nits

  • EnabledSrcTypesPreOpset24 — describes the consumer, not the content. Consider EnabledSrcTypesWithoutE8M0.
  • CastToE8M0Dispatcher breaks the existing SrcDispatcher / SrcDispatcherNoSat naming pattern. DstE8M0Dispatcher would stay parallel.
  • Stale comment 1.5 (tie) in FloatToFloat8E8M0_Saturate is nearest-mode vocabulary; with the new Up default 1.5 is not a tie.

Praise

  • The G/R/S explanation comment on the subnormal branch is excellent reference material.
  • Splitting opset 24+ registrations from the pre-24 ones with separate type lists is the right approach to avoid retroactively advertising E8M0 to earlier schemas.
  • round_mode validation is fail-closed and scoped to E8M0.
  • The from == to_ early-copy preserves identity for E8M0→E8M0 (no spurious re-rounding).

Suggested action: Major #1 is the only item I'd consider a blocker — at minimum a documentation/changelog issue, at most warrants reverting the ctor default. Major #2 deserves a quick confirmation that matching the ONNX reference over the prose is intentional. The rest is polish.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends ONNX Runtime’s CPU Cast operator to support the FLOAT8E8M0 destination type, including parsing/handling of the round_mode attribute and updating Float8E8M0 conversion semantics and tests accordingly.

Changes:

  • Added Float8E8M0::RoundMode and threaded round_mode through CPU Cast when casting to FLOAT8E8M0.
  • Extended CPU Cast kernel type constraints/dispatch to include FLOAT8E8M0 for opset 24+ while keeping earlier opsets’ constraints unchanged.
  • Added/updated unit tests and docs to reflect FLOAT8E8M0 Cast support and revised conversion behavior.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
include/onnxruntime/core/common/float8.h Adds RoundMode and updates float→E8M0 conversion logic to support multiple rounding modes.
onnxruntime/core/providers/cpu/tensor/cast_op.cc Enables CPU Cast to/from FLOAT8E8M0, validates attributes, and adds a dedicated dispatcher to pass round_mode.
onnxruntime/test/providers/cpu/tensor/cast_op_test.cc Adds CPU Cast tests covering float8e8m0 conversions, saturate behavior, and round_mode variants.
onnxruntime/test/framework/float8e8m0_test.cc Updates existing Float8E8M0 unit test expectations/comments for the revised default rounding behavior.
docs/OperatorKernels.md Updates the kernel support table to include tensor(float8e8m0) for Cast in opset 24 and 25+.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread include/onnxruntime/core/common/float8.h Outdated
Comment thread onnxruntime/test/framework/float8e8m0_test.cc Outdated
Comment thread onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
@tianleiwu tianleiwu merged commit 8cfb63e into main May 12, 2026
88 checks passed
@tianleiwu tianleiwu deleted the tlwu/float8e8m0_ops branch May 12, 2026 04:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants