Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimenting with global instantiation for the layouts #347

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 81 additions & 31 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,30 @@ enum PrecisionMode {
PRECISION_BF16
};

cublasLtMatrixLayout_t biasLayout;
cublasLtMatrixLayout_t biasLayoutVp;
cublasLtMatrixLayout_t biasLayout4c;
cublasLtMatrixLayout_t biasLayout3c;

cublasLtMatrixLayout_t weightLayout1;
cublasLtMatrixLayout_t weightLayout2;
cublasLtMatrixLayout_t weightLayout3;
cublasLtMatrixLayout_t weightLayout4;
cublasLtMatrixLayout_t weightLayout5;

cublasLtMatrixLayout_t inputLayoutbtc;
cublasLtMatrixLayout_t inputLayoutbt4c;

cublasLtMatrixLayout_t outputLayoutc;
cublasLtMatrixLayout_t outputLayoutvp;
cublasLtMatrixLayout_t outputLayout4c;
cublasLtMatrixLayout_t outputLayout3c;

cublasLtMatmulPreference_t preference;

cublasLtMatmulDesc_t operationDesc;


// Default Properties
typedef float floatN;
#define CUBLAS_LOWP_COMPUTE cublas_compute_type
Expand Down Expand Up @@ -1389,7 +1413,8 @@ void layernorm_forward(floatX* out, floatX* mean, floatX* rstd,
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
void matmul_forward_cublaslt(floatX* out,
floatX* inp, floatX* weight, floatX* bias,
int B, int T, int C, int OC) {
int B, int T, int C, int OC, cublasLtMatrixLayout_t biasLayout, cublasLtMatrixLayout_t weightLayout,
cublasLtMatrixLayout_t inputLayout, cublasLtMatrixLayout_t outputLayout) {
NVTX_RANGE_FN();
int has_bias = (bias != NULL);

Expand All @@ -1409,11 +1434,6 @@ void matmul_forward_cublaslt(floatX* out,

int returnedResults = 0;
cublasLtMatmulDesc_t operationDesc;
cublasLtMatmulPreference_t preference;
cublasLtMatrixLayout_t weightLayout;
cublasLtMatrixLayout_t inputLayout;
cublasLtMatrixLayout_t outputLayout;
cublasLtMatrixLayout_t biasLayout;
cublasLtMatmulHeuristicResult_t heuristic;

// create the operation descriptor
Expand All @@ -1431,18 +1451,6 @@ void matmul_forward_cublaslt(floatX* out,
}
cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));

// define matrix layouts
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUBLAS_LOWP, C, OC, C));
cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUBLAS_LOWP, C, B*T, C));
cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUBLAS_LOWP, OC, B*T, OC));
cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUBLAS_LOWP, OC, 1, OC));

// create a preference handle with specified max workspace
cublasCheck(cublasLtMatmulPreferenceCreate(&preference));
cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));

// find a suitable algorithm
cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc,
weightLayout, inputLayout, outputLayout, outputLayout,
Expand All @@ -1455,16 +1463,11 @@ void matmul_forward_cublaslt(floatX* out,
// call the matmul
cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc,
alpha_ptr, weight, weightLayout, inp, inputLayout, beta_ptr,
out, outputLayout, out, outputLayout, &heuristic.algo,
out, outputLayout, out, outputLayout, NULL,
cublaslt_workspace, cublaslt_workspace_size, 0));

// cleanups
cublasCheck(cublasLtMatmulPreferenceDestroy(preference));
cublasCheck(cublasLtMatmulDescDestroy(operationDesc));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
}

void attention_forward(floatX* out, floatX* qkvr, floatX* att,
Expand Down Expand Up @@ -2107,29 +2110,28 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {

#ifdef ENABLE_CUDNN
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, biasLayout3c, weightLayout5);
attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
// these are only needed as scratchpads for the forward pass, but
// need not be stored for backward
floatX* scratch = (floatX*)acts.output;
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, biasLayout3c, weightLayout5, inputLayoutbtc, outputLayout3c);
attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH);
#endif

matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C, biasLayout, weightLayout1, inputLayoutbtc, outputLayoutc);
residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, biasLayout4c, weightLayout2, inputLayoutbtc, outputLayout4c);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, biasLayout, weightLayout3, inputLayoutbt4c, outputLayoutc);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);
}

residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp);
matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, biasLayoutVp, weightLayout4, inputLayoutbtc, outputLayoutvp);

// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
Expand All @@ -2148,6 +2150,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
// if we don't have targets, we don't have loss
model->mean_loss = -1.0f;
}

}

void gpt2_zero_grad(GPT2 *model) {
Expand Down Expand Up @@ -2665,6 +2668,35 @@ int main(int argc, char *argv[]) {
floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX));
float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float));

size_t Vp = model.config.padded_vocab_size;
size_t NH = model.config.num_heads;
size_t C = model.config.channels;

// create a preference handle with specified max workspace
cublasCheck(cublasLtMatmulPreferenceCreate(&preference));
cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));

cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUBLAS_LOWP, C, 1, C));
cublasCheck(cublasLtMatrixLayoutCreate(&biasLayoutVp, CUBLAS_LOWP, Vp, 1, Vp));
cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout4c, CUBLAS_LOWP, 4*C, 1, 4*C));
cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout4c, CUBLAS_LOWP, 3*C, 1, 3*C));

cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout1, CUBLAS_LOWP, C, C, C));
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout2, CUBLAS_LOWP, C, 4*C, C));
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout3, CUBLAS_LOWP, 4*C, C, 4*C));
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout4, CUBLAS_LOWP, C, Vp, C));
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout5, CUBLAS_LOWP, C, 3*C, C));

cublasCheck(cublasLtMatrixLayoutCreate(&inputLayoutbtc, CUBLAS_LOWP, C, B*T, C));
cublasCheck(cublasLtMatrixLayoutCreate(&inputLayoutbt4c, CUBLAS_LOWP, 4*C, B*T, 4*C));

cublasCheck(cublasLtMatrixLayoutCreate(&outputLayoutc, CUBLAS_LOWP, C, B*T, C));
cublasCheck(cublasLtMatrixLayoutCreate(&outputLayoutvp, CUBLAS_LOWP, Vp, B*T, Vp));
cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout4c, CUBLAS_LOWP, 4*C, B*T, 4*C));
cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout3c, CUBLAS_LOWP, 3*C, B*T, 3*C));

// train
cudaEvent_t start, end;
cudaCheck(cudaEventCreate(&start));
Expand Down Expand Up @@ -2776,6 +2808,24 @@ int main(int argc, char *argv[]) {
// add a total average, for optimizations that are only mild improvements (excluding 1st batch as warmup)
printf0("total average iteration time: %f ms\n", total_sum_iteration_time_s / (train_num_batches-1) * 1000);

cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayoutVp));
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout4c));
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout3c));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout1));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout2));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout3));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout4));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout5));
cublasCheck(cublasLtMatrixLayoutDestroy(inputLayoutbtc));
cublasCheck(cublasLtMatrixLayoutDestroy(inputLayoutbt4c));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayoutc));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayoutvp));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout4c));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout3c));

cublasCheck(cublasLtMatmulPreferenceDestroy(preference));

// free and destroy everything
cudaCheck(cudaEventDestroy(end));
cudaCheck(cudaEventDestroy(start));
Expand Down