diff --git a/README.md b/README.md index 5521ec88e..615d892f6 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,11 @@ API and command-line option may change frequently.*** - Image Models - SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) - SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) - - [some SD1.x and SDXL distilled models](./docs/distilled_sd.md) + - [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md) - [SD3/SD3.5](./docs/sd3.md) - [Flux-dev/Flux-schnell](./docs/flux.md) - [Chroma](./docs/chroma.md) + - [Chroma1-Radiance](./docs/chroma_radiance.md) - [Qwen Image](./docs/qwen_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) diff --git a/assets/flux/chroma1-radiance.png b/assets/flux/chroma1-radiance.png new file mode 100644 index 000000000..1dd4a524a Binary files /dev/null and b/assets/flux/chroma1-radiance.png differ diff --git a/docs/chroma_radiance.md b/docs/chroma_radiance.md new file mode 100644 index 000000000..a343520bf --- /dev/null +++ b/docs/chroma_radiance.md @@ -0,0 +1,21 @@ +# How to Use + +## Download weights + +- Download Chroma1-Radiance + - safetensors: https://huggingface.co/lodestones/Chroma1-Radiance/tree/main + - gguf: https://huggingface.co/silveroxides/Chroma1-Radiance-GGUF/tree/main + +- Download t5xxl + - safetensors: https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Examples + +``` +.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Chroma1-Radiance-v0.4-Q8_0.gguf --t5xxl ..\..\ComfyUI\models\clip\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma radiance cpp'" --cfg-scale 4.0 --sampling-method euler -v +``` + +Chroma1-Radiance + + + diff --git a/flux.hpp b/flux.hpp index 355184be2..867a4fafa 100644 --- a/flux.hpp +++ b/flux.hpp @@ -399,7 +399,7 @@ namespace Flux { ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { int64_t offset = 3 * idx; - return {ctx, vec, offset}; + return ModulationOut(ctx, vec, offset); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -549,7 +549,135 @@ namespace Flux { } }; + struct NerfEmbedder : public GGMLBlock { + NerfEmbedder(int64_t in_channels, + int64_t hidden_size_input, + int64_t max_freqs) { + blocks["embedder.0"] = std::make_shared(in_channels + max_freqs * max_freqs, hidden_size_input); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* dct) { + // x: (B, P^2, C) + // dct: (1, P^2, max_freqs^2) + // return: (B, P^2, hidden_size_input) + auto embedder = std::dynamic_pointer_cast(blocks["embedder.0"]); + + dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]); + x = ggml_concat(ctx, x, dct, 0); + x = embedder->forward(ctx, x); + + return x; + } + }; + + struct NerfGLUBlock : public GGMLBlock { + int64_t mlp_ratio; + NerfGLUBlock(int64_t hidden_size_s, + int64_t hidden_size_x, + int64_t mlp_ratio) + : mlp_ratio(mlp_ratio) { + int64_t total_params = 3 * hidden_size_x * hidden_size_x * mlp_ratio; + blocks["param_generator"] = std::make_shared(hidden_size_s, total_params); + blocks["norm"] = std::make_shared(hidden_size_x); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* s) { + // x: (batch_size, n_token, hidden_size_x) + // s: (batch_size, hidden_size_s) + // return: (batch_size, n_token, hidden_size_x) + auto param_generator = std::dynamic_pointer_cast(blocks["param_generator"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + + int64_t batch_size = x->ne[2]; + int64_t hidden_size_x = x->ne[0]; + + auto mlp_params = param_generator->forward(ctx, s); + auto fc_params = ggml_chunk(ctx, mlp_params, 3, 0); + auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); + auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); + auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size); + + fc1_gate = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] + fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f); + fc1_value = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] + fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f); + fc2 = ggml_cont(ctx, ggml_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio] + fc2 = ggml_l2_norm(ctx, fc2, 1e-12f); + + auto res_x = x; + x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x] + + auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] + x1 = ggml_silu_inplace(ctx, x1); + + auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] + + x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio] + + x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x] + + x = ggml_add_inplace(ctx, x, res_x); + + return x; + } + }; + + struct NerfFinalLayer : public GGMLBlock { + NerfFinalLayer(int64_t hidden_size, + int64_t out_channels) { + blocks["norm"] = std::make_shared(hidden_size); + blocks["linear"] = std::make_shared(hidden_size, out_channels); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x) { + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + + x = norm->forward(ctx, x); + x = linear->forward(ctx, x); + + return x; + } + }; + + struct NerfFinalLayerConv : public GGMLBlock { + NerfFinalLayerConv(int64_t hidden_size, + int64_t out_channels) { + blocks["norm"] = std::make_shared(hidden_size); + blocks["conv"] = std::make_shared(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x) { + // x: [N, C, H, W] + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C] + x = norm->forward(ctx, x); + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W] + x = conv->forward(ctx, x); + + return x; + } + }; + + struct ChromaRadianceParams { + int64_t nerf_hidden_size = 64; + int64_t nerf_mlp_ratio = 4; + int64_t nerf_depth = 4; + int64_t nerf_max_freqs = 8; + }; + struct FluxParams { + SDVersion version = VERSION_FLUX; + bool is_chroma = false; + int64_t patch_size = 2; int64_t in_channels = 64; int64_t out_channels = 64; int64_t vec_in_dim = 768; @@ -565,8 +693,8 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; - bool is_chroma = false; - SDVersion version = VERSION_FLUX; + int64_t in_dim = 64; + ChromaRadianceParams chroma_radiance_params; }; struct Flux : public GGMLBlock { @@ -575,53 +703,89 @@ namespace Flux { Flux() {} Flux(FluxParams params) : params(params) { - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); + if (params.version == VERSION_CHROMA_RADIANCE) { + std::pair kernel_size = {(int)params.patch_size, (int)params.patch_size}; + std::pair stride = kernel_size; + + blocks["img_in_patch"] = std::make_shared(params.in_channels, + params.hidden_size, + kernel_size, + stride); + } else { + blocks["img_in"] = std::make_shared(params.in_channels, params.hidden_size, true); + } if (params.is_chroma) { - blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); + blocks["distilled_guidance_layer"] = std::make_shared(params.in_dim, params.hidden_size); } else { - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + blocks["time_in"] = std::make_shared(256, params.hidden_size); + blocks["vector_in"] = std::make_shared(params.vec_in_dim, params.hidden_size); if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["guidance_in"] = std::make_shared(256, params.hidden_size); } } - blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); + blocks["txt_in"] = std::make_shared(params.context_in_dim, params.hidden_size, true); for (int i = 0; i < params.depth; i++) { - blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, - params.num_heads, - params.mlp_ratio, - i, - params.qkv_bias, - params.flash_attn, - params.is_chroma)); + blocks["double_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, + params.num_heads, + params.mlp_ratio, + i, + params.qkv_bias, + params.flash_attn, + params.is_chroma); } for (int i = 0; i < params.depth_single_blocks; i++) { - blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, - params.num_heads, - params.mlp_ratio, - i, - 0.f, - params.flash_attn, - params.is_chroma)); + blocks["single_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, + params.num_heads, + params.mlp_ratio, + i, + 0.f, + params.flash_attn, + params.is_chroma); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); + if (params.version == VERSION_CHROMA_RADIANCE) { + blocks["nerf_image_embedder"] = std::make_shared(params.in_channels, + params.chroma_radiance_params.nerf_hidden_size, + params.chroma_radiance_params.nerf_max_freqs); + + for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) { + blocks["nerf_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size, + params.chroma_radiance_params.nerf_hidden_size, + params.chroma_radiance_params.nerf_mlp_ratio); + } + + blocks["nerf_final_layer_conv"] = std::make_shared(params.chroma_radiance_params.nerf_hidden_size, + params.in_channels); + + } else { + blocks["final_layer"] = std::make_shared(params.hidden_size, 1, params.out_channels, params.is_chroma); + } + } + + struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* x) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + + int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size; + int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + return x; } struct ggml_tensor* patchify(struct ggml_context* ctx, - struct ggml_tensor* x, - int64_t patch_size) { + struct ggml_tensor* x) { // x: [N, C, H, W] // return: [N, h*w, C * patch_size * patch_size] int64_t N = x->ne[3]; int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; - int64_t p = patch_size; - int64_t h = H / patch_size; - int64_t w = W / patch_size; + int64_t p = params.patch_size; + int64_t h = H / params.patch_size; + int64_t w = W / params.patch_size; GGML_ASSERT(h * p == H && w * p == W); @@ -633,18 +797,25 @@ namespace Flux { return x; } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + x = pad_to_patch_size(ctx, x); + x = patchify(ctx, x); + return x; + } + struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t h, - int64_t w, - int64_t patch_size) { + int64_t w) { // x: [N, h*w, C*patch_size*patch_size] // return: [N, C, H, W] int64_t N = x->ne[2]; - int64_t C = x->ne[0] / patch_size / patch_size; - int64_t H = h * patch_size; - int64_t W = w * patch_size; - int64_t p = patch_size; + int64_t C = x->ne[0] / params.patch_size / params.patch_size; + int64_t H = h * params.patch_size; + int64_t W = w * params.patch_size; + int64_t p = params.patch_size; GGML_ASSERT(C * p * p == x->ne[0]); @@ -671,7 +842,10 @@ namespace Flux { auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - img = img_in->forward(ctx, img); + if (img_in) { + img = img_in->forward(ctx, img); + } + struct ggml_tensor* vec; struct ggml_tensor* txt_img_mask = nullptr; if (params.is_chroma) { @@ -682,7 +856,7 @@ namespace Flux { // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // ggml_arange tot working on a lot of backends, precomputing it on CPU instead - GGML_ASSERT(arange != nullptr); + GGML_ASSERT(mod_index_arange != nullptr); auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] // Batch broadcast (will it ever be useful) @@ -749,52 +923,96 @@ namespace Flux { txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] - img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + if (final_layer) { + img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + } + return img; } - struct ggml_tensor* process_img(struct ggml_context* ctx, - struct ggml_tensor* x) { + struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe, + struct ggml_tensor* mod_index_arange = nullptr, + struct ggml_tensor* dct = nullptr, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { + GGML_ASSERT(x->ne[3] == 1); + int64_t W = x->ne[0]; int64_t H = x->ne[1]; - int64_t patch_size = 2; + int64_t C = x->ne[2]; + int64_t patch_size = params.patch_size; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] - return img; - } + auto img = pad_to_patch_size(ctx, x); + auto orig_img = img; - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - struct ggml_tensor* pe, - struct ggml_tensor* mod_index_arange = nullptr, - std::vector ref_latents = {}, - std::vector skip_layers = {}) { - // Forward pass of DiT. - // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) - // timestep: (N,) tensor of diffusion timesteps - // context: (N, L, D) - // c_concat: nullptr, or for (N,C+M, H, W) for Fill - // y: (N, adm_in_channels) tensor of class labels - // guidance: (N,) - // pe: (L, d_head/2, 2, 2) - // return: (N, C, H, W) + auto img_in_patch = std::dynamic_pointer_cast(blocks["img_in_patch"]); + + img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size] + img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size] + img = ggml_cont(ctx, ggml_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size] + + auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size] + + // nerf decode + auto nerf_image_embedder = std::dynamic_pointer_cast(blocks["nerf_image_embedder"]); + auto nerf_final_layer_conv = std::dynamic_pointer_cast(blocks["nerf_final_layer_conv"]); + auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size] + int64_t num_patches = nerf_pixels->ne[1]; + nerf_pixels = ggml_reshape_3d(ctx, + nerf_pixels, + nerf_pixels->ne[0] / C, + C, + nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size] + nerf_pixels = ggml_cont(ctx, ggml_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C] + + auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size] + auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size] + + for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) { + auto block = std::dynamic_pointer_cast(blocks["nerf_blocks." + std::to_string(i)]); + + img_dct = block->forward(ctx, img_dct, nerf_hidden); + } + + img_dct = ggml_cont(ctx, ggml_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] + img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] + img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] + + out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] + + return out; + } + + struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe, + struct ggml_tensor* mod_index_arange = nullptr, + struct ggml_tensor* dct = nullptr, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { GGML_ASSERT(x->ne[3] == 1); int64_t W = x->ne[0]; int64_t H = x->ne[1]; int64_t C = x->ne[2]; - int64_t patch_size = 2; + int64_t patch_size = params.patch_size; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; @@ -816,21 +1034,16 @@ namespace Flux { ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); - masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); - mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); - control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0); - - masked = patchify(ctx, masked, patch_size); - mask = patchify(ctx, mask, patch_size); - control = patchify(ctx, control, patch_size); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); + control = process_img(ctx, control); img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); - ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0); - control = patchify(ctx, control, patch_size); - img = ggml_concat(ctx, img, control, 0); + auto control = process_img(ctx, c_concat); + img = ggml_concat(ctx, img, control, 0); } if (ref_latents.size() > 0) { @@ -849,10 +1062,63 @@ namespace Flux { } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) - out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] - + out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w] return out; } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe, + struct ggml_tensor* mod_index_arange = nullptr, + struct ggml_tensor* dct = nullptr, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // timestep: (N,) tensor of diffusion timesteps + // context: (N, L, D) + // c_concat: nullptr, or for (N,C+M, H, W) for Fill + // y: (N, adm_in_channels) tensor of class labels + // guidance: (N,) + // pe: (L, d_head/2, 2, 2) + // return: (N, C, H, W) + + if (params.version == VERSION_CHROMA_RADIANCE) { + return forward_chroma_radiance(ctx, + backend, + x, + timestep, + context, + c_concat, + y, + guidance, + pe, + mod_index_arange, + dct, + ref_latents, + skip_layers); + } else { + return forward_flux_chroma(ctx, + backend, + x, + timestep, + context, + c_concat, + y, + guidance, + pe, + mod_index_arange, + dct, + ref_latents, + skip_layers); + } + } }; struct FluxRunner : public GGMLRunner { @@ -860,7 +1126,8 @@ namespace Flux { FluxParams flux_params; Flux flux; std::vector pe_vec; - std::vector mod_index_arange_vec; // for cache + std::vector mod_index_arange_vec; + std::vector dct_vec; SDVersion version; bool use_mask = false; @@ -883,6 +1150,9 @@ namespace Flux { flux_params.in_channels = 128; } else if (version == VERSION_FLEX_2) { flux_params.in_channels = 196; + } else if (version == VERSION_CHROMA_RADIANCE) { + flux_params.in_channels = 3; + flux_params.patch_size = 16; } for (auto pair : tensor_types) { std::string tensor_name = pair.first; @@ -933,6 +1203,56 @@ namespace Flux { flux.get_param_tensors(tensors, prefix); } + std::vector fetch_dct_pos(int patch_size, int max_freqs) { + const float PI = 3.14159265358979323846f; + + std::vector pos(patch_size); + for (int i = 0; i < patch_size; ++i) { + pos[i] = static_cast(i) / static_cast(patch_size - 1); + } + + std::vector pos_x(patch_size * patch_size); + std::vector pos_y(patch_size * patch_size); + for (int i = 0; i < patch_size; ++i) { + for (int j = 0; j < patch_size; ++j) { + pos_x[i * patch_size + j] = pos[j]; + pos_y[i * patch_size + j] = pos[i]; + } + } + + std::vector freqs(max_freqs); + for (int i = 0; i < max_freqs; ++i) { + freqs[i] = static_cast(i); + } + + std::vector coeffs(max_freqs * max_freqs); + for (int fx = 0; fx < max_freqs; ++fx) { + for (int fy = 0; fy < max_freqs; ++fy) { + coeffs[fx * max_freqs + fy] = 1.0f / (1.0f + freqs[fx] * freqs[fy]); + } + } + + int num_positions = patch_size * patch_size; + int num_features = max_freqs * max_freqs; + std::vector dct(num_positions * num_features); + + for (int p = 0; p < num_positions; ++p) { + float px = pos_x[p]; + float py = pos_y[p]; + + for (int fx = 0; fx < max_freqs; ++fx) { + float cx = std::cos(px * freqs[fx] * PI); + for (int fy = 0; fy < max_freqs; ++fy) { + float cy = std::cos(py * freqs[fy] * PI); + float val = cx * cy * coeffs[fx * max_freqs + fy]; + dct[p * num_features + (fx * max_freqs + fy)] = val; + } + } + } + + return dct; + } + struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -946,6 +1266,7 @@ namespace Flux { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); struct ggml_tensor* mod_index_arange = nullptr; + struct ggml_tensor* dct = nullptr; // for chroma radiance x = to_backend(x); context = to_backend(context); @@ -976,7 +1297,7 @@ namespace Flux { pe_vec = Rope::gen_flux_pe(x->ne[1], x->ne[0], - 2, + flux_params.patch_size, x->ne[3], context->ne[1], ref_latents, @@ -991,6 +1312,17 @@ namespace Flux { // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); + if (version == VERSION_CHROMA_RADIANCE) { + int64_t patch_size = flux_params.patch_size; + int64_t nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs; + dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs); + dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size); + // dct->data = dct_vec.data(); + // print_ggml_tensor(dct); + // dct->data = nullptr; + set_backend_tensor_data(dct, dct_vec.data()); + } + struct ggml_tensor* out = flux.forward(compute_ctx, runtime_backend, x, @@ -1001,6 +1333,7 @@ namespace Flux { guidance, pe, mod_index_arange, + dct, ref_latents, skip_layers); @@ -1035,7 +1368,7 @@ namespace Flux { void test() { struct ggml_init_params params; - params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_size = static_cast(1024 * 1024) * 1024; // 1GB params.mem_buffer = nullptr; params.no_alloc = false; @@ -1046,22 +1379,25 @@ namespace Flux { // cpu f16: // cuda f16: nan // cuda q8_0: pass - auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); - ggml_set_f32(x, 0.01f); + // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); + // ggml_set_f32(x, 0.01f); + auto x = load_tensor_from_file(work_ctx, "chroma_x.bin"); // print_ggml_tensor(x); - std::vector timesteps_vec(1, 999.f); + std::vector timesteps_vec(1, 1.f); auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); - std::vector guidance_vec(1, 3.5f); + std::vector guidance_vec(1, 0.f); auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); - auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); - ggml_set_f32(context, 0.01f); + // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); + // ggml_set_f32(context, 0.01f); + auto context = load_tensor_from_file(work_ctx, "chroma_context.bin"); // print_ggml_tensor(context); - auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); - ggml_set_f32(y, 0.01f); + // auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); + // ggml_set_f32(y, 0.01f); + auto y = nullptr; // print_ggml_tensor(y); struct ggml_tensor* out = nullptr; @@ -1076,32 +1412,44 @@ namespace Flux { } static void load_from_file_and_test(const std::string& file_path) { - // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_Q8_0; - std::shared_ptr flux = std::make_shared(backend, false); - { - LOG_INFO("loading from '%s'", file_path.c_str()); - - flux->alloc_params_buffer(); - std::map tensors; - flux->get_param_tensors(tensors, "model.diffusion_model"); + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } - ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { - LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); - return; + auto tensor_types = model_loader.tensor_storages_types; + for (auto& item : tensor_types) { + // LOG_DEBUG("%s %u", item.first.c_str(), item.second); + if (ends_with(item.first, "weight")) { + // item.second = model_data_type; } + } - bool success = model_loader.load_tensors(tensors); + std::shared_ptr flux = std::make_shared(backend, + false, + tensor_types, + "model.diffusion_model", + VERSION_CHROMA_RADIANCE, + false, + true); - if (!success) { - LOG_ERROR("load tensors from model loader failed"); - return; - } + flux->alloc_params_buffer(); + std::map tensors; + flux->get_param_tensors(tensors, "model.diffusion_model"); - LOG_INFO("flux model loaded"); + bool success = model_loader.load_tensors(tensors); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; } + + LOG_INFO("flux model loaded"); flux->test(); } }; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 02d82bc09..66797941d 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -954,7 +954,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, if (scale != 1.f) { x = ggml_scale(ctx, x, scale); } - x = ggml_mul_mat(ctx, w, x); + if (x->ne[2] * x->ne[3] > 1024) { + // workaround: avoid ggml cuda error + int64_t ne2 = x->ne[2]; + int64_t ne3 = x->ne[3]; + x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]); + x = ggml_mul_mat(ctx, w, x); + x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3); + } else { + x = ggml_mul_mat(ctx, w, x); + } if (force_prec_f32) { ggml_mul_mat_set_prec(x, GGML_PREC_F32); } diff --git a/model.cpp b/model.cpp index 0a03627f9..da77afedd 100644 --- a/model.cpp +++ b/model.cpp @@ -1778,7 +1778,6 @@ bool ModelLoader::model_is_unet() { SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; - bool input_block_checked = false; bool has_multiple_encoders = false; bool is_unet = false; @@ -1791,12 +1790,12 @@ SDVersion ModelLoader::get_sd_version() { bool has_middle_block_1 = false; for (auto& tensor_storage : tensor_storages) { - if (!(is_xl || is_flux)) { + if (!(is_xl)) { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; - if (input_block_checked) { - break; - } + } + if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { + return VERSION_CHROMA_RADIANCE; } if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { return VERSION_SD3; @@ -1813,22 +1812,19 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { has_img_emb = true; } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || + tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { is_unet = true; if (has_multiple_encoders) { is_xl = true; - if (input_block_checked) { - break; - } } } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) { + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || + tensor_storage.name.find("cond_stage_model.1") != std::string::npos || + tensor_storage.name.find("te.1") != std::string::npos) { has_multiple_encoders = true; if (is_unet) { is_xl = true; - if (input_block_checked) { - break; - } } } if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { @@ -1848,12 +1844,10 @@ SDVersion ModelLoader::get_sd_version() { token_embedding_weight = tensor_storage; // break; } - if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") { - input_block_weight = tensor_storage; - input_block_checked = true; - if (is_flux) { - break; - } + if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || + tensor_storage.name == "model.diffusion_model.img_in.weight" || + tensor_storage.name == "unet.conv_in.weight") { + input_block_weight = tensor_storage; } } if (is_wan) { diff --git a/model.h b/model.h index 65226cd02..f1711e67f 100644 --- a/model.h +++ b/model.h @@ -36,6 +36,7 @@ enum SDVersion { VERSION_FLUX_FILL, VERSION_FLUX_CONTROLS, VERSION_FLEX_2, + VERSION_CHROMA_RADIANCE, VERSION_WAN2, VERSION_WAN2_2_I2V, VERSION_WAN2_2_TI2V, @@ -72,7 +73,11 @@ static inline bool sd_version_is_sd3(SDVersion version) { } static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) { + if (version == VERSION_FLUX || + version == VERSION_FLUX_FILL || + version == VERSION_FLUX_CONTROLS || + version == VERSION_FLEX_2 || + version == VERSION_CHROMA_RADIANCE) { return true; } return false; diff --git a/qwen_image.hpp b/qwen_image.hpp index 2d3cd2307..045863021 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -649,7 +649,7 @@ namespace Qwen { static void load_from_file_and_test(const std::string& file_path) { // cuda q8: pass - // cuda q8 fa: nan + // cuda q8 fa: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_Q8_0; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 802608ea9..f9fe8a5a3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -41,6 +41,7 @@ const char* model_version_to_str[] = { "Flux Fill", "Flux Control", "Flex.2", + "Chroma Radiance", "Wan 2.x", "Wan 2.2 I2V", "Wan 2.2 TI2V", @@ -494,6 +495,9 @@ class StableDiffusionGGML { version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); + } else if (version == VERSION_CHROMA_RADIANCE) { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu); } else if (!use_tiny_autoencoder) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu, @@ -1041,7 +1045,7 @@ class StableDiffusionGGML { struct ggml_tensor* c_concat = nullptr; { if (zero_out_masked) { - c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1); + c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / get_vae_scale_factor(), height / get_vae_scale_factor(), 4, 1); ggml_set_f32(c_concat, 0.f); } else { ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); @@ -1375,6 +1379,53 @@ class StableDiffusionGGML { return x; } + int get_vae_scale_factor() { + int vae_scale_factor = 8; + if (version == VERSION_WAN2_2_TI2V) { + vae_scale_factor = 16; + } else if (version == VERSION_CHROMA_RADIANCE) { + vae_scale_factor = 1; + } + return vae_scale_factor; + } + + int get_latent_channel() { + int latent_channel = 4; + if (sd_version_is_dit(version)) { + if (version == VERSION_WAN2_2_TI2V) { + latent_channel = 48; + } else if (version == VERSION_CHROMA_RADIANCE) { + latent_channel = 3; + } else { + latent_channel = 16; + } + } + return latent_channel; + } + + ggml_tensor* generate_init_latent(ggml_context* work_ctx, + int width, + int height, + int frames = 1, + bool video = false) { + int vae_scale_factor = get_vae_scale_factor(); + int W = width / vae_scale_factor; + int H = height / vae_scale_factor; + int T = frames; + if (sd_version_is_wan(version)) { + T = ((T - 1) / 4) + 1; + } + int C = get_latent_channel(); + ggml_tensor* init_latent; + if (video) { + init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); + } else { + init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); + } + ggml_set_f32(init_latent, shift_factor); + return init_latent; + } + void process_latent_in(ggml_tensor* latent) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48); @@ -1410,6 +1461,8 @@ class StableDiffusionGGML { } } } + } else if (version == VERSION_CHROMA_RADIANCE) { + // pass } else { ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3); @@ -1454,6 +1507,8 @@ class StableDiffusionGGML { } } } + } else if (version == VERSION_CHROMA_RADIANCE) { + // pass } else { ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3); @@ -1495,11 +1550,11 @@ class StableDiffusionGGML { ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) { int64_t t0 = ggml_time_ms(); ggml_tensor* result = nullptr; - int W = x->ne[0] / 8; - int H = x->ne[1] / 8; + int W = x->ne[0] / get_vae_scale_factor(); + int H = x->ne[1] / get_vae_scale_factor(); + int C = get_latent_channel(); if (vae_tiling_params.enabled && !encode_video) { // TODO wan2.2 vae support? - int C = sd_version_is_dit(version) ? 16 : 4; int ne2; int ne3; if (sd_version_is_qwen_image(version)) { @@ -1586,7 +1641,10 @@ class StableDiffusionGGML { ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) { ggml_tensor* latent; - if (use_tiny_autoencoder || sd_version_is_qwen_image(version) || sd_version_is_wan(version)) { + if (use_tiny_autoencoder || + sd_version_is_qwen_image(version) || + sd_version_is_wan(version) || + version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { latent = ggml_view_3d(work_ctx, @@ -1613,18 +1671,14 @@ class StableDiffusionGGML { } ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { - int64_t W = x->ne[0] * 8; - int64_t H = x->ne[1] * 8; + int64_t W = x->ne[0] * get_vae_scale_factor(); + int64_t H = x->ne[1] * get_vae_scale_factor(); int64_t C = 3; ggml_tensor* result = nullptr; if (decode_video) { int T = x->ne[2]; if (sd_version_is_wan(version)) { T = ((T - 1) * 4) + 1; - if (version == VERSION_WAN2_2_TI2V) { - W = x->ne[0] * 16; - H = x->ne[1] * 16; - } } result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, @@ -2235,16 +2289,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode - int C = 4; - if (sd_version_is_sd3(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_flux(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_qwen_image(sd_ctx->sd->version)) { - C = 16; - } - int W = width / 8; - int H = height / 8; + int C = sd_ctx->sd->get_latent_channel(); + int W = width / sd_ctx->sd->get_vae_scale_factor(); + int H = height / sd_ctx->sd->get_vae_scale_factor(); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); struct ggml_tensor* control_latent = nullptr; @@ -2422,51 +2469,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, return result_images; } -ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, - ggml_context* work_ctx, - int width, - int height, - int frames = 1, - bool video = false) { - int C = 4; - int T = frames; - int W = width / 8; - int H = height / 8; - if (sd_version_is_sd3(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_flux(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_qwen_image(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_wan(sd_ctx->sd->version)) { - C = 16; - T = ((T - 1) / 4) + 1; - if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) { - C = 48; - W = width / 16; - H = height / 16; - } - } - ggml_tensor* init_latent; - if (video) { - init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); - } else { - init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - } - if (sd_version_is_sd3(sd_ctx->sd->version)) { - ggml_set_f32(init_latent, 0.0609f); - } else if (sd_version_is_flux(sd_ctx->sd->version)) { - ggml_set_f32(init_latent, 0.1159f); - } else { - ggml_set_f32(init_latent, 0.f); - } - return init_latent; -} - sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; int width = sd_img_gen_params->width; int height = sd_img_gen_params->height; + int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); if (sd_version_is_dit(sd_ctx->sd->version)) { if (width % 16 || height % 16) { LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", @@ -2562,20 +2569,20 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g 1); for (int ix = 0; ix < masked_latent->ne[0]; ix++) { for (int iy = 0; iy < masked_latent->ne[1]; iy++) { - int mx = ix * 8; - int my = iy * 8; + int mx = ix * vae_scale_factor; + int my = iy * vae_scale_factor; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { for (int k = 0; k < masked_latent->ne[2]; k++) { float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); ggml_tensor_set_f32(concat_latent, v, ix, iy, k); } // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image - for (int x = 0; x < 8; x++) { - for (int y = 0; y < 8; y++) { + for (int x = 0; x < vae_scale_factor; x++) { + for (int y = 0; y < vae_scale_factor; y++) { float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); - // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) - // python code was using "b (h 8) (w 8) -> b (8 8) h w" - ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); + // TODO: check if the way the mask is flattened is correct (is it supposed to be x*vae_scale_factor+y or x+vae_scale_factor*y?) + // python code was using "b (h vae_scale_factor) (w vae_scale_factor) -> b (vae_scale_factor vae_scale_factor) h w" + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * vae_scale_factor + y); } } } else if (sd_ctx->sd->version == VERSION_FLEX_2) { @@ -2598,11 +2605,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g { // LOG_WARN("Inpainting with a base model is not great"); - denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); + denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / vae_scale_factor, height / vae_scale_factor, 1, 1); for (int ix = 0; ix < denoise_mask->ne[0]; ix++) { for (int iy = 0; iy < denoise_mask->ne[1]; iy++) { - int mx = ix * 8; - int my = iy * 8; + int mx = ix * vae_scale_factor; + int my = iy * vae_scale_factor; float m = ggml_tensor_get_f32(mask_img, mx, my); ggml_tensor_set_f32(denoise_mask, m, ix, iy); } @@ -2613,7 +2620,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g if (sd_version_is_inpaint(sd_ctx->sd->version)) { LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); } - init_latent = generate_init_latent(sd_ctx, work_ctx, width, height); + init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height); } sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance; @@ -2741,6 +2748,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int sample_steps = sd_vid_gen_params->sample_params.sample_steps; LOG_INFO("generate_video %dx%dx%d", width, height, frames); + int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); + sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler); int high_noise_sample_steps = 0; @@ -2838,7 +2847,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor_set_f32(image, value, i0, i1, i2, i3); }); - concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8] + concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor] int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); @@ -2848,7 +2857,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s concat_latent->ne[0], concat_latent->ne[1], concat_latent->ne[2], - 4); // [b*4, t, w/8, h/8] + 4); // [b*4, t, w/vae_scale_factor, h/vae_scale_factor] ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = 0.0f; if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image @@ -2859,7 +2868,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); }); - concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8] + concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/vae_scale_factor, w/vae_scale_factor] } else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) { LOG_INFO("IMG2VID"); @@ -2870,7 +2879,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16] - init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); + init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true); denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); ggml_set_f32(denoise_mask, 1.f); @@ -2927,8 +2936,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3); }); - inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8] - reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8] + inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor] + reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor] int64_t length = inactive->ne[2]; if (ref_image_latent) { @@ -2936,7 +2945,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s frames = (length - 1) * 4 + 1; ref_image_num = 1; } - vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/8, w/8] + vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor] ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value; if (i3 < 32) { @@ -2953,7 +2962,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s if (ref_image_latent && i2 == 0) { value = 0.f; } else { - int64_t vae_stride = 8; + int64_t vae_stride = vae_scale_factor; int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride; int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride; value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0); @@ -2966,7 +2975,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } if (init_latent == nullptr) { - init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true); + init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true); } // Get learned condition @@ -2997,16 +3006,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->cond_stage_model->free_params_buffer(); } - int W = width / 8; - int H = height / 8; + int W = width / vae_scale_factor; + int H = height / vae_scale_factor; int T = init_latent->ne[2]; - int C = 16; - - if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) { - W = width / 16; - H = height / 16; - C = 48; - } + int C = sd_ctx->sd->get_latent_channel(); struct ggml_tensor* final_latent; struct ggml_tensor* x_t = init_latent; diff --git a/vae.hpp b/vae.hpp index 455edae04..202ebe7c0 100644 --- a/vae.hpp +++ b/vae.hpp @@ -533,6 +533,30 @@ struct VAE : public GGMLRunner { virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; }; +struct FakeVAE : public VAE { + FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu) + : VAE(backend, offload_params_to_cpu) {} + void compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx) override { + if (*output == nullptr && output_ctx != nullptr) { + *output = ggml_dup_tensor(output_ctx, z); + } + ggml_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_tensor_get_f32(z, i0, i1, i2, i3); + ggml_tensor_set_f32(*output, value, i0, i1, i2, i3); + }); + } + + void get_param_tensors(std::map& tensors, const std::string prefix) override {} + + std::string get_desc() override { + return "fake_vae"; + } +}; + struct AutoEncoderKL : public VAE { bool decode_only = true; AutoencodingEngine ae;