Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,18 @@ class FeedForward : public GGMLBlock {
}

// net_1 is nn.Dropout(), skip for inference
float scale = 1.f;
bool force_prec_f32 = false;
float scale = 1.f;
if (precision_fix) {
scale = 1.f / 128.f;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
}
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using Vulkan without enabling force_prec_f32,
// or when using CUDA but the weights are k-quants.
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale));
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, force_prec_f32, scale));
}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
Expand Down
8 changes: 6 additions & 2 deletions qwen_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ namespace Qwen {
blocks["norm_added_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));

float scale = 1.f / 32.f;
float scale = 1.f / 32.f;
bool force_prec_f32 = false;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example when using CUDA but the weights are k-quants (not all prompts).
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, false, scale));
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, force_prec_f32, scale));
// to_out.1 is nn.Dropout

blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));
Expand Down
Loading