Skip to content

Commit

Permalink
ggml : support ChatGLM-style RoPE (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jun 26, 2023
1 parent c0b546b commit 2d95223
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/dolly-v2/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ bool dollyv2_eval(
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));

// using mode = 2 for GPT-NeoX mode
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2);
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);

// store key and value to memory
{
Expand Down
4 changes: 2 additions & 2 deletions examples/gpt-j/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,8 @@ bool gptj_eval(

// self-attention
{
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);

// store key and value to memory
{
Expand Down
4 changes: 2 additions & 2 deletions examples/gpt-neox/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ bool gpt_neox_eval(
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));

// using mode = 2 for GPT-NeoX mode
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2);
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0);

// store key and value to memory
{
Expand Down
7 changes: 5 additions & 2 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1036,21 +1036,24 @@ extern "C" {
// rotary position embedding
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style
// if mode & 4 == 1, ChatGLM style
// TODO: avoid creating a new tensor every time
GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode);
int mode,
int n_ctx);

// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode);
int mode,
int n_ctx);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
Expand Down
82 changes: 71 additions & 11 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl(
int n_past,
int n_dims,
int mode,
int n_ctx,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
Expand All @@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl(

ggml_scratch_save(ctx);

struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);

((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_dims;
((int32_t *) b->data)[2] = mode;
((int32_t *) b->data)[3] = n_ctx;

ggml_scratch_load(ctx);

Expand All @@ -6811,17 +6813,19 @@ struct ggml_tensor * ggml_rope(
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
}

struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
}

// ggml_rope_back
Expand Down Expand Up @@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 3);
GGML_ASSERT(ggml_nelements(src1) == 4);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
Expand All @@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32(
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];

assert(n_past >= 0);

Expand Down Expand Up @@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(10000.0, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
Expand All @@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32(

float theta = (float)p;

if (!is_neox) {
if (is_glm) {
theta = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);

theta *= theta_scale;
block_theta *= theta_scale;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[n_dims/2];
const float x2 = src[n_dims];
const float x3 = src[n_dims/2*3];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
Expand Down Expand Up @@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(src1) == 3);
GGML_ASSERT(ggml_nelements(src1) == 4);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
Expand All @@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16(
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];

assert(n_past >= 0);

Expand Down Expand Up @@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16(
const float theta_scale = powf(10000.0, -2.0f/n_dims);

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
Expand All @@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16(

float theta = (float)p;

if (!is_neox) {
if (is_glm) {
theta = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);

theta *= theta_scale;
block_theta *= theta_scale;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);

dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
}
} if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
Expand Down Expand Up @@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3);
assert(ggml_nelements(src1) == 4);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope(ctx,
tensor->grad,
n_past,
n_dims,
mode),
mode,
n_ctx),
inplace);
}
if (src1->grad) {
Expand Down
4 changes: 2 additions & 2 deletions tests/test-grad0.c
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ int main(int argc, const char ** argv) {
int n_rot = ne2[0];

for (int ndims = 3; ndims <= 4; ++ndims) {
for (int mode = 0; mode < 4; ++mode) {
for (int mode = 0; mode < 8; ++mode) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);

Expand All @@ -1154,7 +1154,7 @@ int main(int argc, const char ** argv) {
continue;
}

struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode));
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));

GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
Expand Down

0 comments on commit 2d95223

Please sign in to comment.