Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 96fadd9

Browse files
authored
[FFN Fusion] Support FFN_fusion for the ChatGLM2 (#142)
1 parent 48c1913 commit 96fadd9

File tree

6 files changed

+137
-42
lines changed

6 files changed

+137
-42
lines changed

neural_speed/convert/convert_chatglm.py

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def bytes_to_unicode():
5050

5151

5252
class SentencePieceVocab:
53+
5354
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
5455
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
5556
added_tokens: Dict[str, int]
@@ -149,11 +150,11 @@ def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams
149150
print("ChatGLM-2.gguf converting: ")
150151
list_vars = model.state_dict()
151152
for name in list_vars.keys():
152-
print(name, list_vars[name].shape, list_vars[name].dtype)
153+
print("%-80s" % name, list_vars[name].shape, list_vars[name].dtype)
153154

154155
print(hparams)
155156

156-
gguf_file = fname_out + '.gguf'
157+
gguf_file = fname_out
157158
gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm2")
158159

159160
arch = "chatglm2."
@@ -285,35 +286,68 @@ def write_vocab_gguf(dir_model):
285286
print("gguf: get tensor metadata")
286287
for name in list_vars.keys():
287288
data = list_vars[name].squeeze().numpy()
288-
289-
print("Processing variable: " + name + " with shape: ", data.shape)
290289
if 'inv_freq' in name:
290+
print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape))
291291
continue
292292

293+
print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape), end=" ")
293294
n_dims = len(data.shape)
294295

295296
# ftype == 0 -> float32, ftype == 1 -> float16
296297
ftype_cur = 0
297298
if ftype != 0:
298299
if name[-7:] == ".weight" and n_dims == 2:
299-
print(" Converting to float16")
300+
print(" to float16".rjust(15))
300301
data = data.astype(np.float16)
301302
ftype_cur = 1
302303
else:
303-
print(" Converting to float32")
304+
print(" to float32".rjust(15))
304305
data = data.astype(np.float32)
305306
ftype_cur = 0
306307
else:
307308
if data.dtype != np.float32:
308-
print(" Converting to float32")
309+
print(" to float32".rjust(15))
309310
data = data.astype(np.float32)
310311
ftype_cur = 0
311312

312-
# print(f"[{i+1:{padi}d}/{len(model)}]
313-
# Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}")
314-
315313
gguf_writer.add_tensor(name, data)
316314

315+
if "mlp.dense_h_to_4h" in name:
316+
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
317+
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
318+
shape_0 = data.shape[0]
319+
half_shape_0 = int(shape_0 / 2)
320+
data_0 = data[0:half_shape_0, :]
321+
data_1 = data[half_shape_0:shape_0, :]
322+
323+
print("Converting: %-75s" % name_0, " shape: %-15s" % str(data_0.shape))
324+
print("Converting: %-75s" % name_1, " shape: %-15s" % str(data_1.shape))
325+
326+
n_dims = len(data_0.shape)
327+
assert (len(data_0.shape) == len(data_1.shape))
328+
# ftype == 0 -> float32, ftype == 1 -> float16
329+
ftype_cur = 0
330+
if ftype != 0:
331+
if name_0[-7:] == ".weight" and n_dims == 2:
332+
print(" to float16".rjust(15))
333+
data_0 = data_0.astype(np.float16)
334+
data_1 = data_1.astype(np.float32)
335+
ftype_cur = 1
336+
else:
337+
print(" to float32".rjust(15))
338+
data_0 = data_0.astype(np.float32)
339+
data_1 = data_1.astype(np.float32)
340+
ftype_cur = 0
341+
else:
342+
if data_0.dtype != np.float32:
343+
print(" to float32".rjust(15))
344+
data_0 = data_0.astype(np.float32)
345+
data_1 = data_1.astype(np.float32)
346+
ftype_cur = 0
347+
348+
gguf_writer.add_tensor(name_0, data_0)
349+
gguf_writer.add_tensor(name_1, data_1)
350+
317351
print("gguf: write header")
318352
gguf_writer.write_header_to_file()
319353
print("gguf: write metadata")
@@ -363,9 +397,9 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
363397
fout.write(struct.pack("f", 10000.0)) # freq_base
364398
fout.write(struct.pack("f", 1.0)) # rope_factor
365399

366-
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
367-
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
368-
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
400+
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
401+
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
402+
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
369403

370404
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
371405
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
@@ -419,10 +453,56 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
419453
for i in range(n_dims):
420454
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
421455
fout.write(str)
422-
423456
# data
424457
data.tofile(fout)
425458

459+
if "mlp.dense_h_to_4h" in name:
460+
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
461+
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
462+
shape_0 = data.shape[0]
463+
half_shape_0 = int(shape_0 / 2)
464+
data_0 = data[0:half_shape_0, :]
465+
data_1 = data[half_shape_0:shape_0, :]
466+
467+
print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
468+
print("Converting: %-75s" % name_1, " shape: ", data_1.shape)
469+
470+
n_dims = len(data_0.shape)
471+
assert (len(data_0.shape) == len(data_1.shape))
472+
# ftype == 0 -> float32, ftype == 1 -> float16
473+
ftype_cur = 0
474+
if ftype != 0:
475+
if name_0[-7:] == ".weight" and n_dims == 2:
476+
print(" to float16".rjust(15))
477+
data_0 = data_0.astype(np.float16)
478+
data_1 = data_1.astype(np.float32)
479+
ftype_cur = 1
480+
else:
481+
print(" to float32".rjust(15))
482+
data_0 = data_0.astype(np.float32)
483+
data_1 = data_1.astype(np.float32)
484+
ftype_cur = 0
485+
else:
486+
if data_0.dtype != np.float32:
487+
print(" to float32".rjust(15))
488+
data_0 = data_0.astype(np.float32)
489+
data_1 = data_1.astype(np.float32)
490+
ftype_cur = 0
491+
492+
str_0 = name_0.encode("utf-8")
493+
fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
494+
for i in range(n_dims):
495+
fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
496+
fout.write(str_0)
497+
data_0.tofile(fout)
498+
499+
str_1 = name_1.encode("utf-8")
500+
fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
501+
for i in range(n_dims):
502+
fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
503+
fout.write(str_1)
504+
data_1.tofile(fout)
505+
426506
fout.close()
427507

428508
print("Done. Output file: " + fname_out)

neural_speed/models/chatglm/chatglm2.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
150150

151151
// self-attention
152152
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
153-
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur);
153+
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
154154
{
155155
// compute QKV
156156
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
157-
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur);
157+
cur = ne_add(ctx0, cur, model.layers[il].attn[1]);
158158

159159
struct ne_tensor* query_layer =
160160
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
@@ -298,19 +298,25 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
298298

299299
// mlp.forward
300300
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps);
301-
ne_set_name(mlp_output, "mlp_output");
302-
// mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);
303-
mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output);
304-
305-
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output);
306-
struct ne_tensor* x0 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0);
307-
x0 = ne_silu(ctx0, x0);
308-
struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1],
309-
mlp_output->ne[0] / 2 * ne_element_size(mlp_output));
310-
ne_set_name(x0, "x0");
311-
ne_set_name(x1, "x1");
312-
mlp_output = ne_mul(ctx0, x0, x1);
313-
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output);
301+
mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);
302+
303+
if (model.layers[il].ffn_fusion &&
304+
bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[2]->data, model.layers[il].ffn[1]->data,
305+
model.layers[il].ffn[3]->data, N, int(cur->ne[0] / 2),
306+
model.layers[il].ffn[2]->ne[1], model.layers[il].ffn[1]->ne[1])) {
307+
mlp_output =
308+
ne_ffn_silu(ctx0, model.layers[il].ffn[2], model.layers[il].ffn[1], model.layers[il].ffn[3], mlp_output);
309+
} else {
310+
// mlp.forward
311+
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[0], mlp_output);
312+
struct ne_tensor* x0 =
313+
ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1], 0);
314+
x0 = ne_silu(ctx0, x0);
315+
struct ne_tensor* x1 = ne_view_2d(ctx0, mlp_output, mlp_output->ne[0] / 2, mlp_output->ne[1], mlp_output->nb[1],
316+
mlp_output->ne[0] / 2 * ne_element_size(mlp_output));
317+
mlp_output = ne_mul(ctx0, x0, x1);
318+
mlp_output = ne_mul_mat(ctx0, model.layers[il].ffn[1], mlp_output);
319+
}
314320

315321
#ifdef NS_TP_MODEL
316322
if (enable_tp) {
@@ -327,9 +333,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
327333
// norm
328334
{
329335
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
330-
ne_set_name(inpL, "inpL");
331-
// inpL = ne_mul(ctx0, inpL, model.others[1]);
332-
inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL);
336+
inpL = ne_mul(ctx0, inpL, model.others[1]);
333337
}
334338

335339
lctx.use_buf(ctx0, -1);

neural_speed/models/chatglm/chatglm2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ enum chatglm2_model {
2626
static const model_scratch chatglm_mem_req(int n_layers) {
2727
switch (n_layers) {
2828
case 28:
29-
return {2048ull * MB, 2048ull * MB, 4096ull * MB};
29+
return {4096ull * MB, 4096ull * MB, 8192ull * MB};
3030
default:
3131
MODEL_ASSERT(false);
3232
}

neural_speed/models/chatglm/chatglm2_utils.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,13 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_
6767
model.hparams = ml->file_loaders.at(0)->hparams;
6868
model_file_version file_version = ml->file_loaders.at(0)->file_version;
6969
auto& hparams = model.hparams;
70-
n_ff = 4 * hparams.n_embd;
71-
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
72-
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
73-
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
74-
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
75-
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
76-
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
77-
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
70+
fprintf(stderr, "%s: hparams.n_vocab = %u\n", __func__, hparams.n_vocab);
71+
fprintf(stderr, "%s: hparams.n_embd = %u\n", __func__, hparams.n_embd);
72+
fprintf(stderr, "%s: hparams.n_mult = %u\n", __func__, hparams.n_mult);
73+
fprintf(stderr, "%s: hparams.n_head = %u\n", __func__, hparams.n_head);
74+
fprintf(stderr, "%s: hparams.n_layer = %u\n", __func__, hparams.n_layer);
75+
fprintf(stderr, "%s: hparams.n_rot = %u\n", __func__, hparams.n_rot);
76+
fprintf(stderr, "%s: hparams.ffn_hidden_size = %u\n", __func__, hparams.ffn_hidden_size);
7877
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
7978
n_embd = hparams.n_embd;
8079
n_vocab = hparams.n_vocab;
@@ -149,6 +148,15 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
149148
layer.ffn[1] =
150149
ml->get_tensor(layers_i + ".mlp.dense_4h_to_h.weight", {uint32_t(hparams.ffn_hidden_size), n_embd}, backend);
151150

151+
if (ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight") &&
152+
ml->verify_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight")) {
153+
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_0.weight",
154+
{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend);
155+
layer.ffn[3] = ml->get_tensor(layers_i + ".mlp.dense_h_to_4h_1.weight",
156+
{n_embd, uint32_t(hparams.ffn_hidden_size)}, backend);
157+
layer.ffn_fusion = true;
158+
}
159+
152160
// kv-cache
153161
layer.k_cache = nullptr; // kv-cache will be init later in model_utils
154162
layer.v_cache = nullptr; // kv-cache will be init later in model_utils
@@ -161,6 +169,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
161169
}
162170
}
163171

172+
fprintf(stderr, "%s: layers[0].ffn_fusion = %u\n", __func__, model.layers[0].ffn_fusion);
164173
// print memory requirements
165174
// this is the total memory required to run the inference
166175
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory

neural_speed/models/model_utils/model_types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ struct model_layer {
160160

161161
struct ne_tensor* k_cache;
162162
struct ne_tensor* v_cache;
163+
164+
bool ffn_fusion = false;
163165
};
164166

165167
typedef int32_t model_pos;

tests/model-test/cpp_graph_inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ function main() {
221221
infer_cmd="./build/bin/run_dolly"
222222
elif [[ "${model}" == "chatglm2" ]]; then
223223
quant_script="./build/bin/quant_chatglm2"
224-
convert_script="${convert_script}/convert_chatglm.py"
224+
convert_script="${convert_script}/convert_chatglm.py --format=GGUF"
225225
infer_cmd="./build/bin/run_chatglm2"
226226
elif [[ "${model}" == "chatglm-6b" ]]; then
227227
quant_script="./build/bin/quant_chatglm"

0 commit comments

Comments
 (0)