Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consolidate memory #590

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions llmc/attention.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
}
}

__global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const floatX* datt, const floatX* att,
int B, int T, int C, float scale) {
__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, const floatX* att,
int B, int T, int C, float scale) {
constexpr const int BlockSize = 256;
constexpr int T_per_block = 4;

Expand All @@ -160,14 +160,13 @@ __global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const fl

att += idx * T * T;
datt += idx * T * T;
dpreatt += idx * T * T;

for(int to = 0; to < T_per_block; ++to) {
int t = t0 - to;
if(t < 0) return;
const floatX* att_bth = att + t * T;
const floatX* datt_bth = datt + t * T;
floatX* dpreatt_bth = dpreatt + t * T;
floatX* dpreatt_bth = datt + t * T;

float local_sum = 0;
for (int t2 = threadIdx.x; t2 <= t; t2 += BlockSize) {
Expand All @@ -176,11 +175,16 @@ __global__ void softmax_autoregressive_backward_kernel(floatX* dpreatt, const fl

local_sum = blockReduce<warpReduceSum>(local_sum);

for (int t3 = threadIdx.x; t3 <= t; t3 += BlockSize) {
for (int t3 = threadIdx.x; t3 < T; t3 += BlockSize) {
// don't touch the cache. Some parts will still be here from the previous loop, and
// we want to exploit those.
float acc = (float)__ldcs(att_bth + t3) * ((float)__ldcs(datt_bth + t3) - local_sum);
__stcs(dpreatt_bth + t3, (floatX)(scale * acc));
if(t3 <= t) {
float acc = (float) __ldcs(att_bth + t3) * ((float) __ldcs(datt_bth + t3) - local_sum);
__stcs(dpreatt_bth + t3, (floatX) (scale * acc));
} else {
// explicitly set non-causal elements to zero
__stcs(dpreatt_bth + t3, (floatX)0.f);
}
}
}
}
Expand All @@ -200,7 +204,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
// inp is (B, T, 3C) QKV
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int HS = C / NH; // head size
const int HS = C / NH; // head size

// permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)
floatX *q, *k, *v;
Expand All @@ -223,7 +227,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT));

// multiply all elements of preatt elementwise by scale
float scale = 1.0 / sqrtf(HS);
float scale = 1.f / sqrtf(HS);
int grid_size = CEIL_DIV(B * NH * T * WARP_SIZE, block_size);
softmax_forward_kernel5<<<grid_size, block_size, 0, stream>>>(att, scale, preatt, B * NH, T);

Expand All @@ -247,13 +251,13 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,

// the sequence of transformations in this compound op is:
// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)
void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* datt, floatX* scratch,
void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch,
const floatX* dout,
const floatX* qkvr, const floatX* att,
int B, int T, int C, int NH, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
int HS = C / NH; // head size
const int HS = C / NH; // head size
const float alpha = 1.0f, beta = 0.0f;

// unpack convenience pointers into q, k, v
Expand All @@ -279,10 +283,10 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &alpha,
scratch, CUBLAS_LOWP, HS, T * HS, att, CUBLAS_LOWP, T, T * T, &beta,
dv, CUBLAS_LOWP, HS, T * HS, B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT));
// backward into preatt
int hs = C / NH; // head size
float scale = 1.0f / sqrtf(hs);
softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256, 0, stream>>>(dpreatt, datt, att, B, T, C, scale);
const float scale = 1.0f / sqrtf((float)HS);
// backward into preatt. this is an in-place operation; datt turns into dpreatt here
softmax_autoregressive_backward_inplace_kernel<<<dim3(T / 4, B * NH), 256>>>(datt, att, B, T, C, scale);
const floatX* dpreatt = datt;
// backward into q
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha,
k, CUBLAS_LOWP, HS, T * HS, dpreatt, CUBLAS_LOWP, T, T * T, &beta,
Expand Down
79 changes: 15 additions & 64 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
return params_memory;
}

#define NUM_ACTIVATION_TENSORS 21
constexpr int NUM_ACTIVATION_TENSORS = 23;
typedef struct {
floatX* encoded; // (B, T, C)
floatX* ln1; // (L, B, T, C)
Expand Down Expand Up @@ -221,6 +221,10 @@ typedef struct {
// general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C),
// (B, NH, T, T), and (B, T, V) shaped tensors.
floatX* output;

// some additional scratch buffers
floatX* scratch_bt4c; // (B, T, 4*C)
floatX* scratch_btc; // (B, T, C)
} ActivationTensors;

void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config, int recompute) {
Expand Down Expand Up @@ -257,34 +261,9 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
act_sizes[18] = B * T; // losses
act_sizes[19] = L * B * T * 3*C; // qkvr
act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch
}

// Backward pass is conceptually quite different from forward, because we can discard
// the activations of a layer as soon as we're done with it. This lets us aggressively
// reuse memory, so that we need far fewer tensors for backward state.
#ifdef ENABLE_CUDNN
#define NUM_BACKWARD_TENSORS 2
#else
#define NUM_BACKWARD_TENSORS 3
#endif

typedef struct {
floatX* bt4c; // (B, T, 4*C)
floatX* residual3; // (B, T, C)
#ifndef ENABLE_CUDNN
floatX* preatt; // (B, NH, T, T)
#endif
} GradActTensors;

void fill_in_grad_act_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config) {
size_t C = config.channels;
act_sizes[0] = B * T * 4 * C; // bt4c
act_sizes[1] = B * T * C; // residual3

#ifndef ENABLE_CUDNN
size_t NH = config.num_heads;
act_sizes[2] = B * NH * T * T; // preatt
#endif
act_sizes[21] = B * T * 4 * C; // scratch_bt4c
act_sizes[22] = B * T * C; // scratch_btc
}

void* malloc_and_point(floatX** targets[], const size_t* act_sizes, size_t n) {
Expand Down Expand Up @@ -312,21 +291,12 @@ void* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_si
&acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty,
&acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,
&acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,
&acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output
&acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output,
&acts->scratch_bt4c, &acts->scratch_btc
};
return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS);
}

void* malloc_and_point_backward(GradActTensors* acts, const size_t* act_sizes) {
floatX** ptrs[] = {
&acts->bt4c, &acts->residual3,
#ifndef ENABLE_CUDNN
&acts->preatt,
#endif
};
return malloc_and_point(ptrs, act_sizes, NUM_BACKWARD_TENSORS);
}

typedef struct {
GPT2Config config;
// the weights of the model, and their sizes
Expand All @@ -348,10 +318,6 @@ typedef struct {
size_t act_sizes[NUM_ACTIVATION_TENSORS];
void* acts_memory;
size_t num_activations;
// gradients of the activations
GradActTensors grads_acts;
size_t num_grad_acts;
void* grads_acts_memory;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
Expand Down Expand Up @@ -386,7 +352,6 @@ void gpt2_init_common(GPT2 *model) {
model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward()
// memory lazily initialized in backward()
model->grads_memory = NULL;
model->grads_acts_memory = NULL;
model->workload_indices = NULL; // on cpu, for encoder_backward
model->bucket_info = NULL; // on cpu, for encoder_backward
// memory lazily initialized in update()
Expand Down Expand Up @@ -770,17 +735,6 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) {
// allocate buffers for weight gradients
printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof);
// we're going to be clever for the activations backward pass. we don't need to exactly
// mirror the forward pass activations and we will save memory.
size_t bw_act_sizes[NUM_BACKWARD_TENSORS];
fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, model->config);
// count up and allocate the space
model->num_grad_acts = 0;
for (size_t i = 0; i < NUM_BACKWARD_TENSORS; i++) {
model->num_grad_acts += bw_act_sizes[i];
}
printf0("allocating %d MiB for activation gradients\n", (int)round(model->num_grad_acts * sizeof(floatX) / (1024 * 1024)));
model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes);
// init gradients of parameters and activations to zero
gpt2_zero_grad(model);
// initialise cpu scratch buffers for encoder backward
Expand All @@ -802,10 +756,10 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) {
ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
ActivationTensors acts = model->acts;
GradActTensors grads_acts = model->grads_acts;

// reset residual stream gradients (put here to work with gradient accumulation)
cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX)));
floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass
cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX)));

// re-use the output buffer of the forward pass as a scratchpad during backward pass
float* scratchF = (float*)acts.output;
Expand All @@ -816,11 +770,10 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) {
// technically that is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
// next: backward the classifier matmul
matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
floatX* dresidual = (floatX*)grads_acts.residual3; // the main buffer holding the gradient in the backward pass
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream);
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream);

// from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic
// scratch for backward computations
Expand Down Expand Up @@ -870,7 +823,7 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) {
// notice that there is no l *, because we just have a single copy, and keep
// re-using this memory in every Transformer block as we calculate backward pass

floatX* dl_bt4c = (floatX*)grads_acts.bt4c;
floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c;

// start the backward pass for this layer
if(model->recompute >= 1) {
Expand All @@ -897,8 +850,7 @@ void gpt2_backward(GPT2 *model, int* inputs, bool last_step) {
// we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need
floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation
attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
#endif
if(model->recompute >= 2) {
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);
Expand Down Expand Up @@ -1154,7 +1106,6 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->v_memory));
cudaCheck(cudaFree(model->master_weights));
cudaCheck(cudaFree(model->acts_memory));
cudaCheck(cudaFree(model->grads_acts_memory));
cudaCheck(cudaFree(model->inputs));
cudaCheck(cudaFree(model->targets));
cudaCheck(cudaFreeHost(model->cpu_losses));
Expand Down