Skip to content

Commit

Permalink
Add support for ArcticForCausalLM (#7020)
Browse files Browse the repository at this point in the history
* common : increase max number of experts to 128

* common : add tensor LLM_TENSOR_FFN_NORM_EXPS for normalization before MoE that runs in parallel to attention + ffn

* gguf-py : add architecture-specific block mappings that override selected general block mappings

* convert-hf : add model conversion support for ArcticForCausalLM

* convert-hf : use added_tokens_decoder from tokenizer_config.json to redefine tokens from SentencePiece model (only for ArcticForCausalLM)

* llama : add inference support for LLM_ARCH_ARCTIC

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
  • Loading branch information
fairydreaming and sszymczy committed May 24, 2024
1 parent 0df0aa8 commit fbca2f2
Show file tree
Hide file tree
Showing 4 changed files with 456 additions and 43 deletions.
151 changes: 151 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,6 +2466,157 @@ def set_vocab(self, *args, **kwargs):
self.gguf_writer.add_add_eos_token(True)


@Model.register("ArcticForCausalLM")
class ArcticModel(Model):
model_arch = gguf.MODEL_ARCH.ARCTIC

def set_vocab(self):
# The reason for using a custom implementation here is that the
# snowflake-arctic-instruct model redefined tokens 31998 and 31999 from
# tokenizer.model and used them as BOS and EOS instead of adding new tokens.
from sentencepiece import SentencePieceProcessor

tokenizer_path = self.dir_model / 'tokenizer.model'

if not tokenizer_path.is_file():
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)

# Read the whole vocabulary from the tokenizer.model file
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size

for token_id in range(tokenizer.vocab_size()):

piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

# Use the added_tokens_decoder field from tokeniser_config.json as the source
# of information about added/redefined tokens and modify them accordingly.
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)

if "added_tokens_decoder" in tokenizer_config_json:
added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"]
for token_id, token_json in added_tokens_decoder.items():
token_id = int(token_id)
if (token_id >= vocab_size):
logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

token_content = token_json["content"]
token_type = SentencePieceTokenTypes.USER_DEFINED
token_score = -10000.0

# Map unk_token to UNKNOWN, other special tokens to CONTROL
# Set the score to 0.0 as in the original tokenizer.model
if ("special" in token_json) and token_json["special"]:
if token_content == tokenizer_config_json["unk_token"]:
token_type = SentencePieceTokenTypes.UNKNOWN
else:
token_type = SentencePieceTokenTypes.CONTROL
token_score = 0.0

logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})")
tokens[token_id] = token_content.encode("utf-8")
toktypes[token_id] = token_type
scores[token_id] = token_score

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

if name.endswith("q_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith("k_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for wid in ["w1", "w2", "w3"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def write_tensors(self):
super().write_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
25 changes: 25 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
ARCTIC = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -167,6 +168,7 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN = auto()
FFN_UP = auto()
FFN_ACT = auto()
FFN_NORM_EXP = auto()
FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
Expand Down Expand Up @@ -218,6 +220,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -251,6 +254,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
Expand Down Expand Up @@ -732,6 +736,27 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.ARCTIC: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
19 changes: 18 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -272,6 +273,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand Down Expand Up @@ -306,6 +308,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
"model.layers.{bid}.mlp.c_proj", # starcoder2
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w2", # arctic
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down Expand Up @@ -382,6 +385,18 @@ class TensorNameMap:
),
}

# architecture-specific block mappings
arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
MODEL_ARCH.ARCTIC: {
MODEL_TENSOR.FFN_NORM: (
"model.layers.{bid}.residual_layernorm",
),
MODEL_TENSOR.FFN_NORM_EXP: (
"model.layers.{bid}.post_attention_layernorm",
),
},
}

mapping: dict[str, tuple[MODEL_TENSOR, str]]

def __init__(self, arch: MODEL_ARCH, n_blocks: int):
Expand All @@ -393,12 +408,14 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
self.mapping[key] = (tensor, tensor_name)
if arch in self.arch_block_mappings_cfg:
self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
for bid in range(n_blocks):
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 60
n_experts = 128
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
Expand Down
Loading

0 comments on commit fbca2f2

Please sign in to comment.