Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
52de2e5
tts : remove printfs (#12640)
marcoStocchi Mar 31, 2025
f52d59d
llava : fix clip loading GGUFs with missing description (#12660)
CISC Mar 31, 2025
6c02a03
SYCL: Remove misleading ggml_sycl_op_flatten function (#12387)
qnixsynapse Mar 31, 2025
1a85949
llava : proper description fix (#12668)
CISC Mar 31, 2025
a772448
cmake: improve Vulkan cooperative matrix support checks (whisper/2966)
sandrohanea Mar 31, 2025
0114a32
sync : ggml
ggerganov Mar 31, 2025
1790e73
cmake : fix whitespace (#0)
ggerganov Mar 31, 2025
a8a1f33
Vulkan: Add DP4A MMQ and Q8_1 quantization shader (#12135)
0cc4m Mar 31, 2025
403fbac
convert : Qwerky : use lora_rank_tokenshift and lora_rank_decay if pr…
CISC Mar 31, 2025
250d795
ggml : faster ssm scan (#10558)
A3shTnT Mar 31, 2025
c80a775
vocab : add special infill tokens for CodeLlama (#11850)
danbev Mar 31, 2025
35782ae
convert : BailingMoE : avoid setting rope_dim to 0 (#12678)
CISC Mar 31, 2025
8bbf260
SYCL: switch to SYCL namespace (#12674)
qnixsynapse Apr 1, 2025
8293970
SYCL: Rename oneMKL to oneMath (#12192)
Rbiessy Apr 1, 2025
2bb3597
vulkan: fix build when glslc doesn't support coopmat (#12683)
wbruna Apr 1, 2025
a6f32f0
Fix clang warning in gguf_check_reserved_keys (#12686)
yeahdongcn Apr 1, 2025
3fd072a
metal : use F32 prec in FA kernels (#12688)
ggerganov Apr 1, 2025
5936a61
convert : BailingMoE : fix qkv split when head_dim is 0 (#12687)
CISC Apr 1, 2025
e39e727
llama : use LLM_KV_GENERAL_FILE_TYPE instead of gguf_find_key (#12672)
jklincn Apr 1, 2025
f423981
opencl : fix memory allocation size (#12649)
sparkleholic Apr 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ jobs:
env:
OPENBLAS_VERSION: 0.3.23
SDE_VERSION: 9.33.0-2024-01-07
VULKAN_VERSION: 1.4.304.1
VULKAN_VERSION: 1.4.309.0

strategy:
matrix:
Expand Down
4 changes: 2 additions & 2 deletions common/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class Value : public std::enable_shared_from_this<Value> {
auto index = key.get<int>();
return array_->at(index < 0 ? array_->size() + index : index);
} else if (object_) {
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
auto it = object_->find(key.primitive_);
if (it == object_->end()) return Value();
return it->second;
Expand All @@ -249,7 +249,7 @@ class Value : public std::enable_shared_from_this<Value> {
}
void set(const Value& key, const Value& value) {
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
(*object_)[key.primitive_] = value;
}
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
Expand Down
11 changes: 4 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3557,8 +3557,8 @@ def set_gguf_parameters(self):
head_size = hidden_size // num_attention_heads
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32)
time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64)

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
Expand Down Expand Up @@ -5146,10 +5146,7 @@ def set_vocab(self):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"]

self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
Expand All @@ -5175,7 +5172,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
n_embd = self.hparams["hidden_size"]
head_dim = self.hparams.get("head_dim", n_embd // n_head)
head_dim = self.hparams.get("head_dim") or n_embd // n_head

output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)

Expand Down
39 changes: 7 additions & 32 deletions docs/backend/SYCL.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:

- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*.
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*.
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.

Expand Down Expand Up @@ -227,16 +227,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices,

**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.


**oneMKL for cuBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* do not contain the cuBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *cuBLAS* backend enabled is thus required to run it on Nvidia GPUs.

```sh
git clone https://github.com/oneapi-src/oneMKL
cd oneMKL
cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_CUBLAS_BACKEND=ON -DTARGET_DOMAINS=blas
cmake --build buildWithCublas --config Release
```

**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:

```sh
Expand All @@ -250,16 +240,6 @@ cmake --build build-nvidia --config Release

**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.

**oneMKL for rocBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* doesn't contain the rocBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *rocBLAS* backend enabled is thus required to run it on AMD GPUs.

```sh
git clone https://github.com/oneapi-src/oneMKL
cd oneMKL
# Find your HIPTARGET with rocminfo, under the key 'Name:'
cmake -B buildWithrocBLAS -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_ROCBLAS_BACKEND=ON -DHIPTARGETS=${HIPTARGET} -DTARGET_DOMAINS=blas
cmake --build buildWithrocBLAS --config Release
```

3. **Verify installation and environment**

In order to check the available SYCL devices on the machine, please use the `sycl-ls` command.
Expand Down Expand Up @@ -324,13 +304,10 @@ cmake --build build --config Release -j -v

#### Nvidia GPU

```sh
# Export relevant ENV variables
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LD_LIBRARY_PATH
export LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LIBRARY_PATH
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithCublas/include:$CPLUS_INCLUDE_DIR
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.

```sh
# Build LLAMA with Nvidia BLAS acceleration through SYCL
# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
Expand All @@ -347,12 +324,10 @@ cmake --build build --config Release -j -v

#### AMD GPU

```sh
# Export relevant ENV variables
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LD_LIBRARY_PATH
export LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LIBRARY_PATH
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithrocBLAS/include:$CPLUS_INCLUDE_DIR
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.

```sh
# Build LLAMA with rocBLAS acceleration through SYCL

## AMD
Expand Down
8 changes: 5 additions & 3 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,14 +1396,16 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
const int n_kv = gguf_get_n_kv(ctx);
const int ftype = get_u32(ctx, KEY_FTYPE);
const std::string ftype_str = get_ftype(ftype);
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
const std::string description = gguf_get_val_str(ctx, idx_desc);
const int idx_name = gguf_find_key(ctx, KEY_NAME);
if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
const std::string name = gguf_get_val_str(ctx, idx_name);
LOG_INF("%s: model name: %s\n", __func__, name.c_str());
}
LOG_INF("%s: description: %s\n", __func__, description.c_str());
const int idx_desc = gguf_find_key(ctx, KEY_DESCRIPTION);
if (idx_desc != -1) { // ditto
const std::string description = gguf_get_val_str(ctx, idx_desc);
LOG_INF("%s: description: %s\n", __func__, description.c_str());
}
LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);
Expand Down
8 changes: 5 additions & 3 deletions examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,11 +699,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const std::string voice_data = audio_data;

auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");

std::ostringstream tokens_oss;
for (size_t i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
tokens_oss << tmp[i] << ", ";
}
printf("\n\n");
LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());

prompt_add(prompt_inp, tmp);
#else
prompt_add(prompt_inp, llama_tokens {
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh"
Expand Down Expand Up @@ -2296,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
Expand Down Expand Up @@ -3193,6 +3201,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
Expand Down
151 changes: 151 additions & 0 deletions ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include "ssm-conv.cuh"

template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);

const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);

float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

for (int i = 0; i < n_t; i++) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}

template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int bidz = blockIdx.z;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);

const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);

float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

#pragma unroll
for (int i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
}

static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
const int n_s, cudaStream_t stream) {
const int threads = 128;
GGML_ASSERT(nr % threads == 0);

if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
n_s);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
}
} else {
if (nc == 4) {
const int split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t>
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
}
}
}

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight

const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch

GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));

const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/ssm-conv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading