Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ else()
set(LLAMA_STANDALONE OFF)
endif()

option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)

if (EMSCRIPTEN)
set(BUILD_SHARED_LIBS_DEFAULT OFF)

Expand Down Expand Up @@ -145,7 +147,13 @@ endif()
# 3rd-party
#

if (NOT TARGET ggml)
if (LLAMA_USE_SYSTEM_GGML)
message(STATUS "Using system-provided libggml, skipping ggml build")
find_package(ggml REQUIRED)
add_library(ggml ALIAS ggml::ggml)
endif()

if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
add_subdirectory(ggml)
# ... otherwise assume ggml is added by a parent CMakeLists.txt
endif()
Expand Down
2 changes: 2 additions & 0 deletions cmake/common.cmake
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
include("ggml/cmake/common.cmake")

function(llama_add_compile_flags)
if (LLAMA_FATAL_WARNINGS)
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
Expand Down
229 changes: 197 additions & 32 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def _set_vocab_rwkv_world(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)

tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]

with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)

self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)

def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
Expand Down Expand Up @@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6

def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)

tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]

with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)

self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
self._set_vocab_rwkv_world()

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
Expand Down Expand Up @@ -3565,6 +3568,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (new_name, data)


@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
class Rwkv7Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV7

def set_vocab(self):
self._set_vocab_rwkv_world()

def calc_lora_rank(self, hidden_size, exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try:
head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
except KeyError:
head_size = self.hparams["head_dim"]
layer_norm_eps = self.hparams["norm_eps"]
hidden_size = self.hparams["hidden_size"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)

# ICLR: In-Context-Learning-Rate
try:
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
except KeyError:
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)

# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)

lerp_weights: dict[int, dict[str, Tensor]] = {}
lora_needs_transpose: bool = True

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# unify tensor names here to make life easier
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
name = name.replace("self_attn", "attention").replace("attn", "attention")
name = name.replace("time_mixer.", "")
# lora layer names in fla-hub's impl
if "_lora.lora" in name:
self.lora_needs_transpose = False
name = name.replace("_lora.lora.0.weight", "1.weight")
name = name.replace("_lora.lora.2.weight", "2.weight")
name = name.replace("_lora.lora.2.bias", "0.weight")

name = name.replace("feed_forward_norm", "ln2")
name = name.replace("g_norm", "ln_x")

if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
# some models have dummy v0/v1/v2 on first layer while others don't
# ignore them all since they are not used
return

wkv_has_gate = self.hparams.get("wkv_has_gate", True)
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]

if bid is not None and "attention.x_" in name:
if "attention.x_x" in name:
# already concatenated
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
yield (new_name, data)
else:
try:
self.lerp_weights[bid][name] = data_torch
except KeyError:
self.lerp_weights[bid] = {name: data_torch}
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
yield (new_name, data)
return
else:
data_torch = data_torch.squeeze()
new_name = self.map_tensor_name(name)

if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"

if self.lora_needs_transpose and any(
new_name.endswith(t) for t in [
"time_mix_w1.weight", "time_mix_w2.weight",
"time_mix_a1.weight", "time_mix_a2.weight",
"time_mix_v1.weight", "time_mix_v2.weight",
"time_mix_g1.weight", "time_mix_g2.weight",
]
):
data_torch = data_torch.transpose(0, 1)

if 'r_k' in new_name:
data_torch = data_torch.flatten()

if bid == 0 and "time_mix_a" in new_name:
# dummy v0/v1/v2 on first layer
# easist way to make llama happy
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)

yield (new_name, data_torch)


@Model.register("RwkvHybridForCausalLM")
class ARwkv7Model(Rwkv7Model):
model_arch = gguf.MODEL_ARCH.ARWKV7

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
wkv_has_gate = self.hparams["wkv_has_gate"]
assert self.hparams["wkv_version"] == 7

# ICLR: In-Context-Learning-Rate
lora_rank_decay = 64
lora_rank_iclr = 64
lora_rank_value_residual_mix = 32
lora_rank_gate = 128 if wkv_has_gate else 0

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_token_shift_count(1)

# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)


@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA
Expand Down
41 changes: 36 additions & 5 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,24 @@ Once downloaded, place your model in the models folder in llama.cpp.
##### Input prompt (One-and-done)

```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
```
##### Conversation mode (Allow for continuous interaction with the model)

```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
```

##### Conversation mode using built-in jinja chat template

```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja
```

##### One-and-done query using jinja with custom system prompt and a starting prompt

```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
```

##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
Expand All @@ -44,12 +56,24 @@ Once downloaded, place your model in the models folder in llama.cpp.

##### Input prompt (One-and-done)
```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --prompt "Once upon a time"
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -no-cnv --prompt "Once upon a time"
```
##### Conversation mode (Allow for continuous interaction with the model)

```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf -cnv --chat-template gemma
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --chat-template gemma
```

##### Conversation mode using built-in jinja chat template

```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja
```

##### One-and-done query using jinja with custom system prompt and a starting prompt

```powershell
./llama-cli.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --jinja --single-turn -sys "You are a helpful assistant" -p "Hello"
```

#### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
Expand Down Expand Up @@ -77,6 +101,8 @@ The `llama-cli` program provides several ways to interact with the LLaMA models

- `--prompt PROMPT`: Provide a prompt directly as a command-line option.
- `--file FNAME`: Provide a file containing a prompt or multiple prompts.
- `--system-prompt PROMPT`: Provide a system prompt (will otherwise use the default one in the chat template (if provided)).
- `--system-prompt-file FNAME`: Provide a file containing a system prompt.
- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)

## Interaction
Expand All @@ -89,7 +115,10 @@ In interactive mode, users can participate in text generation by injecting their

- `-i, --interactive`: Run the program in interactive mode, allowing users to engage in real-time conversations or provide specific instructions to the model.
- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: false)
- `-cnv, --conversation`: Run the program in conversation mode (does not print special tokens and suffix/prefix, use default or provided chat template) (default: true if chat template found)
- `-no-cnv`: Disable conversation mode (default: false)
- `-st, --single-turn`: Only process a single conversation turn (user input) and then exit.
- `--jinja`: Enable jinja chat template parser, will use the model's built-in template or a user-provided one (default: false)
- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.

By understanding and utilizing these interaction options, you can create engaging and dynamic experiences with the LLaMA models, tailoring the text generation process to your specific needs.
Expand Down Expand Up @@ -125,6 +154,8 @@ When --in-prefix or --in-suffix options are enabled the chat template ( --chat-t

Example usage: `--chat-template gemma`

`--chat-template-file FNAME`: Load a custom jinja chat template from an external file, useful if the model contains outdated or incompatible template, some examples can be found in models/templates. Up-to-date chat templates can be downloaded from Hugging Face using scripts/get_chat_template.py

## Context Management

During text generation, LLaMA models have a limited context size, which means they can only consider a certain number of tokens from the input and generated text. When the context fills up, the model resets internally, potentially losing some information from the beginning of the conversation or instructions. Context management options help maintain continuity and coherence in these situations.
Expand Down
26 changes: 26 additions & 0 deletions ggml/cmake/common.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
function(ggml_get_flags CCID CCVER)
set(C_FLAGS "")
set(CXX_FLAGS "")

if (CCID MATCHES "Clang")
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)

if (
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
)
list(APPEND C_FLAGS -Wdouble-promotion)
endif()
elseif (CCID STREQUAL "GNU")
set(C_FLAGS -Wdouble-promotion)
set(CXX_FLAGS -Wno-array-bounds)

if (CCVER VERSION_GREATER_EQUAL 8.1.0)
list(APPEND CXX_FLAGS -Wextra-semi)
endif()
endif()

set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
endfunction()
Loading
Loading