@@ -3391,7 +3391,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
33913391 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
33923392 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33933393
3394- CREATE_UNARY(exp)
33953394 CREATE_UNARY(gelu)
33963395 CREATE_UNARY(gelu_erf)
33973396 CREATE_UNARY(gelu_quick)
@@ -3403,6 +3402,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
34033402 CREATE_UNARY(hardswish)
34043403#undef CREATE_UNARY
34053404
3405+ #define CREATE_UNARY_RTE(name) \
3406+ if (device->float_controls_rte_fp16) { \
3407+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3408+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3409+ } else { \
3410+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3411+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3412+ }
3413+ CREATE_UNARY_RTE(exp)
3414+ #undef CREATE_UNARY_RTE
3415+
34063416#define CREATE_GLU(name) \
34073417 if (device->float_controls_rte_fp16) { \
34083418 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
0 commit comments