From eebbab02ef1f0d3164c2f84c221f4e393e198862 Mon Sep 17 00:00:00 2001 From: Shintarou Okada Date: Sun, 24 Dec 2023 22:35:49 +0900 Subject: [PATCH] llama : add PLaMo model (#3557) * add plamo mock * add tensor loading * plamo convert * update norm * able to compile * fix norm_rms_eps hparam * runnable * use inp_pos * seems ok * update kqv code * remove develop code * update README * shuffle attn_q.weight and attn_output.weight for broadcasting * remove plamo_llm_build_kqv and use llm_build_kqv * fix style * update * llama : remove obsolete KQ_scale * plamo : fix tensor names for correct GPU offload --------- Co-authored-by: Georgi Gerganov --- README.md | 1 + convert-hf-to-gguf.py | 86 +++++++++++++++- gguf-py/gguf/constants.py | 17 ++++ gguf-py/gguf/tensor_mapping.py | 37 ++++--- llama.cpp | 181 +++++++++++++++++++++++++++++++++ 5 files changed, 307 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 649c3b3334387..09338d2264ca7 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github - [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek) - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral) +- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) **Multimodal models:** diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e71a96c483313..303d08170ecb0 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -184,6 +184,8 @@ def from_model_architecture(model_architecture): return MixtralModel if model_architecture == "PhiForCausalLM": return Phi2Model + if model_architecture == "PlamoForCausalLM": + return PlamoModel return Model def _is_model_safetensors(self) -> bool: @@ -225,6 +227,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: return gguf.MODEL_ARCH.LLAMA if arch == "PhiForCausalLM": return gguf.MODEL_ARCH.PHI2 + if arch == "PlamoForCausalLM": + return gguf.MODEL_ARCH.PLAMO raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1002,11 +1006,91 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) +class PlamoModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_name("PLaMo") + self.gguf_writer.add_context_length(4096) # not in config.json + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + + def shuffle_attn_q_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(8, 5, 128, 5120) + data_torch = torch.permute(data_torch, (1, 0, 2, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def shuffle_attn_output_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(5120, 8, 5, 128) + data_torch = torch.permute(data_torch, (0, 2, 1, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def write_tensors(self): + block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + if "self_attn.rotary_emb.inv_freq" in name: + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + # shuffle for broadcasting of gqa in ggml_mul_mat + if new_name.endswith("attn_q.weight"): + data_torch = self.shuffle_attn_q_weight(data_torch) + elif new_name.endswith("attn_output.weight"): + data_torch = self.shuffle_attn_output_weight(data_torch) + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file") + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") parser.add_argument( "--vocab-only", action="store_true", help="extract only the vocab", diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 390dca049ebee..4cd87cdda8b7e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum): STABLELM = auto() QWEN = auto() PHI2 = auto() + PLAMO = auto() class MODEL_TENSOR(IntEnum): @@ -142,6 +143,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PLAMO: "plamo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -349,6 +351,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PLAMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.GPT2: [ # TODO ], diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 6fcbdbc1c0d4c..446c6b6883be9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -79,6 +79,7 @@ class TensorNameMap: "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi "transformer.h.{bid}.ln", # phi2 + "model.layers.layers.{bid}.norm", # plamo ), # Attention norm 2 @@ -99,26 +100,29 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf - "layers.{bid}.attention.wq", # llama-pth - "encoder.layer.{bid}.attention.self.query", # bert - "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf - "layers.{bid}.attention.wk", # llama-pth - "encoder.layer.{bid}.attention.self.key", # bert - "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.k_proj", # plamo ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf - "layers.{bid}.attention.wv", # llama-pth - "encoder.layer.{bid}.attention.self.value", # bert - "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.v_proj", # plamo ), # Attention output @@ -134,12 +138,14 @@ class TensorNameMap: "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "transformer.h.{bid}.mixer.out_proj", # phi2 + "model.layers.layers.{bid}.self_attn.o_proj", # plamo ), # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( - "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf - "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo ), # Feed-forward norm @@ -174,6 +180,7 @@ class TensorNameMap: "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.fc1", # phi2 + "model.layers.layers.{bid}.mlp.up_proj", # plamo ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -186,6 +193,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate_proj", # llama-hf refact "layers.{bid}.feed_forward.w1", # llama-pth "transformer.h.{bid}.mlp.w2", # qwen + "model.layers.layers.{bid}.mlp.gate_proj", # plamo ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -206,6 +214,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "transformer.h.{bid}.mlp.fc2", # phi2 + "model.layers.layers.{bid}.mlp.down_proj", # plamo ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/llama.cpp b/llama.cpp index a24621539f6bd..0b99f1e03f527 100644 --- a/llama.cpp +++ b/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_STABLELM, LLM_ARCH_QWEN, LLM_ARCH_PHI2, + LLM_ARCH_PLAMO, LLM_ARCH_UNKNOWN, }; @@ -216,6 +217,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PLAMO, "plamo" }, }; enum llm_kv { @@ -567,6 +569,24 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -2749,6 +2769,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PLAMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3630,6 +3659,51 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); } } break; + case LLM_ARCH_PLAMO: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5555,6 +5629,109 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_plamo() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + struct ggml_tensor * attention_norm = cur; + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + struct ggml_tensor * sa_out = cur; + + cur = attention_norm; + + // feed-forward network + { + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + cb(cur, "l_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; // @@ -6065,6 +6242,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_phi2(); } break; + case LLM_ARCH_PLAMO: + { + result = llm.build_plamo(); + } break; default: GGML_ASSERT(false); }