Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8943,6 +8943,13 @@ def set_vocab(self):
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS

# TODO: remove once MXFP4 is supported more generally
def dequant_model(self):
quant_config = self.hparams.get("quantization_config")
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
return
return super().dequant_model()

def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
Expand Down
87 changes: 7 additions & 80 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }

#define GGML_VK_MAX_NODES 8192

#define MAX_VK_BUFFERS 256

#define VK_CHECK(err, msg) \
do { \
vk::Result err_ = (err); \
Expand Down Expand Up @@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector {
std::vector<vk_semaphore> tl_semaphores;
std::vector<vk_semaphore> semaphores;
std::vector<vk::Event> events;
std::vector<vk_buffer> temp_buffers;
std::vector<vk_context> contexts;
};

Expand Down Expand Up @@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context {
// and set to true after the buffer contents are consumed.
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;

vk_buffer buffer_pool[MAX_VK_BUFFERS];

vk_context_ref compute_ctx;
vk_context_ref transfer_ctx;

Expand Down Expand Up @@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);

ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}

ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);

Expand Down Expand Up @@ -5144,71 +5144,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
}

static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
VK_LOG_MEMORY("ggml_vk_pool_malloc");

int best_i = -1;
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
int worst_i = -1;
size_t worst_size = 0; //largest unused buffer seen so far
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer &b = ctx->buffer_pool[i];
if (b != nullptr && b->size >= size && b->size < best_size) {
best_i = i;
best_size = b->size;
}
if (b != nullptr && b->size > worst_size) {
worst_i = i;
worst_size = b->size;
}
}
if(best_i != -1) {
//found the smallest buffer that fits our needs
vk_buffer b = ctx->buffer_pool[best_i];
ctx->buffer_pool[best_i].reset();
return b;
}
if(worst_i != -1) {
//no buffer that fits our needs, resize largest one to save memory
vk_buffer& b = ctx->buffer_pool[worst_i];
ggml_vk_destroy_buffer(b);
}

return ggml_vk_create_buffer_device(ctx->device, size);
}

static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer& b = ctx->buffer_pool[i];
if (b == nullptr) {
b = buffer;
return;
}
}
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
ggml_vk_destroy_buffer(buffer);
}

// Returns an available temporary buffer that may only be used temporarily, it will be reused
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
// Try to find existing temp buffer with enough capacity
for (auto& buffer : ctx->gc.temp_buffers) {
if (buffer->size >= size) {
return buffer;
}
}

VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");

// Otherwise create new buffer
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
ctx->gc.temp_buffers.push_back(buf);

return buf;
}

static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size,
Expand Down Expand Up @@ -11789,10 +11724,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
// Clean up after graph processing is done
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
for (auto& buffer : ctx->gc.temp_buffers) {
ggml_vk_pool_free(ctx, buffer);
}
ctx->gc.temp_buffers.clear();
ctx->prealloc_y_last_pipeline_used = {};

ctx->unsynced_nodes_written.clear();
Expand Down Expand Up @@ -11835,10 +11766,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
ctx->prealloc_y_last_pipeline_used = nullptr;

for (auto& buffer : ctx->buffer_pool) {
ggml_vk_destroy_buffer(buffer);
}

ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0;
ctx->prealloc_size_split_k = 0;
Expand Down
51 changes: 33 additions & 18 deletions ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#version 450

#extension GL_EXT_control_flow_attributes : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif

#include "types.glsl"

Expand Down Expand Up @@ -84,35 +87,47 @@ void main() {
}

barrier();
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % (w >> 1)) +
(D_STATE * (tid / (w >> 1))) +
j * D_STATE * (D_STATE / (w >> 1));
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + (w >> 1)];
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
}
barrier();
}

[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);

uint lane = tid % SUBGROUP_SIZE;

[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
#endif

if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ void process_shaders() {
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});

string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});

string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});

Expand Down
2 changes: 2 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17965,6 +17965,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);

res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
Expand Down