Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
):
data_qtype = gguf.GGMLQuantizationType.F32

Expand Down Expand Up @@ -4183,6 +4183,21 @@ def set_vocab(self):
super().set_vocab()


@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1

def set_gguf_parameters(self):
super().set_gguf_parameters()

# RND1 specific parameters
# RND1 uses bidirectional attention
self.gguf_writer.add_causal_attention(False)

if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)


@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/batched/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
The example demonstrates batched generation from a given prompt

```bash
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified

...

Expand Down
50 changes: 48 additions & 2 deletions examples/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,54 @@ More Info:
- https://github.com/ggml-org/llama.cpp/pull/14644
- https://github.com/ggml-org/llama.cpp/pull/14771

## Parameters
The diffusion CLI supports various parameters to control the generation process:

Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
### Core Diffusion Parameters
- `--diffusion-steps`: Number of diffusion steps (default: 256)
- `--diffusion-algorithm`: Algorithm for token selection
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
- `1`: ENTROPY_BASED - Entropy-based selection
- `2`: MARGIN_BASED - Margin-based selection
- `3`: RANDOM - Random selection
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
- More documentation here https://github.com/DreamLM/Dream
- `--diffusion-visual`: Enable live visualization during generation

Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
### Scheduling Parameters
Choose one of the following scheduling methods:

**Timestep-based scheduling:**
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)

**Block-based scheduling:**
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)

### Sampling Parameters
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
- `--top-k`: Top-k filtering for sampling
- `--top-p`: Top-p (nucleus) filtering for sampling
- `--seed`: Random seed for reproducibility

### Model Parameters
- `-m`: Path to the GGUF model file
- `-p`: Input prompt text
- `-ub`: Maximum sequence length (ubatch size)
- `-c`: Context size
- `-b`: Batch size

### Examples
#### Dream architechture:
```
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
```

#### LLaDA architechture:
```
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
```

#### RND1 architecture:
```
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
```
9 changes: 5 additions & 4 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@ if(GIT_EXE)
)
endif()

# Build the version string with optional dirty flag
set(GGML_VERSION "${GGML_VERSION_BASE}")
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
set(GGML_VERSION "${GGML_VERSION}-dirty")
endif()

if(NOT GGML_BUILD_COMMIT)
set(GGML_BUILD_COMMIT "unknown")
endif()

# Build the commit string with optional dirty flag
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
endif()

include(CheckIncludeFileCXX)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
Expand Down
15 changes: 15 additions & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ function(ggml_add_cpu_backend_variant tag_name)
set(GGML_INTERNAL_${feat} OFF)
endforeach()

foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
foreach (feat RVV)
set(GGML_INTERNAL_${feat} OFF)
endforeach()

foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
Expand Down Expand Up @@ -402,6 +410,13 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(riscv64_0)
ggml_add_cpu_backend_variant(riscv64_v RVV)
else()
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2303,9 +2303,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;

bool cann_graph_update_required = false;
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;

static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
Expand Down Expand Up @@ -2336,7 +2336,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
}
#else
bool use_cann_graph = false;
bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);

Expand Down
41 changes: 27 additions & 14 deletions ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/spacemit/ime_kernels.h
)
endif()
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
if(NOT GGML_CPU_ALL_VARIANTS)
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif()
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
else()
# Begin with the lowest baseline
set(ARCH_DEFINITIONS "")

if (GGML_INTERNAL_RVV)
message(STATUS "RVV enabled")
list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)
list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)
endif()

ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES
Expand Down
2 changes: 0 additions & 2 deletions ggml/src/ggml-cpu/arch-fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
Expand Down
Loading