diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9552cfc4..9e38ae6e 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2532,14 +2532,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img); sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img); - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { - mask_channels = 8 * 8; // flatten the whole mask + mask_channels = vae_scale_factor * vae_scale_factor; // flatten the whole mask } else if (sd_ctx->sd->version == VERSION_FLEX_2) { - mask_channels = 1 + init_latent->ne[2]; + mask_channels = 1 + sd_ctx->sd->get_latent_channel(); } ggml_tensor* masked_latent = nullptr; @@ -2548,8 +2546,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_ext_tensor_apply_mask(init_img, mask_img, masked_img); masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } else { // mask after vae + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); ggml_ext_tensor_apply_mask(init_latent, mask_img, masked_latent, 0.); } @@ -2590,9 +2590,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g for (int k = 0; k < masked_latent->ne[2]; k++) { ggml_ext_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); } + } else { + float m = ggml_ext_tensor_get_f32(mask_img, mx, my); + ggml_ext_tensor_set_f32(concat_latent, m, ix, iy, 0); + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_ext_tensor_get_f32(masked_latent, ix, iy, k); + ggml_ext_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); + } } } } + } else { + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } {