|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2023 Intel Corporation |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | + |
| 18 | +import os |
| 19 | +import json |
| 20 | +import sys |
| 21 | +import re |
| 22 | +import argparse |
| 23 | +from common import * |
| 24 | +from sentencepiece import SentencePieceProcessor |
| 25 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 26 | + |
| 27 | + |
| 28 | +def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab: |
| 29 | + # Be extra-friendly and accept either a file or a directory. Also, if it's |
| 30 | + # a directory, it might be the model directory, and tokenizer.model might |
| 31 | + # be in the parent of that. |
| 32 | + if path.is_dir(): |
| 33 | + path2 = path / "tokenizer.model" |
| 34 | + # Use `.parent` instead of /.. to handle the symlink case better. |
| 35 | + path3 = path.parent / "tokenizer.model" |
| 36 | + if path2.exists(): |
| 37 | + path = path2 |
| 38 | + elif path3.exists(): |
| 39 | + path = path3 |
| 40 | + else: |
| 41 | + raise FileNotFoundError( |
| 42 | + f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ |
| 43 | + pass the directory as --vocab-dir") |
| 44 | + added_tokens_path = path.parent / "added_tokens.json" |
| 45 | + print(f"Loading vocab file {path}") |
| 46 | + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) |
| 47 | + |
| 48 | + |
| 49 | +def main(args_in: Optional[List[str]] = None) -> None: |
| 50 | + parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") |
| 51 | + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") |
| 52 | + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") |
| 53 | + parser.add_argument("--model_hub", |
| 54 | + choices=["huggingface", "modelscope"], |
| 55 | + default="huggingface", |
| 56 | + help="hub to load model") |
| 57 | + parser.add_argument("model", type=Path, help="directory containing model file") |
| 58 | + args = parser.parse_args(args_in) |
| 59 | + |
| 60 | + out_path = args.outfile.as_posix() |
| 61 | + model_path = args.model.as_posix() |
| 62 | + |
| 63 | + model, hparams, quantize_config = load_quantized_safetensors(model_path) |
| 64 | + list_vars = model |
| 65 | + |
| 66 | + print(hparams) |
| 67 | + |
| 68 | + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| 69 | + fout = open(out_path, "wb") |
| 70 | + |
| 71 | + # possible data types |
| 72 | + # ftype == 0 -> float32, ftype == 1 -> float16 |
| 73 | + ftype = 0 |
| 74 | + if args.outtype == "f16": |
| 75 | + ftype = 1 |
| 76 | + |
| 77 | + # 1. write hparams |
| 78 | + print(hparams) |
| 79 | + ne_file_magic = 0x67676d66 |
| 80 | + fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex |
| 81 | + fout.write(struct.pack("i", 1)) |
| 82 | + |
| 83 | + fout.write(struct.pack("i", hparams["vocab_size"])) |
| 84 | + fout.write(struct.pack("i", hparams["hidden_size"])) |
| 85 | + fout.write(struct.pack("i", 0)) |
| 86 | + fout.write(struct.pack("i", hparams["num_attention_heads"])) |
| 87 | + fout.write(struct.pack("i", 0)) |
| 88 | + fout.write(struct.pack("i", hparams["num_hidden_layers"])) |
| 89 | + fout.write(struct.pack("i", 0)) |
| 90 | + fout.write(struct.pack("i", ftype)) |
| 91 | + fout.write(struct.pack("i", hparams["model_max_length"])) |
| 92 | + fout.write(struct.pack("f", 0)) |
| 93 | + fout.write(struct.pack("f", 0)) |
| 94 | + fout.write(struct.pack("i", 0)) |
| 95 | + |
| 96 | + fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) |
| 97 | + fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) |
| 98 | + |
| 99 | + fout.write(struct.pack("i", 0)) |
| 100 | + fout.write(struct.pack("i", 0)) |
| 101 | + fout.write(struct.pack("i", hparams["intermediate_size"])) |
| 102 | + fout.write(struct.pack("i", 0)) # n_experts |
| 103 | + fout.write(struct.pack("i", 0)) # n_expert_used |
| 104 | + fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps |
| 105 | + fout.write(struct.pack("f", 10000.0)) # freq_base |
| 106 | + fout.write(struct.pack("f", 1.0)) # rope_factor |
| 107 | + |
| 108 | + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled |
| 109 | + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings |
| 110 | + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) |
| 111 | + |
| 112 | + fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) |
| 113 | + fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) |
| 114 | + fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) |
| 115 | + fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) |
| 116 | + |
| 117 | + # 2. vocab |
| 118 | + tokenizer_path = Path(tokenizer.vocab_file).parent |
| 119 | + vocab = load_vocab_for_baichuan(Path(tokenizer_path)) |
| 120 | + counter = 0 |
| 121 | + for text, score in vocab.all_tokens(): |
| 122 | + fout.write(struct.pack("i", len(text))) |
| 123 | + fout.write(text) |
| 124 | + fout.write(struct.pack("f", score)) |
| 125 | + counter += 1 |
| 126 | + |
| 127 | + while counter < hparams["vocab_size"]: |
| 128 | + fout.write(struct.pack("i", len(text))) |
| 129 | + fout.write(text) |
| 130 | + fout.write(struct.pack("f", 0)) |
| 131 | + counter += 1 |
| 132 | + |
| 133 | + def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout): |
| 134 | + # qwen-gptq is torch.bfloat16 mostly. |
| 135 | + if model[src_name].dtype == torch.float32: |
| 136 | + data = model[src_name].squeeze().numpy() |
| 137 | + else: |
| 138 | + data = model[src_name].squeeze().to(torch.float32).numpy() |
| 139 | + data = data.astype(np.float32) |
| 140 | + shape = data.shape |
| 141 | + n_dims = len(shape) |
| 142 | + print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ", |
| 143 | + data.dtype) |
| 144 | + |
| 145 | + #ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype] |
| 146 | + # default type is fp32 |
| 147 | + ftype_cur = 0 |
| 148 | + if ftype == 1 and n_dims > 1: |
| 149 | + data = data.astype(np.float16) |
| 150 | + ftype_cur = 1 |
| 151 | + else: |
| 152 | + data = data.astype(np.float32) |
| 153 | + |
| 154 | + # header |
| 155 | + # write_header(fout, shape, dst_name, ftype_cur) |
| 156 | + str = src_name.encode('utf-8') |
| 157 | + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) |
| 158 | + for i in range(n_dims): |
| 159 | + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) |
| 160 | + fout.write(str) |
| 161 | + |
| 162 | + # data |
| 163 | + data.tofile(fout) |
| 164 | + |
| 165 | + #3. write tensors |
| 166 | + convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, fout) |
| 167 | + convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, fout) |
| 168 | + convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout) |
| 169 | + |
| 170 | + for i in range(hparams["num_hidden_layers"]): |
| 171 | + prefix = "model.layers." + str(i) |
| 172 | + |
| 173 | + convert_qwen_to_fp32_tensor(f"{prefix}.input_layernorm.weight", f"{prefix}.input_layernorm.weight", list_vars, |
| 174 | + fout) |
| 175 | + convert_qwen_to_fp32_tensor(f"{prefix}.post_attention_layernorm.weight", |
| 176 | + f"{prefix}.post_attention_layernorm.weight", list_vars, fout) |
| 177 | + # qkv GEMM |
| 178 | + convert_to_qx_bestla_tensor(f"{prefix}.self_attn.W_pack.weight", f"{prefix}.self_attn.W_pack.weight", list_vars, |
| 179 | + fout, quantize_config) |
| 180 | + convert_to_qx_bestla_tensor(f"{prefix}.self_attn.o_proj.weight", f"{prefix}.self_attn.o_proj.weight", list_vars, |
| 181 | + fout, quantize_config) |
| 182 | + |
| 183 | + # ffn GEMM |
| 184 | + convert_to_qx_bestla_tensor(f"{prefix}.mlp.gate_proj", f"{prefix}.mlp.gate_proj.weight", list_vars, fout, |
| 185 | + quantize_config) |
| 186 | + convert_to_qx_bestla_tensor(f"{prefix}.mlp.down_proj", f"{prefix}.mlp.down_proj.weight", list_vars, fout, |
| 187 | + quantize_config) |
| 188 | + convert_to_qx_bestla_tensor(f"{prefix}.mlp.up_proj", f"{prefix}.mlp.up_proj.weight", list_vars, fout, |
| 189 | + quantize_config) |
| 190 | + |
| 191 | + fout.close() |
| 192 | + print(f"Success! saved as {out_path}") |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == '__main__': |
| 196 | + main() |
0 commit comments