Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Adept Persimmon 8b #3410

Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7cdc3ea
Produces garbage output
phillip-kravtsov Sep 21, 2023
4bcf412
wip: correct tensors up to RoPE
phillip-kravtsov Sep 26, 2023
c9e1446
correct tensors thru RoPE
phillip-kravtsov Sep 26, 2023
d1b40ef
Correct outputs through masked & softmax'd KQ
phillip-kravtsov Sep 26, 2023
db2181a
fp32 works
phillip-kravtsov Sep 26, 2023
3f31799
Rename adept->persimmon
phillip-kravtsov Sep 28, 2023
720503b
Merge branch 'master' of github.com:phillip-kravtsov/llama.cpp into p…
phillip-kravtsov Sep 28, 2023
d61eed0
Produces correct outputs
phillip-kravtsov Sep 29, 2023
d0a7143
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
fa92f6e
clean up convert scripts
phillip-kravtsov Sep 29, 2023
c28a6c5
remove printing logic from ggml.c
phillip-kravtsov Sep 29, 2023
47dcb9f
remove prints from llama.cpp & fix merge
phillip-kravtsov Sep 29, 2023
7473773
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
d904aff
trivial cleanups
phillip-kravtsov Sep 29, 2023
ec0ce97
Add offload funcs
phillip-kravtsov Sep 29, 2023
3db04db
update conversion script to directly take adept artifacts rather than…
phillip-kravtsov Sep 29, 2023
f28f52c
Fix norm eps bug
phillip-kravtsov Sep 29, 2023
d93cf1e
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 29, 2023
574a9e1
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Sep 30, 2023
2b56591
Support sqr and concat on metal, persimmon-8b-q4 runs correctly
phillip-kravtsov Sep 30, 2023
e6bf87f
Small changes from review
phillip-kravtsov Oct 2, 2023
cd4d3df
Formatting changes
phillip-kravtsov Oct 2, 2023
422b110
Minor changes to conversion script
phillip-kravtsov Oct 2, 2023
5a0990c
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 2, 2023
7a279fe
Remove old script
phillip-kravtsov Oct 2, 2023
c90ed9f
Fix editorconfig formatting
phillip-kravtsov Oct 3, 2023
5d259d3
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 5, 2023
1d518d6
Fix build
phillip-kravtsov Oct 5, 2023
0c1a8f6
Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-…
phillip-kravtsov Oct 6, 2023
485a471
add overlooked offload code ggml-ci
phillip-kravtsov Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
133 changes: 133 additions & 0 deletions convert-persimmon-to-gguf.py
@@ -0,0 +1,133 @@
import torch
import os
from pprint import pprint
import sys
import argparse
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf

def _flatten_dict(dct, tensors, prefix=None):
assert isinstance(dct, dict)
for key in dct.keys():
new_prefix = prefix + '.' + key if prefix is not None else key
if isinstance(dct[key], torch.Tensor):
tensors[new_prefix] = dct[key]
elif isinstance(dct[key], dict):
_flatten_dict(dct[key], tensors, new_prefix)
else:
raise ValueError(type(dct[key]))
return None

def get_tokenizer_info(dir_model: Path):
tokenizer_path = dir_model / 'adept_vocab.model'
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
tokenizer = SentencePieceProcessor(str(tokenizer_path))
print('gguf: adding tokens')
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []

for i in range(tokenizer.vocab_size()):
text: bytes
score: float

piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3

# toktype = 4 is user-defined = tokens from added_tokens.json
goerch marked this conversation as resolved.
Show resolved Hide resolved

if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

tokens.append(text)
scores.append(score)
toktypes.append(toktype)
pass
return tokens, scores, toktypes

def main():
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
args = parser.parse_args()
sys.path.append(str(args.adept_inference_dir))
persimmon_model = torch.load(args.ckpt_path)
hparams = persimmon_model['args']
pprint(hparams)
tensors = {}
_flatten_dict(persimmon_model['model'], tensors, None)

arch = gguf.MODEL_ARCH.PERSIMMON
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])

block_count = hparams.num_layers
head_count = hparams.num_attention_heads
head_count_kv = head_count
ctx_length = hparams.seq_length
hidden_size = hparams.hidden_size

gguf_writer.add_name('persimmon-8b-chat')
gguf_writer.add_context_length(ctx_length)
gguf_writer.add_embedding_length(hidden_size)
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)
tokens, scores, toktypes = get_tokenizer_info(args.model_dir)
gguf_writer.add_tokenizer_model('llama')
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
gguf_writer.add_bos_token_id(71013)
gguf_writer.add_eos_token_id(71013)

tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map)
for name in tensors.keys():
data = tensors[name]
if name.endswith(".self_attention.rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
data = data.to(torch.float32).squeeze().numpy()
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{args.outfile}'")
print("")



if __name__ == '__main__':
main()
42 changes: 42 additions & 0 deletions gguf-py/gguf/gguf.py
Expand Up @@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum):
GPTNEOX : int = auto()
MPT : int = auto()
STARCODER : int = auto()
PERSIMMON : int = auto()
phillip-kravtsov marked this conversation as resolved.
Show resolved Hide resolved


class MODEL_TENSOR(IntEnum):
Expand All @@ -105,6 +106,8 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN : int = auto()
FFN_UP : int = auto()
FFN_NORM : int = auto()
ATTN_Q_NORM : int = auto()
ATTN_K_NORM : int = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand All @@ -116,6 +119,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.PERSIMMON: "persimmon",
}

MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
Expand Down Expand Up @@ -185,6 +189,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.PERSIMMON: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
phillip-kravtsov marked this conversation as resolved.
Show resolved Hide resolved
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
},
MODEL_ARCH.GPT2: {
# TODO
},
Expand All @@ -201,6 +219,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.PERSIMMON: [
MODEL_TENSOR.ROPE_FREQS,
]
}


Expand All @@ -213,6 +234,7 @@ class TensorNameMap:
"transformer.word_embeddings", # falcon
"model.embed_tokens", # llama-hf
"tok_embeddings", # llama-pth
"language_model.embedding.word_embeddings", # persimmon
),

# Position embeddings
Expand All @@ -225,6 +247,7 @@ class TensorNameMap:
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan
"output", # llama-pth
"word_embeddings_for_head", # persimmon
),

# Output norm
Expand All @@ -233,6 +256,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 falcon
"model.norm", # llama-hf baichuan
"norm", # llama-pth
"language_model.encoder.final_layernorm", # persimmon
),

# Rope frequencies
Expand All @@ -251,6 +275,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
),

# Attention norm 2
Expand All @@ -264,6 +289,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.c_attn", # gpt2
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
),

# Attention query
Expand Down Expand Up @@ -292,6 +318,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.dense", # falcon
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
),

# Rotary embeddings
Expand All @@ -307,6 +334,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
),

# Feed-forward up
Expand All @@ -317,6 +345,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf
"layers.{bid}.feed_forward.w3", # llama-pth
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
),

# Feed-forward gate
Expand All @@ -333,7 +362,20 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"model.layers.{bid}.mlp.down_proj", # llama-hf
"layers.{bid}.feed_forward.w2", # llama-pth
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
),

MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
),

MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
),

MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
)
}

mapping: dict[str, tuple[MODEL_TENSOR, str]]
Expand Down