diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bb3eb977cecfe..75b76e593bbc0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3793,8 +3793,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index 4aef724e907a5..ff2812d3d7527 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,5 +1,6 @@ #version 450 +#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" @@ -12,6 +13,6 @@ void main() { return; } - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + const float val = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8a2ce321dc42d..9c207f1e46cff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -802,9 +802,6 @@ void process_shaders() { string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -819,6 +816,9 @@ void process_shaders() { std::string suffix = rte ? "_rte" : ""; string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + + string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); } string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});