diff --git a/common.hpp b/common.hpp index d32167145..4a891bc8b 100644 --- a/common.hpp +++ b/common.hpp @@ -28,7 +28,7 @@ class DownSampleBlock : public GGMLBlock { if (vae_downsample) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = sd_pad(ctx, x, 1, 1, 0, 0); x = conv->forward(ctx, x); } else { auto conv = std::dynamic_pointer_cast(blocks["op"]); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ff36cea25..cd6310736 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -113,6 +113,7 @@ struct SDParams { bool diffusion_flash_attn = false; bool diffusion_conv_direct = false; bool vae_conv_direct = false; + bool circular_pad = false; bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; @@ -183,6 +184,7 @@ void print_params(SDParams params) { printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false"); printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false"); printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false"); + printf(" circular padding: %s\n", params.circular_pad ? "true" : "false"); printf(" control_strength: %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); @@ -304,6 +306,7 @@ void print_usage(int argc, const char* argv[]) { printf(" This might crash if it is not supported by the backend.\n"); printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)\n"); printf(" This might crash if it is not supported by the backend.\n"); + printf(" --circular use circular padding for convolutions and pad ops\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color colors the logging tags according to level\n"); @@ -573,6 +576,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--diffusion-fa", "", true, ¶ms.diffusion_flash_attn}, {"", "--diffusion-conv-direct", "", true, ¶ms.diffusion_conv_direct}, {"", "--vae-conv-direct", "", true, ¶ms.vae_conv_direct}, + {"", "--circular", "", true, ¶ms.circular_pad}, {"", "--canny", "", true, ¶ms.canny_preprocess}, {"-v", "--verbose", "", true, ¶ms.verbose}, {"", "--color", "", true, ¶ms.color}, @@ -1386,6 +1390,7 @@ int main(int argc, const char* argv[]) { params.diffusion_flash_attn, params.diffusion_conv_direct, params.vae_conv_direct, + params.circular_pad, params.force_sdxl_vae_conv_scale, params.chroma_use_dit_mask, params.chroma_use_t5_mask, diff --git a/flux.hpp b/flux.hpp index 2ed410419..355184be2 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1,6 +1,7 @@ #ifndef __FLUX_HPP__ #define __FLUX_HPP__ +#include #include #include "ggml_extend.hpp" @@ -18,7 +19,7 @@ namespace Flux { blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { // x: [..., in_dim] // return: [..., hidden_dim] auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); @@ -36,7 +37,7 @@ namespace Flux { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { ggml_type wtype = GGML_TYPE_F32; params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -47,7 +48,7 @@ namespace Flux { : hidden_size(hidden_size), eps(eps) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { struct ggml_tensor* w = params["scale"]; x = ggml_rms_norm(ctx, x, eps); x = ggml_mul(ctx, x, w); @@ -136,11 +137,11 @@ namespace Flux { }; struct ModulationOut { - ggml_tensor* shift = NULL; - ggml_tensor* scale = NULL; - ggml_tensor* gate = NULL; + ggml_tensor* shift = nullptr; + ggml_tensor* scale = nullptr; + ggml_tensor* gate = nullptr; - ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) + ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr) : shift(shift), scale(scale), gate(gate) {} ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { @@ -259,7 +260,7 @@ namespace Flux { struct ggml_tensor* txt, struct ggml_tensor* vec, struct ggml_tensor* pe, - struct ggml_tensor* mask = NULL) { + struct ggml_tensor* mask = nullptr) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] @@ -398,7 +399,7 @@ namespace Flux { ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { int64_t offset = 3 * idx; - return ModulationOut(ctx, vec, offset); + return {ctx, vec, offset}; } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -406,7 +407,7 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* vec, struct ggml_tensor* pe, - struct ggml_tensor* mask = NULL) { + struct ggml_tensor* mask = nullptr) { // x: [N, n_token, hidden_size] // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] @@ -485,7 +486,7 @@ namespace Flux { auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] // No gate - return ModulationOut(shift, scale, NULL); + return {shift, scale, nullptr}; } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -664,7 +665,7 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - struct ggml_tensor* mod_index_arange = NULL, + struct ggml_tensor* mod_index_arange = nullptr, std::vector skip_layers = {}) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); @@ -672,7 +673,7 @@ namespace Flux { img = img_in->forward(ctx, img); struct ggml_tensor* vec; - struct ggml_tensor* txt_img_mask = NULL; + struct ggml_tensor* txt_img_mask = nullptr; if (params.is_chroma) { int64_t mod_index_length = 344; auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); @@ -681,7 +682,7 @@ namespace Flux { // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // ggml_arange tot working on a lot of backends, precomputing it on CPU instead - GGML_ASSERT(arange != NULL); + GGML_ASSERT(arange != nullptr); auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] // Batch broadcast (will it ever be useful) @@ -695,7 +696,7 @@ namespace Flux { vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] vec = approx->forward(ctx, vec); // [344, N, hidden_size] - if (y != NULL) { + if (y != nullptr) { txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); } } else { @@ -703,7 +704,7 @@ namespace Flux { auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); + GGML_ASSERT(guidance != nullptr); auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); // bf16 and fp16 result is different auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); @@ -775,14 +776,14 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - struct ggml_tensor* mod_index_arange = NULL, + struct ggml_tensor* mod_index_arange = nullptr, std::vector ref_latents = {}, std::vector skip_layers = {}) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps // context: (N, L, D) - // c_concat: NULL, or for (N,C+M, H, W) for Fill + // c_concat: nullptr, or for (N,C+M, H, W) for Fill // y: (N, adm_in_channels) tensor of class labels // guidance: (N,) // pe: (L, d_head/2, 2, 2) @@ -801,7 +802,7 @@ namespace Flux { uint64_t img_tokens = img->ne[1]; if (params.version == VERSION_FLUX_FILL) { - GGML_ASSERT(c_concat != NULL); + GGML_ASSERT(c_concat != nullptr); ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -810,7 +811,7 @@ namespace Flux { img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); } else if (params.version == VERSION_FLEX_2) { - GGML_ASSERT(c_concat != NULL); + GGML_ASSERT(c_concat != nullptr); ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); @@ -825,7 +826,7 @@ namespace Flux { img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { - GGML_ASSERT(c_concat != NULL); + GGML_ASSERT(c_concat != nullptr); ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0); control = patchify(ctx, control, patch_size); @@ -924,7 +925,7 @@ namespace Flux { flux.init(params_ctx, tensor_types, prefix); } - std::string get_desc() { + std::string get_desc() override { return "flux"; } @@ -944,18 +945,18 @@ namespace Flux { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - struct ggml_tensor* mod_index_arange = NULL; + struct ggml_tensor* mod_index_arange = nullptr; x = to_backend(x); context = to_backend(context); - if (c_concat != NULL) { + if (c_concat != nullptr) { c_concat = to_backend(c_concat); } if (flux_params.is_chroma) { guidance = ggml_set_f32(guidance, 0); if (!use_mask) { - y = NULL; + y = nullptr; } // ggml_arange is not working on some backends, precompute it @@ -987,7 +988,7 @@ namespace Flux { auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); - // pe->data = NULL; + // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); struct ggml_tensor* out = flux.forward(compute_ctx, @@ -1017,8 +1018,8 @@ namespace Flux { struct ggml_tensor* guidance, std::vector ref_latents = {}, bool increase_ref_index = false, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr, std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] @@ -1035,11 +1036,11 @@ namespace Flux { void test() { struct ggml_init_params params; params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB - params.mem_buffer = NULL; + params.mem_buffer = nullptr; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); - GGML_ASSERT(work_ctx != NULL); + GGML_ASSERT(work_ctx != nullptr); { // cpu f16: @@ -1063,10 +1064,10 @@ namespace Flux { ggml_set_f32(y, 0.01f); // print_ggml_tensor(y); - struct ggml_tensor* out = NULL; + struct ggml_tensor* out = nullptr; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, y, guidance, {}, false, &out, work_ctx); + compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1078,7 +1079,7 @@ namespace Flux { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_Q8_0; - std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, false)); + std::shared_ptr flux = std::make_shared(backend, false); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/ggml b/ggml index 7bffd79a4..55c79c624 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 7bffd79a4bec72e9a3bfbedb582a218b84401c13 +Subproject commit 55c79c6249dbc5e3ac8cd82556861608a6fd425e diff --git a/ggml_extend.hpp b/ggml_extend.hpp index d8df0d8f6..9699b12cd 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include "ggml-alloc.h" #include "ggml-backend.h" @@ -586,6 +587,51 @@ __STATIC_INLINE__ void ggml_tensor_clamp(struct ggml_tensor* src, float min, flo } } + + +inline std::atomic& sd_circular_padding_flag() { + static std::atomic enabled{false}; + return enabled; +} + +inline void sd_set_circular_padding_enabled(bool enabled) { + sd_circular_padding_flag().store(enabled, std::memory_order_relaxed); +} + +inline bool sd_is_circular_padding_enabled() { + return sd_circular_padding_flag().load(std::memory_order_relaxed); +} + +__STATIC_INLINE__ struct ggml_tensor* sd_pad(struct ggml_context* ctx, + struct ggml_tensor* a, + int p0, + int p1, + int p2, + int p3) { + if (sd_is_circular_padding_enabled()) { + return ggml_pad_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); + } + else { + return ggml_pad(ctx, a, p0, p1, p2, p3); + } +} + +__STATIC_INLINE__ struct ggml_tensor* sd_pad_ext(struct ggml_context* ctx, + struct ggml_tensor* a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3) { + if (sd_is_circular_padding_enabled()) { + return ggml_pad_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + } + return ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); +} + __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, @@ -986,10 +1032,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx, if (scale != 1.f) { x = ggml_scale(ctx, x, scale); } + const bool use_circular = sd_is_circular_padding_enabled(); + LOG_DEBUG("use circular conv %d", use_circular ? 1 : 0); + const bool is_depthwise = (w->ne[2] == 1 && x->ne[2] == w->ne[3]); if (direct) { - x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1); + if (use_circular) { + if (is_depthwise) { + x = ggml_conv_2d_dw_direct_circular(ctx, w, x, s0, s1, p0, p1, d0, d1); + } else { + x = ggml_conv_2d_direct_circular(ctx, w, x, s0, s1, p0, p1, d0, d1); + } + } else { + x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1); + } } else { - x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); + if (use_circular) { + x = ggml_conv_2d_circular(ctx, w, x, s0, s1, p0, p1, d0, d1); + } else { + x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); + } } if (scale != 1.f) { x = ggml_scale(ctx, x, 1.f / scale); @@ -1190,7 +1251,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* { if (kv_pad != 0) { - k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); + k_in = sd_pad(ctx, k_in, 0, kv_pad, 0, 0); } if (kv_scale != 1.0f) { k_in = ggml_scale(ctx, k_in, kv_scale); @@ -1200,7 +1261,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N); if (kv_pad != 0) { - v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); + v_in = sd_pad(ctx, v_in, 0, kv_pad, 0, 0); } if (kv_scale != 1.0f) { v_in = ggml_scale(ctx, v_in, kv_scale); @@ -1223,7 +1284,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1]; } if (mask_pad > 0) { - mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0); + mask_in = sd_pad(ctx, mask_in, 0, mask_pad, 0, 0); } mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); } diff --git a/mmdit.hpp b/mmdit.hpp index d9d19340c..1b3f2276f 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -80,7 +80,7 @@ struct PatchEmbed : public GGMLBlock { int64_t H = x->ne[1]; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode + x = sd_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode } x = proj->forward(ctx, x); @@ -997,4 +997,4 @@ struct MMDiTRunner : public GGMLRunner { } }; -#endif \ No newline at end of file +#endif diff --git a/qwen_image.hpp b/qwen_image.hpp index ce4e62dce..cc336ff28 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -363,7 +363,7 @@ namespace Qwen { int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + x = sd_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] return x; } @@ -691,4 +691,4 @@ namespace Qwen { } // namespace name -#endif // __QWEN_IMAGE_HPP__ \ No newline at end of file +#endif // __QWEN_IMAGE_HPP__ diff --git a/rope.hpp b/rope.hpp index 295c9a217..82084403e 100644 --- a/rope.hpp +++ b/rope.hpp @@ -1,6 +1,8 @@ #ifndef __ROPE_HPP__ #define __ROPE_HPP__ +#include +#include #include #include "ggml_extend.hpp" @@ -39,32 +41,51 @@ namespace Rope { return flat_vec; } - __STATIC_INLINE__ std::vector> rope(const std::vector& pos, int dim, int theta) { + __STATIC_INLINE__ std::vector> rope(const std::vector& pos, + int dim, + int theta, + const std::vector* wrap_dims = nullptr) { assert(dim % 2 == 0); int half_dim = dim / 2; + std::vector> result(pos.size(), std::vector(half_dim * 4)); + std::vector scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim); std::vector omega(half_dim); for (int i = 0; i < half_dim; ++i) { - omega[i] = 1.0 / std::pow(theta, scale[i]); - } - - int pos_size = pos.size(); - std::vector> out(pos_size, std::vector(half_dim)); - for (int i = 0; i < pos_size; ++i) { - for (int j = 0; j < half_dim; ++j) { - out[i][j] = pos[i] * omega[j]; - } + omega[i] = 1.0f / std::pow(theta, scale[i]); } - std::vector> result(pos_size, std::vector(half_dim * 4)); - for (int i = 0; i < pos_size; ++i) { + for (size_t i = 0; i < pos.size(); ++i) { + float position = pos[i]; for (int j = 0; j < half_dim; ++j) { - result[i][4 * j] = std::cos(out[i][j]); - result[i][4 * j + 1] = -std::sin(out[i][j]); - result[i][4 * j + 2] = std::sin(out[i][j]); - result[i][4 * j + 3] = std::cos(out[i][j]); + float omega_val = omega[j]; + float original_angle = position * omega_val; + float angle = original_angle; + int wrap_dim = 0; + if (wrap_dims != nullptr && !wrap_dims->empty()) { + size_t wrap_size = wrap_dims->size(); + // mod batch size since we only store this for one item in the batch + size_t wrap_idx = wrap_size > 0 ? (i % wrap_size) : 0; + wrap_dim = (*wrap_dims)[wrap_idx]; + } + if (wrap_dim > 0) { + constexpr float TWO_PI = 6.28318530717958647692f; + float wrap_f = static_cast(wrap_dim); + float cycles = omega_val * wrap_f / TWO_PI; + // closest periodic harmonic, necessary to ensure things neatly tile + // without this round, things don't tile at the boundaries and you end up + // with the model knowing what is "center" + float rounded = std::round(cycles); + angle = position * TWO_PI * rounded / wrap_f; + } + float sin_val = std::sin(angle); + float cos_val = std::cos(angle); + result[i][4 * j] = cos_val; + result[i][4 * j + 1] = -sin_val; + result[i][4 * j + 2] = sin_val; + result[i][4 * j + 3] = cos_val; } } @@ -122,7 +143,8 @@ namespace Rope { __STATIC_INLINE__ std::vector embed_nd(const std::vector>& ids, int bs, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + const std::vector>* wrap_dims = nullptr) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; int num_axes = axes_dim.size(); @@ -137,7 +159,12 @@ namespace Rope { std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); int offset = 0; for (int i = 0; i < num_axes; ++i) { - std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + const std::vector* axis_wrap_dims = nullptr; + if (wrap_dims != nullptr && i < (int)wrap_dims->size()) { + axis_wrap_dims = &(*wrap_dims)[i]; + } + std::vector> rope_emb = + rope(trans_ids[i], axes_dim[i], theta, axis_wrap_dims); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] for (int b = 0; b < bs; ++b) { for (int j = 0; j < pos_len; ++j) { for (int k = 0; k < rope_emb[0].size(); ++k) { @@ -252,7 +279,46 @@ namespace Rope { int theta, const std::vector& axes_dim) { std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector> wrap_dims; + // This logic simply stores the (pad and patch_adjusted) sizes of images so we can make sure rope correctly tiles + if (sd_is_circular_padding_enabled() && bs > 0 && axes_dim.size() >= 3) { + int pad_h = (patch_size - (h % patch_size)) % patch_size; + int pad_w = (patch_size - (w % patch_size)) % patch_size; + int h_len = (h + pad_h) / patch_size; + int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { + const size_t total_tokens = ids.size(); + // Track per-token wrap lengths for the row/column axes so only spatial tokens become periodic. + wrap_dims.assign(axes_dim.size(), std::vector(total_tokens / bs, 0)); + size_t cursor = context_len; // ignore text tokens + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + wrap_dims[1][cursor + token_i] = h_len; + wrap_dims[2][cursor + token_i] = w_len; + } + cursor += img_tokens; + // For each reference image, store wrap sizes as well + for (ggml_tensor* ref : ref_latents) { + if (ref == nullptr) { + continue; + } + int ref_h = static_cast(ref->ne[1]); + int ref_w = static_cast(ref->ne[0]); + int ref_pad_h = (patch_size - (ref_h % patch_size)) % patch_size; + int ref_pad_w = (patch_size - (ref_w % patch_size)) % patch_size; + int ref_h_len = (ref_h + ref_pad_h) / patch_size; + int ref_w_len = (ref_w + ref_pad_w) / patch_size; + size_t ref_n_tokens = static_cast(ref_h_len) * static_cast(ref_w_len); + for (size_t token_i = 0; token_i < ref_n_tokens; ++token_i) { + wrap_dims[1][cursor + token_i] = ref_h_len; + wrap_dims[2][cursor + token_i] = ref_w_len; + } + cursor += ref_n_tokens; + } + } + } + const std::vector>* wraps_ptr = wrap_dims.empty() ? nullptr : &wrap_dims; + return embed_nd(ids, bs, theta, axes_dim, wraps_ptr); } __STATIC_INLINE__ std::vector> gen_vid_ids(int t, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 87b6a3779..adc007be9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -114,6 +114,7 @@ class StableDiffusionGGML { bool use_tiny_autoencoder = false; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; bool offload_params_to_cpu = false; + bool circular_pad = false; bool stacked_id = false; bool is_using_v_parameterization = false; @@ -187,6 +188,11 @@ class StableDiffusionGGML { taesd_path = SAFE_STR(sd_ctx_params->taesd_path); use_tiny_autoencoder = taesd_path.size() > 0; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; + circular_pad = sd_ctx_params->circular_pad; + sd_set_circular_padding_enabled(circular_pad); + if (circular_pad) { + LOG_INFO("Using circular padding for convolutions"); + } if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) { rng = std::make_shared(); @@ -1820,6 +1826,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false; sd_ctx_params->diffusion_flash_attn = false; + sd_ctx_params->circular_pad = false; sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_t5_mask_pad = 1; @@ -1860,6 +1867,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" "diffusion_flash_attn: %s\n" + "circular_pad: %s\n" "chroma_use_dit_mask: %s\n" "chroma_use_t5_mask: %s\n" "chroma_t5_mask_pad: %d\n", @@ -1889,6 +1897,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->diffusion_flash_attn), + BOOL_STR(sd_ctx_params->circular_pad), BOOL_STR(sd_ctx_params->chroma_use_dit_mask), BOOL_STR(sd_ctx_params->chroma_use_t5_mask), sd_ctx_params->chroma_t5_mask_pad); diff --git a/stable-diffusion.h b/stable-diffusion.h index a891a58f1..1512c7192 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -164,6 +164,7 @@ typedef struct { bool diffusion_flash_attn; bool diffusion_conv_direct; bool vae_conv_direct; + bool circular_pad; bool force_sdxl_vae_conv_scale; bool chroma_use_dit_mask; bool chroma_use_t5_mask; diff --git a/wan.hpp b/wan.hpp index 31fa90b3a..8d2e29641 100644 --- a/wan.hpp +++ b/wan.hpp @@ -73,7 +73,7 @@ namespace WAN { lp2 -= (int)cache_x->ne[2]; } - x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); + x = sd_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); return ggml_nn_conv_3d(ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), 0, 0, 0, @@ -172,7 +172,7 @@ namespace WAN { 2); } if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep - cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); + cache_x = sd_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) } if (chunk_idx == 1) { @@ -198,9 +198,9 @@ namespace WAN { } else if (mode == "upsample3d") { x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "downsample2d") { - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = sd_pad(ctx, x, 1, 1, 0, 0); } else if (mode == "downsample3d") { - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = sd_pad(ctx, x, 1, 1, 0, 0); } x = resample_1->forward(ctx, x); x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) @@ -260,7 +260,7 @@ namespace WAN { int64_t pad_t = (factor_t - T % factor_t) % factor_t; - x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); + x = sd_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); T = x->ne[2]; x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] @@ -1838,7 +1838,7 @@ namespace WAN { int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size); int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size); int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size); - x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] + x = sd_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w] return x; }