-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
only save missing bits to reconstruct fp32 master weights #432
Open
ngc92
wants to merge
3
commits into
karpathy:master
Choose a base branch
from
ngc92:master-mantissa
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+317
−56
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
Utilities for manipulating at the bit level | ||
*/ | ||
#ifndef LLMC_BITS_CUH | ||
#define LLMC_BITS_CUH | ||
|
||
#include "cuda_bf16.h" | ||
|
||
// implementation of unreachable from C++23/C23 that works across compilers | ||
[[noreturn]] __host__ __device__ inline void unreachable() | ||
{ | ||
// Uses compiler specific extensions if possible. | ||
// Even if no extension is used, undefined behavior is still raised by | ||
// an empty function body and the noreturn attribute. | ||
#if defined(_MSC_VER) && !defined(__clang__) // MSVC | ||
__assume(false); | ||
#else // GCC, Clang | ||
__builtin_unreachable(); | ||
#endif | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// bit-fiddling to reconstruct master weights from BF16 and missing bits | ||
// handling bf16 is _almost_ just storing the 16 bits of the mantissa that get | ||
// truncated when going from float -> bf16; except we do stochastic rounding. | ||
// if we end up rounding towards zero, we could just keep the bits, but if we | ||
// round away from zero, there is a chance that the other bits of the bf16 no | ||
// longer correspond to the bits in the original float. So we have to reserve | ||
// one bit to remember whether we rounded up or down, for an effective master | ||
// weight precision of fp31. That should still be more than sufficient. | ||
|
||
// Result of splitting a float into a stochastically-rounded bfloat16 and | ||
// additional reconstruction bits | ||
struct SplitFloatResult { | ||
nv_bfloat16 b_float; | ||
unsigned short bits; | ||
}; | ||
|
||
// UB-free bit-level conversions. A C++17 version for C++20 std::bit_cast | ||
template<class T, class S> | ||
__host__ __device__ T bit_cast(S v) { | ||
T dest; | ||
static_assert(sizeof(v) == sizeof(dest), "Size mismatch."); | ||
memcpy(&dest, &v, sizeof(v)); | ||
return dest; | ||
} | ||
|
||
// Splits a float into a bfloat16 and the remaining mantissa bits | ||
__host__ __device__ SplitFloatResult split_float(float value, unsigned short threshold) { | ||
unsigned int float_bits = bit_cast<unsigned int>(value); | ||
// IEEE 754: float: S E(8) M (23) bfloat: same, but mantissa 23-16 = 7 bits | ||
// ideally, we'd just store the cut-off 16 bits, but that doesn't work if rounding | ||
// is involved. | ||
unsigned int rounded_bits = float_bits & 0x0000FFFFu; | ||
if(rounded_bits > threshold) { | ||
SplitFloatResult result; | ||
result.b_float = __float2bfloat16_rn(bit_cast<float>(float_bits | 0xFFFFu)); | ||
// use last bit to remember that we rounded away from zero | ||
result.bits = rounded_bits & (~1u) | 1u; | ||
return result; | ||
} else { | ||
// truncation is easy | ||
SplitFloatResult result; | ||
result.b_float = bit_cast<__nv_bfloat16>((unsigned short)(float_bits >> 16u)); | ||
// use last bit to remember that we rounded towards zero | ||
result.bits = rounded_bits & (~1u) | 0u; | ||
return result; | ||
} | ||
} | ||
|
||
// Reassembles a float from the bfloat16 part and the missing mantissa | ||
__host__ __device__ float assemble_float(nv_bfloat16 b_float, unsigned short bits) { | ||
constexpr const unsigned short BF16_SIGN_MASK = 0b1'00000000'0000000u; | ||
constexpr const unsigned short BF16_EXPONENT_MASK = 0b0'11111111'0000000u; | ||
constexpr const unsigned short BF16_MANTISSA_MASK = 0b0'00000000'1111111u; | ||
unsigned short bf = bit_cast<unsigned short>(b_float); | ||
if(bits & 1u) { | ||
// if we rounded away from zero, we need to undo these changes. | ||
// first, check if the mantissa (7 bits) of bf16 is zero | ||
const unsigned short mantissa = bf & BF16_MANTISSA_MASK; | ||
if(mantissa == 0) { | ||
// mantissa overflowed, need to decrement the exponent | ||
const unsigned short exponent = (bf & BF16_EXPONENT_MASK) >> 7u; | ||
if(exponent == 0) { | ||
// zero, cannot be reached if we round away from zero | ||
unreachable(); | ||
} | ||
// decrement the exponent and set mantissa to all-ones | ||
bf = (bf & BF16_SIGN_MASK) | ((exponent-1u) << 7u) | BF16_MANTISSA_MASK; | ||
} else { | ||
// mantissa was incremented, decrement | ||
bf = (bf & (BF16_SIGN_MASK | BF16_EXPONENT_MASK)) | (mantissa - 1u); | ||
} | ||
} | ||
const unsigned int result = (bits & (~1u)) | (bf << 16u); | ||
return bit_cast<float>(result); | ||
} | ||
|
||
#endif // LLMC_BITS_CUH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Makefile for building the unit tests | ||
|
||
# Find nvcc (NVIDIA CUDA compiler) | ||
NVCC := $(shell which nvcc 2>/dev/null) | ||
ifeq ($(NVCC),) | ||
$(error nvcc not found.) | ||
endif | ||
|
||
# Compiler flags | ||
CFLAGS = -O3 --use_fast_math | ||
NVCCFLAGS = -lcublas -lcublasLt -std=c++17 -I.. | ||
MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ | ||
|
||
# Default rule for our CUDA files | ||
%: %.cu | ||
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ | ||
|
||
# Build all targets | ||
TARGETS = bits | ||
all: $(TARGETS) | ||
|
||
# Individual targets | ||
bits: bits.cu | ||
|
||
run_all: all | ||
@for target in $(TARGETS); do \ | ||
echo "\n========================================"; \ | ||
echo "Running $$target ..."; \ | ||
echo "========================================\n"; \ | ||
./$$target; \ | ||
done | ||
|
||
# Clean up | ||
clean: | ||
rm -f $(TARGETS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include "llmc/bits.cuh" | ||
|
||
#undef NDEBUG | ||
#include <assert.h> | ||
#include <float.h> | ||
#include <stdio.h> | ||
#include <math.h> | ||
|
||
float round_trip(float f, unsigned short threshold) { | ||
const SplitFloatResult split = split_float(f, threshold); | ||
const float r = assemble_float(split.b_float, split.bits); | ||
return r; | ||
} | ||
|
||
bool match_floats(float f1, float f2) { | ||
const unsigned int u1 = bit_cast<unsigned int>(f1); | ||
const unsigned int u2 = bit_cast<unsigned int>(f2); | ||
if((u1 & (~1u)) != (u2 & (~1u))) { | ||
printf("MISMATCH: %0x %0x\n", u1, u2); | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
#define ASSERT_ROUND_TRIP(f) \ | ||
assert(match_floats(f, round_trip(f, 0))); \ | ||
assert(match_floats(f, round_trip(f, 0xFFFF))); \ | ||
|
||
int main() { | ||
ASSERT_ROUND_TRIP(1.4623f) | ||
ASSERT_ROUND_TRIP(-63623.9f) | ||
ASSERT_ROUND_TRIP(FLT_TRUE_MIN) | ||
ASSERT_ROUND_TRIP(NAN) | ||
ASSERT_ROUND_TRIP(0) | ||
ASSERT_ROUND_TRIP(INFINITY) | ||
// make sure we trigger the "rounding increases exponent" code path | ||
const float increment_exponent = bit_cast<float>((unsigned int)(0x40ff'fff0)); | ||
ASSERT_ROUND_TRIP(increment_exponent) | ||
printf("PASS\n"); | ||
return EXIT_SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just do "| 1u" why clear the bit first? The end result is equivalent?
Maybe for consistency with the else branch? But in that case I'd add
| 0u
below as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made this consistent, and added a comment explaining what the 1 and 0 mean for good measure