Skip to content

Commit

Permalink
[CPP Graph] ChatGLM-2 Enabling (#210)
Browse files Browse the repository at this point in the history
* chatglm-2 q4_j infernece pass with correct accuracy

* unift convert scripts

* specify chatglm2, remove ambiguous chatglm

* initilize glm1

* initilize glm1

* Fix kernel issues for glm1

* adapt to the latest main and chatglm2 infernece pass

* add parameters for all convert.py

Signed-off-by: Zhenzhong1 <zhenzhong.xu@intel.com>

* add parameters for the bloom

* update README and cleancode

* disable chatglm1

---------

Signed-off-by: Zhenzhong1 <zhenzhong.xu@intel.com>
  • Loading branch information
Zhenzhong1 committed Sep 8, 2023
1 parent d06e5d6 commit 9a2cfa5
Show file tree
Hide file tree
Showing 29 changed files with 1,820 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ We support the following models:
|[Falcon-7B](https://huggingface.co/tiiuae/falcon-7b), [Falcon-40B](https://huggingface.co/tiiuae/falcon-40b)|||
|[BLOOM-7B](https://huggingface.co/bigscience/bloomz-7b1)|||
|[OPT-125m](https://huggingface.co/facebook/opt-125m), [OPT-350m](https://huggingface.co/facebook/opt-350m), [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b), [OPT-13B](https://huggingface.co/facebook/opt-13b)|||
|[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)|||

### Code generation models
| model name | INT8 | INT4|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ compile_quant(quant_starcoder quant_model.cpp starcoder starcoder)
compile_quant(quant_opt quant_model.cpp opt opt)
compile_quant(quant_bloom quant_model.cpp bloom bloom)

#compile_quant(quant_chatglm1 quant_model.cpp chatglm1 chatglm1)
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)

# all models running
function(compile_run TARGET SRC MODEL_NAME MODEL_LIB)
add_executable_w_warning(${TARGET} ${SRC})
Expand All @@ -83,3 +86,4 @@ compile_run(run_mpt main_run.cpp mpt mpt)
compile_run(run_starcoder main_run.cpp starcoder starcoder)
compile_run(run_opt main_run.cpp opt opt)
compile_run(run_bloom main_run.cpp bloom bloom)
compile_run(run_chatglm2 main_run.cpp chatglm2 chatglm2)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <thread>
#include <unordered_map>
#include <utility>
#include <iostream>

#include "common.h"
#include "models/model_utils/model_types.h"
Expand All @@ -49,6 +50,17 @@ static model_context** g_ctx;

static bool is_interacting = false;

std::string build_prompt(const std::vector<std::string> &history) {
std::ostringstream oss_prompt;
for (size_t i = 0; i < history.size(); i += 2) {
oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << history[i] << "\n\n答:";
if (i < history.size() - 1) {
oss_prompt << history[i + 1] << "\n\n";
}
}
return oss_prompt.str();
}

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
void sigint_handler(int signo) {
if (signo == SIGINT) {
Expand All @@ -68,10 +80,12 @@ int main(int argc, char** argv) {
gpt_params params;
#ifdef MODEL_NAME
params.model_name = MODEL_NAME;
std::cout << "Welcome to use the " << params.model_name << " on the ITREX! "<< std::endl;
#endif
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}

model_archs mt = model_name_to_arch::init().find(params.model_name);
if (mt == MODEL_UNKNOWN) {
fprintf(stderr, "error, please set model_name \n");
Expand Down Expand Up @@ -196,7 +210,17 @@ int main(int argc, char** argv) {
if (params.model_arch == MODEL_LLAMA) {
add_bos = true;
}
auto embd_inp = ::model_tokenize(ctx, params.prompt, add_bos);

std::vector<int> embd_inp;
if (params.model_arch == MODEL_CHATGLM2 || params.model_arch == MODEL_CHATGLM1) {
std::vector<std::string> prompts;
prompts.push_back(params.prompt);
std::string prompt = build_prompt(prompts);
embd_inp = ::model_tokenize(ctx, prompt, false);
embd_inp.insert(embd_inp.begin(), {64790, 64792}); // special prefix
} else {
embd_inp = ::model_tokenize(ctx, params.prompt, add_bos);
}

const int n_ctx = model_n_ctx(ctx);

Expand Down
200 changes: 114 additions & 86 deletions intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -2929,7 +2929,7 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
// ne_rope

struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
bool inplace) {
int n_ctx, bool inplace) {
NE_ASSERT(n_past >= 0);
bool is_node = false;

Expand All @@ -2946,6 +2946,7 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
((int32_t*)b->data)[0] = n_past;
((int32_t*)b->data)[1] = n_dims;
((int32_t*)b->data)[2] = mode;
((int32_t*)b->data)[3] = n_ctx;

ne_scratch_load(ctx);

Expand All @@ -2957,12 +2958,12 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
return result;
}

struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, false);
struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int n_ctx) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
}

struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, true);
struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode, int n_ctx) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
}

// ne_rope_back
Expand Down Expand Up @@ -5790,11 +5791,19 @@ static void ne_compute_forward_gelu(const struct ne_compute_params* params, cons
}

// ne_compute_forward_silu
static inline bool ne_is_contiguous_except_dim_1(const struct ne_tensor * tensor) {
static_assert(NE_MAX_DIMS == 4, "NE_MAX_DIMS is not 4 - update this function");

return
tensor->nb[0] == NE_TYPE_SIZE[tensor->type] &&
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}

static void ne_compute_forward_silu_f32(const struct ne_compute_params* params, const struct ne_tensor* src0,
struct ne_tensor* dst) {
NE_ASSERT(ne_is_contiguous(src0));
NE_ASSERT(ne_is_contiguous(dst));
NE_ASSERT(ne_is_contiguous_except_dim_1(src0));
NE_ASSERT(ne_is_contiguous_except_dim_1(dst));
NE_ASSERT(ne_are_same_shape(src0, dst));

if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
Expand Down Expand Up @@ -7696,116 +7705,135 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
}

// ne_compute_forward_rope
#define NE_TENSOR_UNARY_OP_LOCALS \
NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
NE_TENSOR_LOCALS(size_t, nb0, src0, nb); \
NE_TENSOR_LOCALS(int64_t, ne, dst, ne); \
NE_TENSOR_LOCALS(size_t, nb, dst, nb);

static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, const struct ne_tensor* src0,
const struct ne_tensor* src1, struct ne_tensor* dst) {
NE_ASSERT(src1->type == NE_TYPE_I32);
NE_ASSERT(ne_nelements(src1) == 3);
static void ne_compute_forward_rope_f32(
const struct ne_compute_params * params,
const struct ne_tensor * src0,
const struct ne_tensor* src1,
struct ne_tensor * dst) {

if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
return;
}
if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
return;
}

const int n_past = ((int32_t*)src1->data)[0];
const int n_dims = ((int32_t*)src1->data)[1];
const int mode = ((int32_t*)src1->data)[2];
float freq_base = 10000.0f;
float freq_scale = 1.0f;

assert(n_past >= 0);
const int64_t n_past = ((int32_t*)src1->data)[0];
const int64_t n_dims = ((int32_t*)src1->data)[1];
const int64_t mode = ((int32_t*)src1->data)[2];
const int64_t n_ctx = ((int32_t*)src1->data)[3];

const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];
assert(n_past >= 0);

const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
NE_TENSOR_UNARY_OP_LOCALS;

const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
NE_ASSERT(nb00 == sizeof(float));

// printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
// printf("n_past = %d, ne2 = %d\n", n_past, ne2);
const int ith = params->ith;
const int nth = params->nth;

NE_ASSERT(nb00 == sizeof(float));
const int nr = ne_nrows(dst);

const int ith = params->ith;
const int nth = params->nth;
NE_ASSERT(n_dims <= ne0);
NE_ASSERT(n_dims % 2 == 0);

const int nr = ne_nrows(dst);
// rows per thread
const int dr = (nr + nth - 1)/nth;

NE_ASSERT(n_dims <= ne0);
NE_ASSERT(n_dims % 2 == 0);
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

// rows per thread
const int dr = (nr + nth - 1) / nth;
// row index used to determine which thread to use
int ir = 0;

// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
const float theta_scale = powf(freq_base, -2.0f/n_dims);

// row index used to determine which thread to use
int ir = 0;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

const float theta_scale = powf(10000.0, -2.0f / n_dims);
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;

const bool is_neox = mode & 2;
float theta = freq_scale * (float)p;

for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (is_glm) {
theta = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);

float theta = (float)p;
theta *= theta_scale;
block_theta *= theta_scale;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

theta *= theta_scale;
const float x0 = src[0];
const float x1 = src[n_dims/2];
const float x2 = src[n_dims];
const float x3 = src[n_dims/2*3];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
theta *= theta_scale;

const float x0 = src[0];
const float x1 = src[1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = x0 * cos_theta - x1 * sin_theta;
dst_data[1] = x0 * sin_theta + x1 * cos_theta;
}
} else {
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float x0 = src[0];
const float x1 = src[1];

theta *= theta_scale;
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: this is probably wrong, but I can't figure it out ..
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);

const int64_t i0 = ib * n_dims + ic / 2;
theta *= theta_scale;

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
const int64_t i0 = ib*n_dims + ic/2;

const float x0 = src[0];
const float x1 = src[n_dims / 2];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = x0 * cos_theta - x1 * sin_theta;
dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
const float x0 = src[0];
const float x1 = src[n_dims/2];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
}
}
}
}
}
}
}
}

static void ne_compute_forward_rope_f16(const struct ne_compute_params* params, const struct ne_tensor* src0,
Expand Down Expand Up @@ -10231,7 +10259,7 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor
const int n_past = ((int32_t*)src1->data)[0];
const int n_dims = ((int32_t*)src1->data)[1];
const int mode = ((int32_t*)src1->data)[2];
src0->grad = ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode), inplace);
src0->grad = ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0), inplace);
}
if (src1->grad) {
// noop
Expand Down

0 comments on commit 9a2cfa5

Please sign in to comment.