Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for xverse #6301

Merged
merged 12 commits into from
Mar 29, 2024
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -115,6 +115,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba)
- [x] [Xverse](https://huggingface.co/models?search=xverse)
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)

**Multimodal models:**
Expand Down
140 changes: 140 additions & 0 deletions convert-hf-to-gguf.py
Expand Up @@ -778,7 +778,147 @@ def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
r = weights.shape[0] // 3
return weights[r * n_part:r * n_part + r, ...]

@Model.register("XverseForCausalLM")
class XverseModel(Model):
model_arch = gguf.MODEL_ARCH.XVERSE

def set_vocab(self):
assert (self.dir_model / "tokenizer.json").is_file()
dir_model = self.dir_model
hparams = self.hparams

tokens: list[bytearray] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()

for token_id in range(vocab_size):
token_text = reverse_vocab[token_id].encode('utf-8')
# replace "\x00" to string with length > 0
if token_text == b"\x00":
toktype = gguf.TokenType.BYTE # special
token_text = f"<{token_text}>".encode('utf-8')
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
toktype = gguf.TokenType.BYTE # special
elif reverse_vocab[token_id] in added_vocab:
if tokenizer.added_tokens_decoder[token_id].special:
toktype = gguf.TokenType.CONTROL
else:
toktype = gguf.TokenType.USER_DEFINED
else:
toktype = gguf.TokenType.NORMAL

tokens.append(token_text)
toktypes.append(toktype)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
hf_repo = self.hparams.get("_name_or_path", "")

ctx_length = 0
if "max_sequence_length" in self.hparams:
ctx_length = self.hparams["max_sequence_length"]
elif "max_position_embeddings" in self.hparams:
ctx_length = self.hparams["max_position_embeddings"]
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def write_tensors(self):
# Collect tensors from generator object
model_kv = dict(self.get_tensors())
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
head_count_kv = self.hparams.get("num_key_value_heads", head_count)

for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

# HF models permute some of the tensors, so we need to undo that
if name.endswith(("q_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
if name.endswith(("k_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)

def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head

return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)

@Model.register("FalconForCausalLM", "RWForCausalLM")
class FalconModel(Model):
model_arch = gguf.MODEL_ARCH.FALCON
Expand Down
22 changes: 22 additions & 0 deletions gguf-py/gguf/constants.py
Expand Up @@ -123,6 +123,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()


Expand Down Expand Up @@ -191,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
}

Expand Down Expand Up @@ -606,6 +608,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.COMMAND_R: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -650,6 +668,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down