From 3bab484f491d2a74f0896e18e79ecfb6c5fd9fd7 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Sep 2025 21:20:21 +0800 Subject: [PATCH 1/2] simplify the logic of pm id image loading --- README.md | 2 +- docs/photo_maker.md | 9 +- examples/cli/main.cpp | 166 +++++++++++++++++++------------ ggml_extend.hpp | 36 +------ pmid.hpp | 221 +----------------------------------------- stable-diffusion.cpp | 131 ++++++++++--------------- stable-diffusion.h | 12 ++- upscaler.cpp | 2 +- util.cpp | 104 -------------------- util.h | 5 - wan.hpp | 2 +- 11 files changed, 178 insertions(+), 512 deletions(-) diff --git a/README.md b/README.md index afa0ec357..c4f169fa8 100644 --- a/README.md +++ b/README.md @@ -299,7 +299,7 @@ arguments: --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model --embd-dir [EMBEDDING_PATH] path to embeddings - --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings + --photo-maker path to PHOTOMAKER model --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now diff --git a/docs/photo_maker.md b/docs/photo_maker.md index 8305a33bd..dae2c9b2a 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -6,16 +6,15 @@ You can use [PhotoMaker](https://github.com/TencentARC/PhotoMaker) to personaliz Download PhotoMaker model file (in safetensor format) [here](https://huggingface.co/bssrdf/PhotoMaker). The official release of the model file (in .bin format) does not work with ```stablediffusion.cpp```. -- Specify the PhotoMaker model path using the `--stacked-id-embd-dir PATH` parameter. -- Specify the input images path using the `--input-id-images-dir PATH` parameter. - - input images **must** have the same width and height for preprocessing (to be improved) +- Specify the PhotoMaker model path using the `--photo-maker PATH` parameter. +- Specify the input images path using the `--pm-id-images-dir PATH` parameter. In prompt, make sure you have a class word followed by the trigger word ```"img"``` (hard-coded for now). The class word could be one of ```"man, woman, girl, boy"```. If input ID images contain asian faces, add ```Asian``` before the class word. Another PhotoMaker specific parameter: -- ```--style-ratio (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality). +- ```--pm-style-strength (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality). Other parameters recommended for running Photomaker: @@ -28,7 +27,7 @@ If on low memory GPUs (<= 8GB), recommend running with ```--vae-on-cpu``` option Example: ```bash -bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png +bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --photo-maker ../models/photomaker-v1.safetensors --pm-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --pm-style-strength 10 --vae-on-cpu --steps 50 ``` ## PhotoMaker Version 2 diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af03c1579..0ba3acb75 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -66,8 +66,6 @@ struct SDParams { std::string esrgan_path; std::string control_net_path; std::string embedding_dir; - std::string stacked_id_embed_dir; - std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; std::string tensor_type_rules; std::string lora_model_dir; @@ -82,11 +80,10 @@ struct SDParams { std::string prompt; std::string negative_prompt; - float style_ratio = 20.f; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; std::vector skip_layers = {7, 8, 9}; sd_sample_params_t sample_params; @@ -116,6 +113,12 @@ struct SDParams { bool color = false; int upscale_repeats = 1; + // Photo Maker + std::string photo_maker_path; + std::string pm_id_images_dir; + std::string pm_id_embed_path; + float pm_style_strength = 20.f; + bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; @@ -149,9 +152,10 @@ void print_params(SDParams params) { printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" control_net_path: %s\n", params.control_net_path.c_str()); printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); - printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str()); - printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); - printf(" style ratio: %.2f\n", params.style_ratio); + printf(" photo_maker_path: %s\n", params.photo_maker_path.c_str()); + printf(" pm_id_images_dir: %s\n", params.pm_id_images_dir.c_str()); + printf(" pm_id_embed_path: %s\n", params.pm_id_embed_path.c_str()); + printf(" pm_style_strength: %.2f\n", params.pm_style_strength); printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_image_path: %s\n", params.init_image_path.c_str()); @@ -217,9 +221,6 @@ void print_usage(int argc, const char* argv[]) { printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); - printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); - printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); - printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); @@ -266,7 +267,6 @@ void print_usage(int argc, const char* argv[]) { printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n"); printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); - printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); @@ -301,6 +301,11 @@ void print_usage(int argc, const char* argv[]) { printf(" only enabled if `--high-noise-steps` is set to -1\n"); printf(" --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto)\n"); printf(" --vace-strength wan vace strength\n"); + printf(" --photo-maker path to PHOTOMAKER model\n"); + printf(" --pm-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); + printf(" --pm-id-embed-path [PATH] path to PHOTOMAKER v2 id embed\n"); + printf(" --pm-style-strength strength for keeping PHOTOMAKER input identity (default: 20)\n"); + printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" -v, --verbose print extra info\n"); } @@ -487,12 +492,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--taesd", "", ¶ms.taesd_path}, {"", "--control-net", "", ¶ms.control_net_path}, {"", "--embd-dir", "", ¶ms.embedding_dir}, - {"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir}, {"", "--lora-model-dir", "", ¶ms.lora_model_dir}, {"-i", "--init-img", "", ¶ms.init_image_path}, {"", "--end-img", "", ¶ms.end_image_path}, {"", "--tensor-type-rules", "", ¶ms.tensor_type_rules}, - {"", "--input-id-images-dir", "", ¶ms.input_id_images_path}, + {"", "--photo-maker", "", ¶ms.photo_maker_path}, + {"", "--pm-id-images-dir", "", ¶ms.pm_id_images_dir}, + {"", "--pm-id-embed-path", "", ¶ms.pm_id_embed_path}, {"", "--mask", "", ¶ms.mask_image_path}, {"", "--control-image", "", ¶ms.control_image_path}, {"", "--control-video", "", ¶ms.control_video_path}, @@ -532,7 +538,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--high-noise-skip-layer-end", "", ¶ms.high_noise_sample_params.guidance.slg.layer_end}, {"", "--high-noise-eta", "", ¶ms.high_noise_sample_params.eta}, {"", "--strength", "", ¶ms.strength}, - {"", "--style-ratio", "", ¶ms.style_ratio}, + {"", "--pm-style-strength", "", ¶ms.pm_style_strength}, {"", "--control-strength", "", ¶ms.control_strength}, {"", "--moe-boundary", "", ¶ms.moe_boundary}, {"", "--flow-shift", "", ¶ms.flow_shift}, @@ -1075,14 +1081,58 @@ uint8_t* load_image(const char* image_path, int& width, int& height, int expecte STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_COLORSPACE_SRGB, nullptr); - - // Save resized result + width = resized_width; + height = resized_height; free(image_buffer); image_buffer = resized_image_buffer; } return image_buffer; } +bool load_images_from_dir(const std::string dir, + std::vector& images, + int expected_width = 0, + int expected_height = 0, + int max_image_num = 0, + bool verbose = false) { + if (!fs::exists(dir) || !fs::is_directory(dir)) { + fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str()); + return false; + } + + for (const auto& entry : fs::directory_iterator(dir)) { + if (!entry.is_regular_file()) + continue; + + std::string path = entry.path().string(); + std::string ext = entry.path().extension().string(); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + + if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") { + if (verbose) { + printf("load image %zu from '%s'\n", images.size(), path.c_str()); + } + int width = 0; + int height = 0; + uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return false; + } + + images.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + + if (max_image_num > 0 && images.size() >= max_image_num) { + break; + } + } + } + return true; +} + int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); @@ -1122,6 +1172,7 @@ int main(int argc, const char* argv[]) { sd_image_t control_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL}; sd_image_t mask_image = {(uint32_t)params.width, (uint32_t)params.height, 1, NULL}; std::vector ref_images; + std::vector pmid_images; std::vector control_frames; auto release_all_resources = [&]() { @@ -1129,14 +1180,19 @@ int main(int argc, const char* argv[]) { free(end_image.data); free(control_image.data); free(mask_image.data); - for (auto ref_image : ref_images) { - free(ref_image.data); - ref_image.data = NULL; + for (auto image : ref_images) { + free(image.data); + image.data = NULL; } ref_images.clear(); - for (auto frame : control_frames) { - free(frame.data); - frame.data = NULL; + for (auto image : pmid_images) { + free(image.data); + image.data = NULL; + } + pmid_images.clear(); + for (auto image : control_frames) { + free(image.data); + image.data = NULL; } control_frames.clear(); }; @@ -1225,44 +1281,26 @@ int main(int argc, const char* argv[]) { } if (!params.control_video_path.empty()) { - std::string dir = params.control_video_path; - - if (!fs::exists(dir) || !fs::is_directory(dir)) { - fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str()); + if (!load_images_from_dir(params.control_video_path, + control_frames, + params.width, + params.height, + params.video_frames, + params.verbose)) { release_all_resources(); return 1; } + } - for (const auto& entry : fs::directory_iterator(dir)) { - if (!entry.is_regular_file()) - continue; - - std::string path = entry.path().string(); - std::string ext = entry.path().extension().string(); - std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); - - if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") { - if (params.verbose) { - printf("load control frame %zu from '%s'\n", control_frames.size(), path.c_str()); - } - int width = 0; - int height = 0; - uint8_t* image_buffer = load_image(path.c_str(), width, height, params.width, params.height); - if (image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", path.c_str()); - release_all_resources(); - return 1; - } - - control_frames.push_back({(uint32_t)params.width, - (uint32_t)params.height, - 3, - image_buffer}); - - if (control_frames.size() >= params.video_frames) { - break; - } - } + if (!params.pm_id_images_dir.empty()) { + if (!load_images_from_dir(params.pm_id_images_dir, + pmid_images, + 0, + 0, + 0, + params.verbose)) { + release_all_resources(); + return 1; } } @@ -1283,7 +1321,7 @@ int main(int argc, const char* argv[]) { params.control_net_path.c_str(), params.lora_model_dir.c_str(), params.embedding_dir.c_str(), - params.stacked_id_embed_dir.c_str(), + params.photo_maker_path.c_str(), vae_decode_only, true, params.n_threads, @@ -1334,9 +1372,13 @@ int main(int argc, const char* argv[]) { params.batch_count, control_image, params.control_strength, - params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str(), + { + pmid_images.data(), + (int)pmid_images.size(), + params.pm_id_embed_path.c_str(), + params.pm_style_strength, + }, // pm_params params.vae_tiling_params, }; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9d712772f..a5f61ea46 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -193,17 +193,9 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i return value; } -static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) { - struct ggml_tensor* res = NULL; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - struct ggml_tensor* node = ggml_graph_node(gf, i); - // printf("%d, %s \n", i, ggml_get_name(node)); - if (strcmp(ggml_get_name(node), name) == 0) { - res = node; - break; - } - } - return res; +__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { + float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); + return value; } __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false, const char* mark = "") { @@ -454,28 +446,6 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, } } -__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, - struct ggml_tensor* output, - int idx, - float* mean = NULL, - float* std = NULL) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - int value = *(image_data + iy * width * channels + ix * channels + k); - float pixel_val = value / 255.0f; - if (mean != NULL && std != NULL) - pixel_val = (pixel_val - mean[k]) / std[k]; - ggml_tensor_set_f32(output, pixel_val, ix, iy, k, idx); - } - } - } -} - __STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, struct ggml_tensor* output, bool scale = true) { diff --git a/pmid.hpp b/pmid.hpp index 5e9b0d5b2..d7daa4196 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -42,41 +42,6 @@ struct FuseBlock : public GGMLBlock { } }; -/* -class QFormerPerceiver(nn.Module): - def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): - super().__init__() - - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.use_residual = use_residual - print(cross_attention_dim*num_tokens) - self.token_proj = nn.Sequential( - nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), - nn.GELU(), - nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), - ) - self.token_norm = nn.LayerNorm(cross_attention_dim) - self.perceiver_resampler = FacePerceiverResampler( - dim=cross_attention_dim, - depth=4, - dim_head=128, - heads=cross_attention_dim // 128, - embedding_dim=embedding_dim, - output_dim=cross_attention_dim, - ff_mult=4, - ) - - def forward(self, x, last_hidden_state): - x = self.token_proj(x) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - x = self.token_norm(x) # cls token - out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens - if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out - return out -*/ - struct PMFeedForward : public GGMLBlock { // network hparams int dim; @@ -122,17 +87,8 @@ struct PerceiverAttention : public GGMLBlock { int64_t ne[4]; for (int i = 0; i < 4; ++i) ne[i] = x->ne[i]; - // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); - // printf("heads = %d \n", heads); - // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, - // x->nb[1], x->nb[2], x->nb[3], 0); x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]); - // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], - // x->nb[1], x->nb[2], x->nb[3], 0); - // x = ggml_cont(ctx, x); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); - // print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: "); - // x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads); return x; } @@ -269,17 +225,6 @@ struct QFormerPerceiver : public GGMLBlock { 4)); } - /* - def forward(self, x, last_hidden_state): - x = self.token_proj(x) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - x = self.token_norm(x) # cls token - out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens - if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out - return out - */ - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* last_hidden_state) { @@ -299,113 +244,6 @@ struct QFormerPerceiver : public GGMLBlock { } }; -/* -class FacePerceiverResampler(torch.nn.Module): - def __init__( - self, - *, - dim=768, - depth=4, - dim_head=64, - heads=16, - embedding_dim=1280, - output_dim=768, - ff_mult=4, - ): - super().__init__() - - self.proj_in = torch.nn.Linear(embedding_dim, dim) - self.proj_out = torch.nn.Linear(dim, output_dim) - self.norm_out = torch.nn.LayerNorm(output_dim) - self.layers = torch.nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - torch.nn.ModuleList( - [ - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - FeedForward(dim=dim, mult=ff_mult), - ] - ) - ) - - def forward(self, latents, x): - x = self.proj_in(x) - for attn, ff in self.layers: - latents = attn(x, latents) + latents - latents = ff(latents) + latents - latents = self.proj_out(latents) - return self.norm_out(latents) -*/ - -/* - -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - -def reshape_tensor(x, heads): - bs, length, width = x.shape - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs, heads, length, -1) - return x - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - latents = self.norm2(latents) - - b, l, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - - return self.to_out(out) - -*/ - struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -425,31 +263,13 @@ struct FuseModule : public GGMLBlock { auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - // print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: "); - // print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: "); - - // auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); - // auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - // print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: "); - // print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: "); - // concat is along dim 2 - // auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: "); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); - // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); - // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); - // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); - // stacked_id_embeds = ggml_nn_layer_norm(ctx, stacked_id_embeds, ln_w, ln_b); stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds); stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); - return stacked_id_embeds; } @@ -464,21 +284,14 @@ struct FuseModule : public GGMLBlock { struct ggml_tensor* valid_id_embeds = id_embeds; // # slice out the image token embeddings - // print_ggml_tensor(class_tokens_mask_pos, false); ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos"); ggml_set_name(prompt_embeds, "prompt_embeds"); - // print_ggml_tensor(valid_id_embeds, true, "valid_id_embeds"); - // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); - // print_ggml_tensor(left, true, "AA left"); - // print_ggml_tensor(right, true, "AA right"); if (left && right) { stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); @@ -487,15 +300,12 @@ struct FuseModule : public GGMLBlock { } else if (right) { stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } - // print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds"); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds"); + class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); - // print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: "); return updated_prompt_embeds; } }; @@ -551,34 +361,11 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo num_tokens(2) { blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); - /* - cross_attention_dim = 2048 - # projection - self.num_tokens = 2 - self.cross_attention_dim = cross_attention_dim - self.qformer_perceiver = QFormerPerceiver( - id_embeddings_dim, - cross_attention_dim, - self.num_tokens, - )*/ - blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, - cross_attention_dim, - num_tokens)); + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); } - /* - def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): - b, num_inputs, c, h, w = id_pixel_values.shape - id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) - - last_hidden_state = self.vision_model(id_pixel_values)[0] - id_embeds = id_embeds.view(b * num_inputs, -1) - - id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) - id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) - updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) - */ - struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* id_pixel_values, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index f2d1e36ee..f257710d0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -412,7 +412,7 @@ class StableDiffusionGGML { clip_vision->get_param_tensors(tensors); } } else { // SD1.x SD2.x SDXL - if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, model_loader.tensor_storages_types, @@ -510,7 +510,7 @@ class StableDiffusionGGML { } } - if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { pmid_model = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types, @@ -525,15 +525,15 @@ class StableDiffusionGGML { "pmid", version); } - if (strlen(SAFE_STR(sd_ctx_params->stacked_id_embed_dir)) > 0) { - pmid_lora = std::make_shared(backend, sd_ctx_params->stacked_id_embed_dir, ""); + if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { + pmid_lora = std::make_shared(backend, sd_ctx_params->photo_maker_path, ""); if (!pmid_lora->load_from_file(true)) { - LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->stacked_id_embed_dir); + LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); return false; } - LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->stacked_id_embed_dir); - if (!model_loader.init_from_file(sd_ctx_params->stacked_id_embed_dir, "pmid.")) { - LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->stacked_id_embed_dir); + LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path); + if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) { + LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); } else { stacked_id = true; } @@ -1644,7 +1644,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "control_net_path: %s\n" "lora_model_dir: %s\n" "embedding_dir: %s\n" - "stacked_id_embed_dir: %s\n" + "photo_maker_path: %s\n" "vae_decode_only: %s\n" "vae_tiling: %s\n" "free_params_immediately: %s\n" @@ -1671,7 +1671,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->control_net_path), SAFE_STR(sd_ctx_params->lora_model_dir), SAFE_STR(sd_ctx_params->embedding_dir), - SAFE_STR(sd_ctx_params->stacked_id_embed_dir), + SAFE_STR(sd_ctx_params->photo_maker_path), BOOL_STR(sd_ctx_params->vae_decode_only), BOOL_STR(sd_ctx_params->free_params_immediately), sd_ctx_params->n_threads, @@ -1747,8 +1747,8 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->seed = -1; sd_img_gen_params->batch_count = 1; sd_img_gen_params->control_strength = 0.9f; - sd_img_gen_params->style_strength = 20.f; sd_img_gen_params->normalize_input = false; + sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; } @@ -1769,15 +1769,13 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "sample_params: %s\n" "strength: %.2f\n" "seed: %" PRId64 - "VAE tiling:" - "\n" "batch_count: %d\n" "ref_images_count: %d\n" "increase_ref_index: %s\n" "control_strength: %.2f\n" - "style_strength: %.2f\n" "normalize_input: %s\n" - "input_id_images_path: %s\n", + "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" + "VAE tiling: %s\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, @@ -1786,14 +1784,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { SAFE_STR(sample_params_str), sd_img_gen_params->strength, sd_img_gen_params->seed, - BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), sd_img_gen_params->batch_count, sd_img_gen_params->ref_images_count, BOOL_STR(sd_img_gen_params->increase_ref_index), sd_img_gen_params->control_strength, - sd_img_gen_params->style_strength, BOOL_STR(sd_img_gen_params->normalize_input), - SAFE_STR(sd_img_gen_params->input_id_images_path)); + sd_img_gen_params->pm_params.style_strength, + sd_img_gen_params->pm_params.id_images_count, + SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), + BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); free(sample_params_str); return buf; } @@ -1872,9 +1871,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int batch_count, sd_image_t control_image, float control_strength, - float style_ratio, bool normalize_input, - std::string input_id_images_path, + sd_pm_params_t pm_params, std::vector ref_latents, bool increase_ref_index, ggml_tensor* concat_latent = NULL, @@ -1915,67 +1913,46 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } // preprocess input id images - std::vector input_id_images; bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; - if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { - std::vector img_files = get_files_from_dir(input_id_images_path); - for (std::string img_file : img_files) { - int c = 0; - int width, height; - if (ends_with(img_file, "safetensors")) { - continue; - } - uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); - if (input_image_buffer == NULL) { - LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); - continue; - } else { - LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str()); - } - sd_image_t* input_image = NULL; - input_image = new sd_image_t{(uint32_t)width, - (uint32_t)height, - 3, - input_image_buffer}; - input_image = preprocess_id_image(input_image); - if (input_image == NULL) { - LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str()); - continue; - } - input_id_images.push_back(input_image); - } - } - if (input_id_images.size() > 0) { - sd_ctx->sd->pmid_model->style_strength = style_ratio; - int32_t w = input_id_images[0]->width; - int32_t h = input_id_images[0]->height; - int32_t channels = input_id_images[0]->channel; - int32_t num_input_images = (int32_t)input_id_images.size(); - init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images); - // TODO: move these to somewhere else and be user settable - float mean[] = {0.48145466f, 0.4578275f, 0.40821073f}; - float std[] = {0.26862954f, 0.26130258f, 0.27577711f}; - for (int i = 0; i < num_input_images; i++) { - sd_image_t* init_image = input_id_images[i]; - if (normalize_input) - sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std); - else - sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); + if (pm_params.id_images_count > 0) { + int clip_image_size = 224; + sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength; + + init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count); + + std::vector processed_id_images; + for (int i = 0; i < pm_params.id_images_count; i++) { + sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); + sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size); + free(id_image.data); + id_image.data = NULL; + processed_id_images.push_back(processed_id_image); } + + ggml_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2); + ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3); + }); + + for (auto& image : processed_id_images) { + free(image.data); + image.data = NULL; + } + processed_id_images.clear(); + int64_t t0 = ggml_time_ms(); auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, sd_ctx->sd->n_threads, prompt, clip_skip, width, height, - num_input_images, + pm_params.id_images_count, sd_ctx->sd->diffusion_model->get_adm_in_channels()); id_cond = std::get<0>(cond_tup); class_tokens_mask = std::get<1>(cond_tup); // struct ggml_tensor* id_embeds = NULL; - if (pmv2) { - // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); - id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); + if (pmv2 && pm_params.id_embed_path != nullptr) { + id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); // print_ggml_tensor(id_embeds, true, "id_embeds:"); } id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); @@ -1988,19 +1965,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt); // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); prompt = prompt_text_only; // - // if (sample_steps < 50) { - // LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps); - // sample_steps = 50; - // } + if (sample_steps < 50) { + LOG_WARN("It's recommended to use >= 50 steps for photo maker!"); + } } else { LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); LOG_WARN("Turn off PhotoMaker"); sd_ctx->sd->stacked_id = false; } - for (sd_image_t* img : input_id_images) { - free(img->data); - } - input_id_images.clear(); } // Get learned condition @@ -2248,7 +2220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } struct ggml_init_params params; - params.mem_size = static_cast(1024 * 1024) * 1024; // 1G + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -2430,9 +2402,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->batch_count, sd_img_gen_params->control_image, sd_img_gen_params->control_strength, - sd_img_gen_params->style_strength, sd_img_gen_params->normalize_input, - SAFE_STR(sd_img_gen_params->input_id_images_path), + sd_img_gen_params->pm_params, ref_latents, sd_img_gen_params->increase_ref_index, concat_latent, diff --git a/stable-diffusion.h b/stable-diffusion.h index 1f8c7c259..d1c3c7171 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -136,7 +136,7 @@ typedef struct { const char* control_net_path; const char* lora_model_dir; const char* embedding_dir; - const char* stacked_id_embed_dir; + const char* photo_maker_path; bool vae_decode_only; bool free_params_immediately; int n_threads; @@ -185,6 +185,13 @@ typedef struct { float eta; } sd_sample_params_t; +typedef struct { + sd_image_t* id_images; + int id_images_count; + const char* id_embed_path; + float style_strength; +} sd_pm_params_t; // photo maker + typedef struct { const char* prompt; const char* negative_prompt; @@ -202,9 +209,8 @@ typedef struct { int batch_count; sd_image_t control_image; float control_strength; - float style_strength; bool normalize_input; - const char* input_id_images_path; + sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; } sd_img_gen_params_t; diff --git a/upscaler.cpp b/upscaler.cpp index 7d09d86d2..7e765d77a 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -69,7 +69,7 @@ struct UpscalerGGML { input_image.width, input_image.height, output_width, output_height); struct ggml_init_params params; - params.mem_size = static_cast(1024 * 1024) * 1024; // 1G + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; diff --git a/util.cpp b/util.cpp index b9142e606..5af6b1ec1 100644 --- a/util.cpp +++ b/util.cpp @@ -110,56 +110,6 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { } } -std::vector get_files_from_dir(const std::string& dir) { - std::vector files; - - WIN32_FIND_DATA findFileData; - HANDLE hFind; - - char currentDirectory[MAX_PATH]; - GetCurrentDirectory(MAX_PATH, currentDirectory); - - char directoryPath[MAX_PATH]; // this is absolute path - sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str()); - - // Find the first file in the directory - hFind = FindFirstFile(directoryPath, &findFileData); - bool isAbsolutePath = false; - // Check if the directory was found - if (hFind == INVALID_HANDLE_VALUE) { - printf("Unable to find directory. Try with original path \n"); - - char directoryPathAbsolute[MAX_PATH]; - sprintf(directoryPathAbsolute, "%s*", dir.c_str()); - - hFind = FindFirstFile(directoryPathAbsolute, &findFileData); - isAbsolutePath = true; - if (hFind == INVALID_HANDLE_VALUE) { - printf("Absolute path was also wrong.\n"); - return files; - } - } - - // Loop through all files in the directory - do { - // Check if the found file is a regular file (not a directory) - if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - if (isAbsolutePath) { - files.push_back(dir + "\\" + std::string(findFileData.cFileName)); - } else { - files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); - } - } - } while (FindNextFile(hFind, &findFileData) != 0); - - // Close the handle - FindClose(hFind); - - sort(files.begin(), files.end()); - - return files; -} - #else // Unix #include #include @@ -194,27 +144,6 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { return ""; } -std::vector get_files_from_dir(const std::string& dir) { - std::vector files; - - DIR* dp = opendir(dir.c_str()); - - if (dp != nullptr) { - struct dirent* entry; - - while ((entry = readdir(dp)) != nullptr) { - std::string fname = dir + "/" + entry->d_name; - if (!is_directory(fname)) - files.push_back(fname); - } - closedir(dp); - } - - sort(files.begin(), files.end()); - - return files; -} - #endif // get_num_physical_cores is copy from @@ -318,39 +247,6 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } -sd_image_t* preprocess_id_image(sd_image_t* img) { - int shortest_edge = 224; - int size = shortest_edge; - sd_image_t* resized = NULL; - uint32_t w = img->width; - uint32_t h = img->height; - uint32_t c = img->channel; - - // 1. do resize using stb_resize functions - - unsigned char* buf = (unsigned char*)malloc(sizeof(unsigned char) * 3 * size * size); - if (!stbir_resize_uint8(img->data, w, h, 0, - buf, size, size, 0, - c)) { - fprintf(stderr, "%s: resize operation failed \n ", __func__); - return resized; - } - - // 2. do center crop (likely unnecessary due to step 1) - - // 3. do rescale - - // 4. do normalize - - // 3 and 4 will need to be done in float format. - - resized = new sd_image_t{(uint32_t)shortest_edge, - (uint32_t)shortest_edge, - 3, - buf}; - return resized; -} - void pretty_progress(int step, int steps, float time) { if (sd_progress_cb) { sd_progress_cb(step, steps, time, sd_progress_cb_data); diff --git a/util.h b/util.h index 89a990c82..1e8db6e3b 100644 --- a/util.h +++ b/util.h @@ -24,14 +24,9 @@ bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename); -std::vector get_files_from_dir(const std::string& dir); - std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str); std::u32string unicode_value_to_utf32(int unicode_value); - -sd_image_t* preprocess_id_image(sd_image_t* img); - // std::string sd_basename(const std::string& path); typedef struct { diff --git a/wan.hpp b/wan.hpp index cd4d7a591..7e3510a1d 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1219,7 +1219,7 @@ namespace WAN { void test() { struct ggml_init_params params; - params.mem_size = static_cast(1024 * 1024) * 1024; // 1G + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; From 5837419676e592ebc1100c12350929982f970655 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Sep 2025 22:12:29 +0800 Subject: [PATCH 2/2] update docs --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c4f169fa8..451388aa8 100644 --- a/README.md +++ b/README.md @@ -299,9 +299,6 @@ arguments: --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model --embd-dir [EMBEDDING_PATH] path to embeddings - --photo-maker path to PHOTOMAKER model - --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir - --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K) @@ -348,7 +345,6 @@ arguments: --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto) SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) --strength STRENGTH strength for noising/unnoising (default: 0.75) - --style-ratio STYLE-RATIO strength for keeping input identity (default: 20) --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) @@ -383,6 +379,11 @@ arguments: only enabled if `--high-noise-steps` is set to -1 --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto) --vace-strength wan vace strength + --photo-maker path to PHOTOMAKER model + --pm-id-images-dir [DIR] path to PHOTOMAKER input id images dir + --pm-id-embed-path [PATH] path to PHOTOMAKER v2 id embed + --pm-style-strength strength for keeping PHOTOMAKER input identity (default: 20) + --normalize-input normalize PHOTOMAKER input id images -v, --verbose print extra info ```