Skip to content

Commit

Permalink
Add support for BERT embedding models (ggerganov#5423)
Browse files Browse the repository at this point in the history
* BERT model graph construction (build_bert)
* WordPiece tokenizer (llm_tokenize_wpm)
* Add flag for non-causal attention models
* Allow for models that only output embeddings
* Support conversion of BERT models to GGUF
* Based on prior work by @xyzhang626 and @skeskinen

---------

Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
4 people authored and jordankanter committed Mar 13, 2024
1 parent 5398b10 commit eea4ecc
Show file tree
Hide file tree
Showing 8 changed files with 616 additions and 52 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[flake8]
max-line-length = 125
ignore = W503
94 changes: 94 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def from_model_architecture(model_architecture):
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
if model_architecture == "BertModel":
return BertModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -264,6 +266,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
if arch == "BertModel":
return gguf.MODEL_ARCH.BERT

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1629,6 +1633,96 @@ def write_tensors(self):
self.post_write_tensors(tensor_map, name, data_torch)


class BertModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"]

def set_gguf_parameters(self):
# TODO(cebtenzzre): merge with parent class
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_file_type(self.ftype)

def set_vocab(self):
path = self.dir_model
added_tokens_path = self.dir_model if self.dir_model.exists() else None

# use huggingface vocab to get all tokens
vocab = HfVocab(path, added_tokens_path)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size

# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
n_token_types = len(set(toktypes))
self.gguf_writer.add_token_type_count(n_token_types)

# convert to phantom space vocab
def phantom(tok, typ):
if tok.startswith(b"[") and tok.endswith(b"]"):
return tok
if tok.startswith(b"##"):
return tok[2:]
return b"\xe2\x96\x81" + tok
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]

# set up bos and eos tokens (cls and sep)
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)

# add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

# handle special tokens
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
tensors = dict(self.get_tensors())
for name, data_torch in tensors.items():
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
continue # we don't need these

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
new_dtype: type[np.floating[Any]]

if (
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
):
# if f16 desired, convert any float32 2-dim weight tensors to float16
new_dtype = np.float16
else:
# if f32 desired, convert any float16 to float32
new_dtype = np.float32

print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")

if data.dtype != new_dtype:
data = data.astype(new_dtype)

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


Expand Down
12 changes: 11 additions & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
}

const int n_embd = llama_n_embd(model);
const auto * embeddings = llama_get_embeddings(ctx);
auto * embeddings = llama_get_embeddings(ctx);

// l2-normalize embeddings
float norm = 0;
for (int i = 0; i < n_embd; i++) {
norm += embeddings[i] * embeddings[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; i++) {
embeddings[i] /= norm;
}

for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);
Expand Down
43 changes: 25 additions & 18 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Attention:
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand All @@ -60,22 +61,23 @@ class Rope:
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"

class Tokenizer:
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type"
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type"
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"


#
Expand Down Expand Up @@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_OUT = auto()
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_OUT_NORM = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto()
FFN_NORM = auto()
Expand All @@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
FFN_UP_EXP = auto()
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -178,6 +182,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
Expand All @@ -187,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -262,17 +268,18 @@ class MODEL_TENSOR(IntEnum):
],
MODEL_ARCH.BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ def add_layer_norm_eps(self, value: float) -> None:
def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)

def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)

Expand Down Expand Up @@ -387,6 +390,9 @@ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[by
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)

def add_token_type_count(self, value: int) -> None:
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)

def add_token_scores(self, scores: Sequence[float]) -> None:
self.add_array(Keys.Tokenizer.SCORES, scores)

Expand Down
13 changes: 10 additions & 3 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TensorNameMap:
# Normalization of token embeddings
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
),

# Position embeddings
Expand All @@ -54,7 +55,6 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
Expand All @@ -79,7 +79,6 @@ class TensorNameMap:
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2
Expand Down Expand Up @@ -155,6 +154,11 @@ class TensorNameMap:
"model.layers.{bid}.attention.wo", # internlm2
),

# Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
),

# Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
Expand All @@ -171,7 +175,6 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth
"encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2
Expand Down Expand Up @@ -266,6 +269,10 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
),

MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert
)
}

mapping: dict[str, tuple[MODEL_TENSOR, str]]
Expand Down
Loading

0 comments on commit eea4ecc

Please sign in to comment.