Skip to content

Commit

Permalink
llama : add PLaMo model (ggerganov#3557)
Browse files Browse the repository at this point in the history
* add plamo mock

* add tensor loading

* plamo convert

* update norm

* able to compile

* fix norm_rms_eps hparam

* runnable

* use inp_pos

* seems ok

* update kqv code

* remove develop code

* update README

* shuffle attn_q.weight and attn_output.weight for broadcasting

* remove plamo_llm_build_kqv and use llm_build_kqv

* fix style

* update

* llama : remove obsolete KQ_scale

* plamo : fix tensor names for correct GPU offload

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
2 people authored and jordankanter committed Feb 3, 2024
1 parent 1fc2dcc commit eebbab0
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)

**Multimodal models:**

Expand Down
86 changes: 85 additions & 1 deletion convert-hf-to-gguf.py
Expand Up @@ -184,6 +184,8 @@ def from_model_architecture(model_architecture):
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
if model_architecture == "PlamoForCausalLM":
return PlamoModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -225,6 +227,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1002,11 +1006,91 @@ def set_gguf_parameters(self):
self.gguf_writer.add_add_bos_token(False)


class PlamoModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_name("PLaMo")
self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])

def shuffle_attn_q_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(8, 5, 128, 5120)
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch

def shuffle_attn_output_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(5120, 8, 5, 128)
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch

def write_tensors(self):
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

for name, data_torch in self.get_tensors():
if "self_attn.rotary_emb.inv_freq" in name:
continue

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

# shuffle for broadcasting of gqa in ggml_mul_mat
if new_name.endswith("attn_q.weight"):
data_torch = self.shuffle_attn_q_weight(data_torch)
elif new_name.endswith("attn_output.weight"):
data_torch = self.shuffle_attn_output_weight(data_torch)

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Expand Up @@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum):
STABLELM = auto()
QWEN = auto()
PHI2 = auto()
PLAMO = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -142,6 +143,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PLAMO: "plamo",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -349,6 +351,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO
],
Expand Down
37 changes: 23 additions & 14 deletions gguf-py/gguf/tensor_mapping.py
Expand Up @@ -79,6 +79,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
),

# Attention norm 2
Expand All @@ -99,26 +100,29 @@ class TensorNameMap:

# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
),

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
),

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
),

# Attention output
Expand All @@ -134,12 +138,14 @@ class TensorNameMap:
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
),

# Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
),

# Feed-forward norm
Expand Down Expand Up @@ -174,6 +180,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen
"transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand All @@ -186,6 +193,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand All @@ -206,6 +214,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down

0 comments on commit eebbab0

Please sign in to comment.