Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only save missing bits to reconstruct fp32 master weights #432

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,18 @@ jobs:

- name: Build project
run: make -j4 -C dev/cuda

run-cpu-unit-tests:
runs-on: ubuntu-latest
container:
image: nvidia/cuda:12.4.1-devel-ubuntu22.04

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Build project
run: make -j4 -C test

- name: Run tests
run: ./test/bits
46 changes: 35 additions & 11 deletions llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AdamW kernel
// llmc internal imports
#include "cuda_common.h"
#include "cuda_utils.cuh"
#include "bits.cuh"

// ----------------------------------------------------------------------------
// CUDA kernels
Expand All @@ -16,7 +17,7 @@ __device__ float lerp(float start, float end, float weight) {
}

template <typename Tp, typename Tg>
__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
__device__ void adamw_update(Tp* params_memory, unsigned short* mantissas, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
float grad_scale, unsigned int seed) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -35,24 +36,47 @@ __device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg*
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
// fetch the old value of this parameter as a float, from either source
float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx];
// update this parameter
float old_param;
if (mantissas != NULL) {
if constexpr (std::is_same_v<Tp, __nv_bfloat16>) {
old_param = assemble_float(params_memory[idx], mantissas[idx]);
} else {
assert(false && "Master params are only implemented for bf16.");
}
} else {
old_param = (float)params_memory[idx];
}
// update this parameter in 32-bit precision
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));

// update our low precision version of the parameters using stochastic rounding
// this will be used in the next forward pass
stochastic_rounding(param, &params_memory[idx], seed);
// write the full, float version of the param into our master copy, if we maintain one
// this will be used in the next update
if (master_params_memory != NULL) { master_params_memory[idx] = param; }
// TODO: simply doing `params_memory[i] = (floatX)param;` breaks everything (why?)

// If we keep master parameter "copies", make sure to store the missing bits in the 'mantissas' array,
// otherwise we can directly go for stochastic rounding.
if (mantissas != NULL) {
if constexpr (std::is_same_v<Tp, __nv_bfloat16>) {
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed);
unsigned int threshold = random & 0xFFFFu;
SplitFloatResult split = split_float(param, threshold);
mantissas[idx] = split.bits;
params_memory[idx] = split.b_float;
} else {
assert(false && "Master params are only implemented for bf16.");
}
} else {
stochastic_rounding(param, &params_memory[idx], seed);
}
}

template <typename Tp, typename Tg>
__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
__global__ void adamw_kernel3(Tp* params_memory, unsigned short* mantissas, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
float grad_scale, unsigned int seed) {
adamw_update(params_memory + blockIdx.y * w_stride,
master_params_memory ? master_params_memory + blockIdx.y * s_stride : NULL,
mantissas ? mantissas + blockIdx.y * s_stride : NULL,
grads_memory + blockIdx.y * g_stride,
m_memory + blockIdx.y * s_stride,
v_memory + blockIdx.y * s_stride,
Expand All @@ -62,15 +86,15 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
}

template <typename Tp, typename Tg>
void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
void adamw_update(Tp* params_memory, unsigned short* mantissas, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
float grad_scale, unsigned int seed, cudaStream_t stream) {
// AdamW update
int block_size = 512;
int num_blocks = CEIL_DIV(num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
adamw_kernel3<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>(params_memory, master_params_memory, grads_memory,
adamw_kernel3<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>(params_memory, mantissas, grads_memory,
m_memory, v_memory, num_parameters, w_stride, g_stride, s_stride,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,
grad_scale, seed);
Expand Down
99 changes: 99 additions & 0 deletions llmc/bits.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
Utilities for manipulating at the bit level
*/
#ifndef LLMC_BITS_CUH
#define LLMC_BITS_CUH

#include "cuda_bf16.h"

// implementation of unreachable from C++23/C23 that works across compilers
[[noreturn]] __host__ __device__ inline void unreachable()
{
// Uses compiler specific extensions if possible.
// Even if no extension is used, undefined behavior is still raised by
// an empty function body and the noreturn attribute.
#if defined(_MSC_VER) && !defined(__clang__) // MSVC
__assume(false);
#else // GCC, Clang
__builtin_unreachable();
#endif
}

// ----------------------------------------------------------------------------
// bit-fiddling to reconstruct master weights from BF16 and missing bits
// handling bf16 is _almost_ just storing the 16 bits of the mantissa that get
// truncated when going from float -> bf16; except we do stochastic rounding.
// if we end up rounding towards zero, we could just keep the bits, but if we
// round away from zero, there is a chance that the other bits of the bf16 no
// longer correspond to the bits in the original float. So we have to reserve
// one bit to remember whether we rounded up or down, for an effective master
// weight precision of fp31. That should still be more than sufficient.

// Result of splitting a float into a stochastically-rounded bfloat16 and
// additional reconstruction bits
struct SplitFloatResult {
nv_bfloat16 b_float;
unsigned short bits;
};

// UB-free bit-level conversions. A C++17 version for C++20 std::bit_cast
template<class T, class S>
__host__ __device__ T bit_cast(S v) {
T dest;
static_assert(sizeof(v) == sizeof(dest), "Size mismatch.");
memcpy(&dest, &v, sizeof(v));
return dest;
}

// Splits a float into a bfloat16 and the remaining mantissa bits
__host__ __device__ SplitFloatResult split_float(float value, unsigned short threshold) {
unsigned int float_bits = bit_cast<unsigned int>(value);
// IEEE 754: float: S E(8) M (23) bfloat: same, but mantissa 23-16 = 7 bits
// ideally, we'd just store the cut-off 16 bits, but that doesn't work if rounding
// is involved.
unsigned int rounded_bits = float_bits & 0x0000FFFFu;
if(rounded_bits > threshold) {
SplitFloatResult result;
result.b_float = __float2bfloat16_rn(bit_cast<float>(float_bits | 0xFFFFu));
// use last bit to remember that we rounded away from zero
result.bits = rounded_bits & (~1u) | 1u;
Copy link
Contributor

@gordicaleksa gordicaleksa Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just do "| 1u" why clear the bit first? The end result is equivalent?

Maybe for consistency with the else branch? But in that case I'd add | 0u below as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made this consistent, and added a comment explaining what the 1 and 0 mean for good measure

return result;
} else {
// truncation is easy
SplitFloatResult result;
result.b_float = bit_cast<__nv_bfloat16>((unsigned short)(float_bits >> 16u));
// use last bit to remember that we rounded towards zero
result.bits = rounded_bits & (~1u) | 0u;
return result;
}
}

// Reassembles a float from the bfloat16 part and the missing mantissa
__host__ __device__ float assemble_float(nv_bfloat16 b_float, unsigned short bits) {
constexpr const unsigned short BF16_SIGN_MASK = 0b1'00000000'0000000u;
constexpr const unsigned short BF16_EXPONENT_MASK = 0b0'11111111'0000000u;
constexpr const unsigned short BF16_MANTISSA_MASK = 0b0'00000000'1111111u;
unsigned short bf = bit_cast<unsigned short>(b_float);
if(bits & 1u) {
// if we rounded away from zero, we need to undo these changes.
// first, check if the mantissa (7 bits) of bf16 is zero
const unsigned short mantissa = bf & BF16_MANTISSA_MASK;
if(mantissa == 0) {
// mantissa overflowed, need to decrement the exponent
const unsigned short exponent = (bf & BF16_EXPONENT_MASK) >> 7u;
if(exponent == 0) {
// zero, cannot be reached if we round away from zero
unreachable();
}
// decrement the exponent and set mantissa to all-ones
bf = (bf & BF16_SIGN_MASK) | ((exponent-1u) << 7u) | BF16_MANTISSA_MASK;
} else {
// mantissa was incremented, decrement
bf = (bf & (BF16_SIGN_MASK | BF16_EXPONENT_MASK)) | (mantissa - 1u);
}
}
const unsigned int result = (bits & (~1u)) | (bf << 16u);
return bit_cast<float>(result);
}

#endif // LLMC_BITS_CUH
14 changes: 13 additions & 1 deletion llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,26 @@ extern cudaDeviceProp deviceProp;
// Error checking

// CUDA error checking
void inline cudaCheck(cudaError_t error, const char *file, int line) {
inline void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))

// like cudaFree, but checks for errors _and_ resets the pointer.
template<class T>
inline void cudaFreeCheck(T** ptr, const char *file, int line) {
cudaError_t error = cudaFree(*ptr);
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
*ptr = nullptr;
}
#define cudaFreeCheck(ptr) (cudaFreeCheck(ptr, __FILE__, __LINE__))

// ----------------------------------------------------------------------------
// CUDA Precision settings and defines

Expand Down
1 change: 1 addition & 0 deletions profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ int main(int argc, char *argv[]) {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_init_common(&model);
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");

int B = 24; // if program OOMs decrease this number, e.g. all the way down to 4 or etc
Expand Down
35 changes: 35 additions & 0 deletions test/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Makefile for building the unit tests

# Find nvcc (NVIDIA CUDA compiler)
NVCC := $(shell which nvcc 2>/dev/null)
ifeq ($(NVCC),)
$(error nvcc not found.)
endif

# Compiler flags
CFLAGS = -O3 --use_fast_math
NVCCFLAGS = -lcublas -lcublasLt -std=c++17 -I..
MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/

# Default rule for our CUDA files
%: %.cu
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@

# Build all targets
TARGETS = bits
all: $(TARGETS)

# Individual targets
bits: bits.cu

run_all: all
@for target in $(TARGETS); do \
echo "\n========================================"; \
echo "Running $$target ..."; \
echo "========================================\n"; \
./$$target; \
done

# Clean up
clean:
rm -f $(TARGETS)
41 changes: 41 additions & 0 deletions test/bits.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "llmc/bits.cuh"

#undef NDEBUG
#include <assert.h>
#include <float.h>
#include <stdio.h>
#include <math.h>

float round_trip(float f, unsigned short threshold) {
const SplitFloatResult split = split_float(f, threshold);
const float r = assemble_float(split.b_float, split.bits);
return r;
}

bool match_floats(float f1, float f2) {
const unsigned int u1 = bit_cast<unsigned int>(f1);
const unsigned int u2 = bit_cast<unsigned int>(f2);
if((u1 & (~1u)) != (u2 & (~1u))) {
printf("MISMATCH: %0x %0x\n", u1, u2);
return false;
}
return true;
}

#define ASSERT_ROUND_TRIP(f) \
assert(match_floats(f, round_trip(f, 0))); \
assert(match_floats(f, round_trip(f, 0xFFFF))); \

int main() {
ASSERT_ROUND_TRIP(1.4623f)
ASSERT_ROUND_TRIP(-63623.9f)
ASSERT_ROUND_TRIP(FLT_TRUE_MIN)
ASSERT_ROUND_TRIP(NAN)
ASSERT_ROUND_TRIP(0)
ASSERT_ROUND_TRIP(INFINITY)
// make sure we trigger the "rounding increases exponent" code path
const float increment_exponent = bit_cast<float>((unsigned int)(0x40ff'fff0));
ASSERT_ROUND_TRIP(increment_exponent)
printf("PASS\n");
return EXIT_SUCCESS;
}
36 changes: 34 additions & 2 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ int main(int argc, char *argv[]) {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_init_common(&model);

gpt2_build_from_checkpoint(&model, load_filename);
size_t V = model.config.vocab_size;
Expand Down Expand Up @@ -304,11 +305,42 @@ int main(int argc, char *argv[]) {

// compare
for (int i = 0; i < 10; i++) {
if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {
if (fabsf(losses[i] - expected_losses[i]) < loss_diff_threshold && isfinite(losses[i])) {
printf("loss ok at step %d: %f %f\n", i+1, losses[i], expected_losses[i]);
} else {
printf("LOSS MISMATCH AT STEP %d: %f %f\n", i+1, losses[i], expected_losses[i]);
allok = 0;
}
}

// Finally, let's check determinism
gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt");
save_state("test_gpt2cu_state.ckpt", 10, &model, nullptr);
for (int step = 10; step < 20; step++) {
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model, x, true);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+1, &multi_gpu_config);
losses[step - 10] = model.mean_loss;
}

// reload
gpt2_free(&model);
gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt");
int ld_step;
load_state(&ld_step, &model, nullptr, "test_gpt2cu_state.ckpt");
for (int step = 10; step < 20; step++) {
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model, x, true);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+1, &multi_gpu_config);

if(losses[step-10] != model.mean_loss) {
printf("Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\n", step, losses[step-10], model.mean_loss);
allok = false;
break;
} else {
printf("loss ok at step %d: %f %f\n", i+1, losses[i], expected_losses[i]);
printf("loss ok at step %d: %f %f\n", step, losses[step-10], model.mean_loss);
}
}

Expand Down
Loading
Loading