feat: support bf16 from pytorch dataset#6342
Conversation
PR ReviewP0 Bug:
|
| Null values are replaced with NaN. | ||
| """ | ||
| storage = arr.storage if isinstance(arr.type, pa.ExtensionType) else arr | ||
| buf = storage.buffers()[1] |
There was a problem hiding this comment.
Should we do a sanity check that the data type of storage is a 16-bit type at this point?
| buf = storage.buffers()[1] | ||
| offset = storage.offset * 2 # 2 bytes per bf16 value | ||
| np_uint16 = np.frombuffer(buf, dtype=np.uint16, count=len(storage), offset=offset) | ||
| tensor = torch.from_numpy(np_uint16.copy()).view(torch.bfloat16) |
There was a problem hiding this comment.
Is the copy here so that the resulting buffer can be mutable?
| np_uint16 = np.frombuffer(buf, dtype=np.uint16, count=len(storage), offset=offset) | ||
| tensor = torch.from_numpy(np_uint16.copy()).view(torch.bfloat16) | ||
| if arr.null_count > 0: | ||
| null_mask = torch.from_numpy(arr.is_null().to_numpy(zero_copy_only=False)) |
There was a problem hiding this comment.
Seems like there should be a way to do this without a copy but maybe not.
There was a problem hiding this comment.
asked claude / codex to do double checks, to make this opportunist
| if uint64_as_int64 and tensor.dtype == torch.uint64: | ||
| if ( | ||
| uint64_as_int64 and tensor.dtype == torch.uint64 | ||
| ): # ← inside numeric branch ✓ |
There was a problem hiding this comment.
Kind of a strange comment. I'm not really sure what it means.
Summary
Support round-trip to use bf16 from PyTorch
Co-authored-by: Claude Opus 4.6 noreply@anthropic.com