diff --git a/z_image.hpp b/z_image.hpp index b692a14a5..888a895e2 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -30,7 +30,12 @@ namespace ZImage { JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm) : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) { blocks["qkv"] = std::make_shared(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false); - blocks["out"] = std::make_shared(num_heads * head_dim, hidden_size, false); + float scale = 1.f; +#if GGML_USE_HIP + // Prevent NaN issues with certain ROCm setups + scale = 1.f / 16.f; +#endif + blocks["out"] = std::make_shared(num_heads * head_dim, hidden_size, false, false, false, scale); if (qk_norm) { blocks["q_norm"] = std::make_shared(head_dim); blocks["k_norm"] = std::make_shared(head_dim); @@ -93,7 +98,7 @@ namespace ZImage { #endif // The purpose of the scale here is to prevent NaN issues in certain situations. // For example, when using CUDA but the weights are k-quants. - blocks["w2"] = std::make_shared(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f); + blocks["w2"] = std::make_shared(hidden_dim, dim, false, false, force_prec_f32, scale); blocks["w3"] = std::make_shared(dim, hidden_dim, false); } @@ -667,4 +672,4 @@ namespace ZImage { } // namespace ZImage -#endif // __Z_IMAGE_HPP__ \ No newline at end of file +#endif // __Z_IMAGE_HPP__