diff --git a/llmc/attention.cuh b/llmc/attention.cuh index 6d4f0e2f..efad322e 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -149,8 +149,8 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons } } -__global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const floatX* datt, const floatX* att, - int B, int T, int C, float scale) { +__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, const floatX* att, + int B, int T, int C, float scale) { constexpr const int BlockSize = 256; constexpr int T_per_block = 4; @@ -160,14 +160,13 @@ __global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const fl att += idx * T * T; datt += idx * T * T; - dpreatt += idx * T * T; for(int to = 0; to < T_per_block; ++to) { int t = t0 - to; if(t < 0) return; const floatX* att_bth = att + t * T; const floatX* datt_bth = datt + t * T; - floatX* dpreatt_bth = dpreatt + t * T; + floatX* dpreatt_bth = datt + t * T; float local_sum = 0; for (int t2 = threadIdx.x; t2 <= t; t2 += BlockSize) { @@ -176,11 +175,16 @@ __global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const fl local_sum = blockReduce(local_sum); - for (int t3 = threadIdx.x; t3 <= t; t3 += BlockSize) { + for (int t3 = threadIdx.x; t3 < T; t3 += BlockSize) { // don't touch the cache. Some parts will still be here from the previous loop, and // we want to exploit those. - float acc = (float)__ldcs(att_bth + t3) * ((float)__ldcs(datt_bth + t3) - local_sum); - __stcs(dpreatt_bth + t3, (floatX)(scale * acc)); + if(t3 <= t) { + float acc = (float) __ldcs(att_bth + t3) * ((float) __ldcs(datt_bth + t3) - local_sum); + __stcs(dpreatt_bth + t3, (floatX) (scale * acc)); + } else { + // explicitly set non-causal elements to zero + __stcs(dpreatt_bth + t3, (floatX)0.f); + } } } } @@ -200,7 +204,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, // inp is (B, T, 3C) QKV // preatt, att are (B, NH, T, T) // output is (B, T, C) - int HS = C / NH; // head size + const int HS = C / NH; // head size // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) floatX *q, *k, *v; @@ -223,7 +227,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT)); // multiply all elements of preatt elementwise by scale - float scale = 1.0 / sqrtf(HS); + float scale = 1.f / sqrtf(HS); int grid_size = CEIL_DIV(B * NH * T * WARP_SIZE, block_size); softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); @@ -247,13 +251,13 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) -void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* datt, floatX* scratch, +void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch, const floatX* dout, const floatX* qkvr, const floatX* att, int B, int T, int C, int NH, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; - int HS = C / NH; // head size + const int HS = C / NH; // head size const float alpha = 1.0f, beta = 0.0f; // unpack convenience pointers into q, k, v @@ -279,10 +283,10 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &alpha, scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, &beta, dv, CUBLAS_LOWP, HS, T * HS, B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT)); - // backward into preatt - int hs = C / NH; // head size - float scale = 1.0f / sqrtf(hs); - softmax_autoregressive_backward_kernel<<>>(dpreatt, datt, att, B, T, C, scale); + const float scale = 1.0f / sqrtf((float)HS); + // backward into preatt. this is an in-place operation; datt turns into dpreatt here + softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); + const floatX* dpreatt = datt; // backward into q cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, k, CUBLAS_LOWP, HS, T * HS, dpreatt, CUBLAS_LOWP, T, T * T, &beta, diff --git a/train_gpt2.cu b/train_gpt2.cu index 1d27408a..49e94c3f 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -192,7 +192,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen return params_memory; } -#define NUM_ACTIVATION_TENSORS 21 +constexpr int NUM_ACTIVATION_TENSORS = 23; typedef struct { floatX* encoded; // (B, T, C) floatX* ln1; // (L, B, T, C) @@ -221,6 +221,10 @@ typedef struct { // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C), // (B, NH, T, T), and (B, T, V) shaped tensors. floatX* output; + + // some additional scratch buffers + floatX* scratch_bt4c; // (B, T, 4*C) + floatX* scratch_btc; // (B, T, C) } ActivationTensors; void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config, int recompute) { @@ -257,34 +261,9 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config act_sizes[18] = B * T; // losses act_sizes[19] = L * B * T * 3*C; // qkvr act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch -} - -// Backward pass is conceptually quite different from forward, because we can discard -// the activations of a layer as soon as we're done with it. This lets us aggressively -// reuse memory, so that we need far fewer tensors for backward state. -#ifdef ENABLE_CUDNN -#define NUM_BACKWARD_TENSORS 2 -#else -#define NUM_BACKWARD_TENSORS 3 -#endif -typedef struct { - floatX* bt4c; // (B, T, 4*C) - floatX* residual3; // (B, T, C) - #ifndef ENABLE_CUDNN - floatX* preatt; // (B, NH, T, T) - #endif -} GradActTensors; - -void fill_in_grad_act_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config) { - size_t C = config.channels; - act_sizes[0] = B * T * 4 * C; // bt4c - act_sizes[1] = B * T * C; // residual3 - - #ifndef ENABLE_CUDNN - size_t NH = config.num_heads; - act_sizes[2] = B * NH * T * T; // preatt - #endif + act_sizes[21] = B * T * 4 * C; // scratch_bt4c + act_sizes[22] = B * T * C; // scratch_btc } void* malloc_and_point(floatX** targets[], const size_t* act_sizes, size_t n) { @@ -312,21 +291,12 @@ void* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_si &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, - &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output + &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output, + &acts->scratch_bt4c, &acts->scratch_btc }; return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS); } -void* malloc_and_point_backward(GradActTensors* acts, const size_t* act_sizes) { - floatX** ptrs[] = { - &acts->bt4c, &acts->residual3, - #ifndef ENABLE_CUDNN - &acts->preatt, - #endif - }; - return malloc_and_point(ptrs, act_sizes, NUM_BACKWARD_TENSORS); -} - typedef struct { GPT2Config config; // the weights of the model, and their sizes @@ -348,10 +318,6 @@ typedef struct { size_t act_sizes[NUM_ACTIVATION_TENSORS]; void* acts_memory; size_t num_activations; - // gradients of the activations - GradActTensors grads_acts; - size_t num_grad_acts; - void* grads_acts_memory; // other run state configuration int batch_size; // the batch size (B) of current forward pass int seq_len; // the sequence length (T) of current forward pass @@ -386,7 +352,6 @@ void gpt2_init_common(GPT2 *model) { model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward() // memory lazily initialized in backward() model->grads_memory = NULL; - model->grads_acts_memory = NULL; model->workload_indices = NULL; // on cpu, for encoder_backward model->bucket_info = NULL; // on cpu, for encoder_backward // memory lazily initialized in update() @@ -770,17 +735,6 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // allocate buffers for weight gradients printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); - // we're going to be clever for the activations backward pass. we don't need to exactly - // mirror the forward pass activations and we will save memory. - size_t bw_act_sizes[NUM_BACKWARD_TENSORS]; - fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, model->config); - // count up and allocate the space - model->num_grad_acts = 0; - for (size_t i = 0; i < NUM_BACKWARD_TENSORS; i++) { - model->num_grad_acts += bw_act_sizes[i]; - } - printf0("allocating %d MiB for activation gradients\n", (int)round(model->num_grad_acts * sizeof(floatX) / (1024 * 1024))); - model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); // init gradients of parameters and activations to zero gpt2_zero_grad(model); // initialise cpu scratch buffers for encoder backward @@ -802,10 +756,10 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { ParameterTensors params = model->params; // for brevity ParameterTensors grads = model->grads; ActivationTensors acts = model->acts; - GradActTensors grads_acts = model->grads_acts; // reset residual stream gradients (put here to work with gradient accumulation) - cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX))); + floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass + cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); // re-use the output buffer of the forward pass as a scratchpad during backward pass float* scratchF = (float*)acts.output; @@ -816,11 +770,10 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 - floatX* dresidual = (floatX*)grads_acts.residual3; // the main buffer holding the gradient in the backward pass - layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); + layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic // scratch for backward computations @@ -870,7 +823,7 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // notice that there is no l *, because we just have a single copy, and keep // re-using this memory in every Transformer block as we calculate backward pass - floatX* dl_bt4c = (floatX*)grads_acts.bt4c; + floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; // start the backward pass for this layer if(model->recompute >= 1) { @@ -897,8 +850,7 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) { // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory floatX* buffer_a = l_atty; floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need - floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation - attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); #endif if(model->recompute >= 2) { layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); @@ -1154,7 +1106,6 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->v_memory)); cudaCheck(cudaFree(model->master_weights)); cudaCheck(cudaFree(model->acts_memory)); - cudaCheck(cudaFree(model->grads_acts_memory)); cudaCheck(cudaFree(model->inputs)); cudaCheck(cudaFree(model->targets)); cudaCheck(cudaFreeHost(model->cpu_losses));