Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,12 +774,12 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;

static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT

Expand Down Expand Up @@ -864,7 +864,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
compute_pipeline_create_info.setPNext(&rci);
}

pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
try {
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
pipeline->compiled = true;

{
Expand Down Expand Up @@ -1560,8 +1566,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};

#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
Expand Down
Loading