Skip to content

feat: support bf16 from pytorch dataset#6342

Merged
eddyxu merged 6 commits into
mainfrom
lei/torch_bf16
Apr 1, 2026
Merged

feat: support bf16 from pytorch dataset#6342
eddyxu merged 6 commits into
mainfrom
lei/torch_bf16

Conversation

@eddyxu
Copy link
Copy Markdown
Member

@eddyxu eddyxu commented Mar 30, 2026

Summary

Support round-trip to use bf16 from PyTorch

Co-authored-by: Claude Opus 4.6 noreply@anthropic.com

@github-actions github-actions Bot added enhancement New feature or request A-python Python bindings labels Mar 30, 2026
@github-actions
Copy link
Copy Markdown
Contributor

PR Review

P0 Bug: uint64_as_int64 conversion broken for numeric types

The insertion of the elif _is_bfloat16_type(arr.type): branch has accidentally moved the uint64_as_int64 guard into the wrong branch.

Before (base):

        elif (
            pa.types.is_integer(arr.type)
            or pa.types.is_floating(arr.type)
            or pa.types.is_boolean(arr.type)
        ):
            tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False))

            if uint64_as_int64 and tensor.dtype == torch.uint64:  # ← inside numeric branch ✓
                tensor = tensor.to(torch.int64)

After (this PR):

        elif (
            pa.types.is_integer(arr.type)
            ...
        ):
            tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False))
        elif _is_bfloat16_type(arr.type):
            tensor = _bf16_to_tensor(arr)

            if uint64_as_int64 and tensor.dtype == torch.uint64:  # ← now inside bf16 branch ✗
                tensor = tensor.to(torch.int64)

The uint64_as_int64 conversion no longer runs for numeric uint64 columns (regression), and is dead code inside the bf16 branch (bf16 tensors will never have torch.uint64 dtype).

Fix: Move the if uint64_as_int64 block back into (or after) the numeric elif, e.g. by dedenting it to run after all branches, or by keeping it inside the numeric branch:

        ):
            tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False))

            if uint64_as_int64 and tensor.dtype == torch.uint64:
                tensor = tensor.to(torch.int64)
        elif _is_bfloat16_type(arr.type):
            tensor = _bf16_to_tensor(arr)
        elif hf_converter is not None:

Minor: No null handling in _bf16_to_tensor

If a bf16 array contains nulls, _bf16_to_tensor will silently produce garbage values for null slots. This is consistent with the existing numeric path (to_numpy(zero_copy_only=False) also fills nulls with 0), so not blocking — but worth a brief doc comment noting the assumption.


Overall the approach (reinterpreting uint16 storage bytes as bfloat16) is sound and the test coverage for the happy path is good.

@eddyxu eddyxu requested a review from westonpace March 30, 2026 16:22
Copy link
Copy Markdown
Member

@westonpace westonpace left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits

Null values are replaced with NaN.
"""
storage = arr.storage if isinstance(arr.type, pa.ExtensionType) else arr
buf = storage.buffers()[1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do a sanity check that the data type of storage is a 16-bit type at this point?

Comment thread python/python/lance/torch/data.py Outdated
buf = storage.buffers()[1]
offset = storage.offset * 2 # 2 bytes per bf16 value
np_uint16 = np.frombuffer(buf, dtype=np.uint16, count=len(storage), offset=offset)
tensor = torch.from_numpy(np_uint16.copy()).view(torch.bfloat16)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the copy here so that the resulting buffer can be mutable?

np_uint16 = np.frombuffer(buf, dtype=np.uint16, count=len(storage), offset=offset)
tensor = torch.from_numpy(np_uint16.copy()).view(torch.bfloat16)
if arr.null_count > 0:
null_mask = torch.from_numpy(arr.is_null().to_numpy(zero_copy_only=False))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there should be a way to do this without a copy but maybe not.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asked claude / codex to do double checks, to make this opportunist

Comment thread python/python/lance/torch/data.py Outdated
if uint64_as_int64 and tensor.dtype == torch.uint64:
if (
uint64_as_int64 and tensor.dtype == torch.uint64
): # ← inside numeric branch ✓
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of a strange comment. I'm not really sure what it means.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@github-actions github-actions Bot added the A-java Java bindings + JNI label Mar 31, 2026
@eddyxu eddyxu merged commit 21d830a into main Apr 1, 2026
12 checks passed
@eddyxu eddyxu deleted the lei/torch_bf16 branch April 1, 2026 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

A-java Java bindings + JNI A-python Python bindings enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants