Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8d5fe2d

Browse files
authored
[Model Enhence] Add Baichuan-7B architecutre and refactor Baichuan-13B. (#177)
1 parent eed9b30 commit 8d5fe2d

File tree

5 files changed

+208
-48
lines changed

5 files changed

+208
-48
lines changed

neural_speed/convert/convert_baichuan.py

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def bytes_to_unicode():
4848

4949

5050
class SentencePieceVocab:
51+
5152
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
5253
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
5354
added_tokens: Dict[str, int]
@@ -116,8 +117,7 @@ def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab:
116117
else:
117118
raise FileNotFoundError(
118119
f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \
119-
pass the directory as --vocab-dir"
120-
)
120+
pass the directory as --vocab-dir")
121121
added_tokens_path = path.parent / "added_tokens.json"
122122
print(f"Loading vocab file {path}")
123123
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
@@ -161,9 +161,112 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
161161
fout.write(struct.pack("f", 10000.0)) # freq_base
162162
fout.write(struct.pack("f", 1.0)) # rope_factor
163163

164-
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
165-
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
166-
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
164+
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
165+
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
166+
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
167+
168+
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
169+
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
170+
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
171+
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
172+
173+
tokenizer_path = Path(tokenizer.vocab_file).parent
174+
vocab = load_vocab_for_baichuan(Path(tokenizer_path))
175+
counter = 0
176+
for text, score in vocab.all_tokens():
177+
fout.write(struct.pack("i", len(text)))
178+
fout.write(text)
179+
fout.write(struct.pack("f", score))
180+
counter += 1
181+
182+
while counter < hparams["vocab_size"]:
183+
fout.write(struct.pack("i", len(text)))
184+
fout.write(text)
185+
fout.write(struct.pack("f", 0))
186+
counter += 1
187+
188+
for name in list_vars.keys():
189+
data = list_vars[name].squeeze().numpy()
190+
print("Processing variable: " + name + " with shape: ", data.shape)
191+
if 'inv_freq' in name:
192+
continue
193+
194+
n_dims = len(data.shape)
195+
196+
# ftype == 0 -> float32, ftype == 1 -> float16
197+
ftype_cur = 0
198+
if ftype != 0:
199+
if name[-7:] == ".weight" and n_dims == 2:
200+
print(" Converting to float16")
201+
data = data.astype(np.float16)
202+
ftype_cur = 14
203+
else:
204+
print(" Converting to float32")
205+
data = data.astype(np.float32)
206+
ftype_cur = 0
207+
else:
208+
if data.dtype != np.float32:
209+
print(" Converting to float32")
210+
data = data.astype(np.float32)
211+
ftype_cur = 0
212+
213+
# header
214+
str = name.encode("utf-8")
215+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
216+
for i in range(n_dims):
217+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
218+
fout.write(str)
219+
220+
# data
221+
data.tofile(fout)
222+
223+
fout.close()
224+
225+
print("Done. Output file: " + fname_out)
226+
print("")
227+
228+
229+
def baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
230+
print("Baichuan-7B converting: ")
231+
list_vars = model.state_dict()
232+
for name in list_vars.keys():
233+
print(name, list_vars[name].shape, list_vars[name].dtype)
234+
235+
fout = open(fname_out, "wb")
236+
237+
print(hparams)
238+
239+
fout.write(struct.pack("i", 0x67676d66))
240+
fout.write(struct.pack("i", 1))
241+
242+
fout.write(struct.pack("i", hparams["vocab_size"]))
243+
fout.write(struct.pack("i", hparams["hidden_size"]))
244+
fout.write(struct.pack("i", 0))
245+
fout.write(struct.pack("i", hparams["num_attention_heads"]))
246+
fout.write(struct.pack("i", 0))
247+
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
248+
fout.write(struct.pack("i", 128))
249+
fout.write(struct.pack("i", ftype))
250+
fout.write(struct.pack("i", hparams["model_max_length"]))
251+
fout.write(struct.pack("f", 0))
252+
fout.write(struct.pack("f", 0))
253+
fout.write(struct.pack("i", 0))
254+
255+
fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
256+
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
257+
258+
fout.write(struct.pack("i", 0))
259+
fout.write(struct.pack("i", 0))
260+
fout.write(struct.pack("i", hparams["intermediate_size"]))
261+
fout.write(struct.pack("i", 0)) # n_experts
262+
fout.write(struct.pack("i", 0)) # n_expert_used
263+
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
264+
fout.write(struct.pack("f", 10000.0)) # freq_base
265+
fout.write(struct.pack("f", 1.0)) # rope_factor
266+
267+
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
268+
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
269+
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
167270

168271
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
169272
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
@@ -230,8 +333,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
230333
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
231334
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
232335
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
233-
parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
234-
default="huggingface", help="hub to load model")
336+
parser.add_argument("--model_hub",
337+
choices=["huggingface", "modelscope"],
338+
default="huggingface",
339+
help="hub to load model")
235340
parser.add_argument("model", type=Path, help="directory containing model file")
236341
args = parser.parse_args(args_in)
237342

@@ -255,7 +360,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
255360

256361
hparams = config.to_dict()
257362

258-
baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
363+
if hparams['hidden_size'] == 4096:
364+
baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
365+
else:
366+
baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
259367

260368

261369
if __name__ == '__main__':

neural_speed/models/baichuan/baichuan.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,14 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
7474
int n_head = hparams.n_head;
7575
const int n_vocab = hparams.n_vocab;
7676
const int head_size = n_embd / n_head;
77-
const int n_rot = n_embd / n_head / 2;
7877
const float attn_scale = 1.f / std::sqrt(head_size);
78+
const int n_rot = hparams.n_rot;
79+
int baichuan_version = 0;
80+
if (hparams.n_embd == 4096) {
81+
baichuan_version = 7;
82+
} else {
83+
baichuan_version = 13;
84+
}
7985

8086
bool enable_tp = false;
8187
#ifdef NS_TP_MODEL
@@ -131,6 +137,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
131137
}
132138

133139
struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);
140+
134141
for (int i = 0; i < batch_size; ++i) {
135142
memcpy(static_cast<model_token*>(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd));
136143
}
@@ -152,7 +159,6 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
152159
{
153160
// Linear::forward compute QKV
154161
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
155-
156162
ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
157163
0); // [N, hidden]
158164

@@ -162,6 +168,12 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
162168
ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
163169
2 * hidden_size * ne_element_size(cur)); // [N, heads, head_size]
164170

171+
// using mode = 2 for neox mode
172+
if (baichuan_version == 7) {
173+
query_layer = ne_rope_inplace(ctx0, query_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
174+
key_layer = ne_rope_inplace(ctx0, key_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
175+
}
176+
165177
if (!run_mha_reordered) {
166178
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size]
167179
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, N, head_size]
@@ -193,7 +205,11 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
193205
// attention
194206
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [heads, N, klen]
195207
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, attn_scale));
196-
attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8);
208+
209+
if (baichuan_version == 13) {
210+
attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8);
211+
}
212+
197213
if (n_past == 0) {
198214
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
199215
}

neural_speed/models/baichuan/baichuan.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ static const model_scratch baichuan_mem_req(int n_layers, float scratch_size_rat
3131
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
3232
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
3333
};
34+
case 32:
35+
return {
36+
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
37+
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
38+
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
39+
};
3440
default:
3541
MODEL_ASSERT(false);
3642
}

neural_speed/models/baichuan/baichuan_utils.cpp

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ void BAICHUAN::init(const char* path_model, model_context* ctx, int n_gpu_layer_
6363
model.hparams = ml->file_loaders.at(0)->hparams;
6464
model_file_version file_version = ml->file_loaders.at(0)->file_version;
6565
auto& hparams = model.hparams;
66-
n_ff = 4 * hparams.n_embd;
6766
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
6867
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
6968
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
7069
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
7170
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
7271
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
73-
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
72+
fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size);
7473
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
74+
fprintf(stderr, "%s: inner_hidden_size = %u\n", __func__, hparams.inner_hidden_size);
7575
n_embd = hparams.n_embd;
7676
n_vocab = hparams.n_vocab;
7777
n_layer = hparams.n_layer;
@@ -92,10 +92,6 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
9292
fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
9393

9494
const auto& hparams = model.hparams;
95-
const int head_dim = n_embd / hparams.n_head;
96-
const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
97-
const int kv_dim = kv_heads * head_dim;
98-
const int max_len = 4096;
9995

10096
// create the ne context
10197
lctx.model.buf.resize(ctx_size);
@@ -116,37 +112,71 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
116112
}
117113

118114
ml->ne_ctx = ne_ctx;
119-
120-
model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
121-
model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
122-
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
123-
const int i_gpu_start = n_layer - n_gpu_layer;
124-
125-
model.layers.resize(n_layer);
126115
size_t vram_total = 0;
127-
for (uint32_t i = 0; i < n_layer; ++i) {
128-
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
129-
auto& layer = model.layers[i];
130-
std::string layers_i = "model.layers." + std::to_string(i);
131-
layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);
132-
133-
// qkv GEMM
134-
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
135-
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);
136-
137-
layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);
138116

139-
// ffn GEMM
140-
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight",
141-
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
142-
143-
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",
144-
{uint32_t(model.hparams.inner_hidden_size), n_embd}, backend);
145-
layer.ffn[2] =
146-
ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
147-
148-
layer.v_cache = nullptr;
149-
layer.k_cache = nullptr;
117+
if (ml->verify_tensor("token_embd.weight")) { // for gguf
118+
model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
119+
model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
120+
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
121+
const int i_gpu_start = n_layer - n_gpu_layer;
122+
123+
model.layers.resize(n_layer);
124+
for (uint32_t i = 0; i < n_layer; ++i) {
125+
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
126+
auto& layer = model.layers[i];
127+
std::string layers_i = "blk." + std::to_string(i);
128+
layer.norm[0] = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend);
129+
130+
// qkv GEMM
131+
std::string w_pack = "model.layers." + std::to_string(i);
132+
layer.attn[0] = ml->get_tensor(w_pack + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
133+
layer.attn[1] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
134+
135+
layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
136+
137+
// ffn GEMM
138+
layer.ffn[0] =
139+
ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend);
140+
141+
layer.ffn[1] =
142+
ml->get_tensor(layers_i + ".ffn_down.weight", {uint32_t(model.hparams.ffn_hidden_size), n_embd}, backend);
143+
layer.ffn[2] =
144+
ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend);
145+
146+
layer.v_cache = nullptr;
147+
layer.k_cache = nullptr;
148+
}
149+
} else {
150+
model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
151+
model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
152+
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
153+
const int i_gpu_start = n_layer - n_gpu_layer;
154+
155+
model.layers.resize(n_layer);
156+
for (uint32_t i = 0; i < n_layer; ++i) {
157+
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
158+
auto& layer = model.layers[i];
159+
std::string layers_i = "model.layers." + std::to_string(i);
160+
layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);
161+
162+
// qkv GEMM
163+
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
164+
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);
165+
166+
layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);
167+
168+
// ffn GEMM
169+
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight",
170+
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
171+
172+
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",
173+
{uint32_t(model.hparams.inner_hidden_size), n_embd}, backend);
174+
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight",
175+
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);
176+
177+
layer.v_cache = nullptr;
178+
layer.k_cache = nullptr;
179+
}
150180
}
151181

152182
// print memory requirements

neural_speed/models/model_utils/model_files.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ struct model_file_loader {
10901090
printf("%-16s %d.hparams.n_head = %-30d\n", __func__, count++, hparams.n_head);
10911091
printf("%-16s %d.hparams.n_head_kv = %-30d\n", __func__, count++, hparams.n_head_kv);
10921092
printf("%-16s %d.hparams.n_layer = %-30d\n", __func__, count++, hparams.n_layer);
1093-
printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_vocab);
1093+
printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_rot);
10941094

10951095
hparams.ftype = (enum ne_ftype)file.read_u32();
10961096
hparams.max_seq_len = file.read_u32();
@@ -1122,7 +1122,7 @@ struct model_file_loader {
11221122
file.read_raw(&hparams.norm_eps, sizeof(float));
11231123
file.read_raw(&hparams.freq_base, sizeof(float));
11241124
file.read_raw(&hparams.freq_scale, sizeof(float));
1125-
printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
1125+
printf("%-16s %d.hparams.norm_eps = %-30f\n", __func__, count++, hparams.norm_eps);
11261126
printf("%-16s %d.hparams.freq_base = %-30.3f\n", __func__, count++, hparams.freq_base);
11271127
printf("%-16s %d.hparams.freq_scale = %-30.3f\n", __func__, count++, hparams.freq_scale);
11281128

0 commit comments

Comments
 (0)