diff --git a/README.md b/README.md
index 5521ec88e..615d892f6 100644
--- a/README.md
+++ b/README.md
@@ -35,10 +35,11 @@ API and command-line option may change frequently.***
- Image Models
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- - [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
+ - [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [Chroma](./docs/chroma.md)
+ - [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
diff --git a/assets/flux/chroma1-radiance.png b/assets/flux/chroma1-radiance.png
new file mode 100644
index 000000000..1dd4a524a
Binary files /dev/null and b/assets/flux/chroma1-radiance.png differ
diff --git a/docs/chroma_radiance.md b/docs/chroma_radiance.md
new file mode 100644
index 000000000..a343520bf
--- /dev/null
+++ b/docs/chroma_radiance.md
@@ -0,0 +1,21 @@
+# How to Use
+
+## Download weights
+
+- Download Chroma1-Radiance
+ - safetensors: https://huggingface.co/lodestones/Chroma1-Radiance/tree/main
+ - gguf: https://huggingface.co/silveroxides/Chroma1-Radiance-GGUF/tree/main
+
+- Download t5xxl
+ - safetensors: https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors
+
+## Examples
+
+```
+.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Chroma1-Radiance-v0.4-Q8_0.gguf --t5xxl ..\..\ComfyUI\models\clip\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma radiance cpp'" --cfg-scale 4.0 --sampling-method euler -v
+```
+
+
+
+
+
diff --git a/flux.hpp b/flux.hpp
index 355184be2..867a4fafa 100644
--- a/flux.hpp
+++ b/flux.hpp
@@ -399,7 +399,7 @@ namespace Flux {
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
int64_t offset = 3 * idx;
- return {ctx, vec, offset};
+ return ModulationOut(ctx, vec, offset);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
@@ -549,7 +549,135 @@ namespace Flux {
}
};
+ struct NerfEmbedder : public GGMLBlock {
+ NerfEmbedder(int64_t in_channels,
+ int64_t hidden_size_input,
+ int64_t max_freqs) {
+ blocks["embedder.0"] = std::make_shared(in_channels + max_freqs * max_freqs, hidden_size_input);
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* dct) {
+ // x: (B, P^2, C)
+ // dct: (1, P^2, max_freqs^2)
+ // return: (B, P^2, hidden_size_input)
+ auto embedder = std::dynamic_pointer_cast(blocks["embedder.0"]);
+
+ dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
+ x = ggml_concat(ctx, x, dct, 0);
+ x = embedder->forward(ctx, x);
+
+ return x;
+ }
+ };
+
+ struct NerfGLUBlock : public GGMLBlock {
+ int64_t mlp_ratio;
+ NerfGLUBlock(int64_t hidden_size_s,
+ int64_t hidden_size_x,
+ int64_t mlp_ratio)
+ : mlp_ratio(mlp_ratio) {
+ int64_t total_params = 3 * hidden_size_x * hidden_size_x * mlp_ratio;
+ blocks["param_generator"] = std::make_shared(hidden_size_s, total_params);
+ blocks["norm"] = std::make_shared(hidden_size_x);
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* s) {
+ // x: (batch_size, n_token, hidden_size_x)
+ // s: (batch_size, hidden_size_s)
+ // return: (batch_size, n_token, hidden_size_x)
+ auto param_generator = std::dynamic_pointer_cast(blocks["param_generator"]);
+ auto norm = std::dynamic_pointer_cast(blocks["norm"]);
+
+ int64_t batch_size = x->ne[2];
+ int64_t hidden_size_x = x->ne[0];
+
+ auto mlp_params = param_generator->forward(ctx, s);
+ auto fc_params = ggml_chunk(ctx, mlp_params, 3, 0);
+ auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
+ auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
+ auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
+
+ fc1_gate = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
+ fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f);
+ fc1_value = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
+ fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f);
+ fc2 = ggml_cont(ctx, ggml_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
+ fc2 = ggml_l2_norm(ctx, fc2, 1e-12f);
+
+ auto res_x = x;
+ x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x]
+
+ auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
+ x1 = ggml_silu_inplace(ctx, x1);
+
+ auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
+
+ x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
+
+ x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
+
+ x = ggml_add_inplace(ctx, x, res_x);
+
+ return x;
+ }
+ };
+
+ struct NerfFinalLayer : public GGMLBlock {
+ NerfFinalLayer(int64_t hidden_size,
+ int64_t out_channels) {
+ blocks["norm"] = std::make_shared(hidden_size);
+ blocks["linear"] = std::make_shared(hidden_size, out_channels);
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* x) {
+ auto norm = std::dynamic_pointer_cast(blocks["norm"]);
+ auto linear = std::dynamic_pointer_cast(blocks["linear"]);
+
+ x = norm->forward(ctx, x);
+ x = linear->forward(ctx, x);
+
+ return x;
+ }
+ };
+
+ struct NerfFinalLayerConv : public GGMLBlock {
+ NerfFinalLayerConv(int64_t hidden_size,
+ int64_t out_channels) {
+ blocks["norm"] = std::make_shared(hidden_size);
+ blocks["conv"] = std::make_shared(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1});
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* x) {
+ // x: [N, C, H, W]
+ auto norm = std::dynamic_pointer_cast(blocks["norm"]);
+ auto conv = std::dynamic_pointer_cast(blocks["conv"]);
+
+ x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
+ x = norm->forward(ctx, x);
+ x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
+ x = conv->forward(ctx, x);
+
+ return x;
+ }
+ };
+
+ struct ChromaRadianceParams {
+ int64_t nerf_hidden_size = 64;
+ int64_t nerf_mlp_ratio = 4;
+ int64_t nerf_depth = 4;
+ int64_t nerf_max_freqs = 8;
+ };
+
struct FluxParams {
+ SDVersion version = VERSION_FLUX;
+ bool is_chroma = false;
+ int64_t patch_size = 2;
int64_t in_channels = 64;
int64_t out_channels = 64;
int64_t vec_in_dim = 768;
@@ -565,8 +693,8 @@ namespace Flux {
bool qkv_bias = true;
bool guidance_embed = true;
bool flash_attn = true;
- bool is_chroma = false;
- SDVersion version = VERSION_FLUX;
+ int64_t in_dim = 64;
+ ChromaRadianceParams chroma_radiance_params;
};
struct Flux : public GGMLBlock {
@@ -575,53 +703,89 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
- blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true));
+ if (params.version == VERSION_CHROMA_RADIANCE) {
+ std::pair kernel_size = {(int)params.patch_size, (int)params.patch_size};
+ std::pair stride = kernel_size;
+
+ blocks["img_in_patch"] = std::make_shared(params.in_channels,
+ params.hidden_size,
+ kernel_size,
+ stride);
+ } else {
+ blocks["img_in"] = std::make_shared(params.in_channels, params.hidden_size, true);
+ }
if (params.is_chroma) {
- blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size));
+ blocks["distilled_guidance_layer"] = std::make_shared(params.in_dim, params.hidden_size);
} else {
- blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size));
- blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
+ blocks["time_in"] = std::make_shared(256, params.hidden_size);
+ blocks["vector_in"] = std::make_shared(params.vec_in_dim, params.hidden_size);
if (params.guidance_embed) {
- blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size));
+ blocks["guidance_in"] = std::make_shared(256, params.hidden_size);
}
}
- blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true));
+ blocks["txt_in"] = std::make_shared(params.context_in_dim, params.hidden_size, true);
for (int i = 0; i < params.depth; i++) {
- blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size,
- params.num_heads,
- params.mlp_ratio,
- i,
- params.qkv_bias,
- params.flash_attn,
- params.is_chroma));
+ blocks["double_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size,
+ params.num_heads,
+ params.mlp_ratio,
+ i,
+ params.qkv_bias,
+ params.flash_attn,
+ params.is_chroma);
}
for (int i = 0; i < params.depth_single_blocks; i++) {
- blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size,
- params.num_heads,
- params.mlp_ratio,
- i,
- 0.f,
- params.flash_attn,
- params.is_chroma));
+ blocks["single_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size,
+ params.num_heads,
+ params.mlp_ratio,
+ i,
+ 0.f,
+ params.flash_attn,
+ params.is_chroma);
}
- blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma));
+ if (params.version == VERSION_CHROMA_RADIANCE) {
+ blocks["nerf_image_embedder"] = std::make_shared(params.in_channels,
+ params.chroma_radiance_params.nerf_hidden_size,
+ params.chroma_radiance_params.nerf_max_freqs);
+
+ for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
+ blocks["nerf_blocks." + std::to_string(i)] = std::make_shared(params.hidden_size,
+ params.chroma_radiance_params.nerf_hidden_size,
+ params.chroma_radiance_params.nerf_mlp_ratio);
+ }
+
+ blocks["nerf_final_layer_conv"] = std::make_shared(params.chroma_radiance_params.nerf_hidden_size,
+ params.in_channels);
+
+ } else {
+ blocks["final_layer"] = std::make_shared(params.hidden_size, 1, params.out_channels, params.is_chroma);
+ }
+ }
+
+ struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
+ struct ggml_tensor* x) {
+ int64_t W = x->ne[0];
+ int64_t H = x->ne[1];
+
+ int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
+ int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
+ x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
+ return x;
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
- struct ggml_tensor* x,
- int64_t patch_size) {
+ struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, C * patch_size * patch_size]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
- int64_t p = patch_size;
- int64_t h = H / patch_size;
- int64_t w = W / patch_size;
+ int64_t p = params.patch_size;
+ int64_t h = H / params.patch_size;
+ int64_t w = W / params.patch_size;
GGML_ASSERT(h * p == H && w * p == W);
@@ -633,18 +797,25 @@ namespace Flux {
return x;
}
+ struct ggml_tensor* process_img(struct ggml_context* ctx,
+ struct ggml_tensor* x) {
+ // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
+ x = pad_to_patch_size(ctx, x);
+ x = patchify(ctx, x);
+ return x;
+ }
+
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
- int64_t w,
- int64_t patch_size) {
+ int64_t w) {
// x: [N, h*w, C*patch_size*patch_size]
// return: [N, C, H, W]
int64_t N = x->ne[2];
- int64_t C = x->ne[0] / patch_size / patch_size;
- int64_t H = h * patch_size;
- int64_t W = w * patch_size;
- int64_t p = patch_size;
+ int64_t C = x->ne[0] / params.patch_size / params.patch_size;
+ int64_t H = h * params.patch_size;
+ int64_t W = w * params.patch_size;
+ int64_t p = params.patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
@@ -671,7 +842,10 @@ namespace Flux {
auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]);
auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]);
- img = img_in->forward(ctx, img);
+ if (img_in) {
+ img = img_in->forward(ctx, img);
+ }
+
struct ggml_tensor* vec;
struct ggml_tensor* txt_img_mask = nullptr;
if (params.is_chroma) {
@@ -682,7 +856,7 @@ namespace Flux {
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
- GGML_ASSERT(arange != nullptr);
+ GGML_ASSERT(mod_index_arange != nullptr);
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
// Batch broadcast (will it ever be useful)
@@ -749,52 +923,96 @@ namespace Flux {
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
- img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
+ if (final_layer) {
+ img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
+ }
+
return img;
}
- struct ggml_tensor* process_img(struct ggml_context* ctx,
- struct ggml_tensor* x) {
+ struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx,
+ ggml_backend_t backend,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timestep,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ struct ggml_tensor* guidance,
+ struct ggml_tensor* pe,
+ struct ggml_tensor* mod_index_arange = nullptr,
+ struct ggml_tensor* dct = nullptr,
+ std::vector ref_latents = {},
+ std::vector skip_layers = {}) {
+ GGML_ASSERT(x->ne[3] == 1);
+
int64_t W = x->ne[0];
int64_t H = x->ne[1];
- int64_t patch_size = 2;
+ int64_t C = x->ne[2];
+ int64_t patch_size = params.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
- x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
- // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
- auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
- return img;
- }
+ auto img = pad_to_patch_size(ctx, x);
+ auto orig_img = img;
- struct ggml_tensor* forward(struct ggml_context* ctx,
- ggml_backend_t backend,
- struct ggml_tensor* x,
- struct ggml_tensor* timestep,
- struct ggml_tensor* context,
- struct ggml_tensor* c_concat,
- struct ggml_tensor* y,
- struct ggml_tensor* guidance,
- struct ggml_tensor* pe,
- struct ggml_tensor* mod_index_arange = nullptr,
- std::vector ref_latents = {},
- std::vector skip_layers = {}) {
- // Forward pass of DiT.
- // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
- // timestep: (N,) tensor of diffusion timesteps
- // context: (N, L, D)
- // c_concat: nullptr, or for (N,C+M, H, W) for Fill
- // y: (N, adm_in_channels) tensor of class labels
- // guidance: (N,)
- // pe: (L, d_head/2, 2, 2)
- // return: (N, C, H, W)
+ auto img_in_patch = std::dynamic_pointer_cast(blocks["img_in_patch"]);
+
+ img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
+ img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
+ img = ggml_cont(ctx, ggml_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
+
+ auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
+
+ // nerf decode
+ auto nerf_image_embedder = std::dynamic_pointer_cast(blocks["nerf_image_embedder"]);
+ auto nerf_final_layer_conv = std::dynamic_pointer_cast(blocks["nerf_final_layer_conv"]);
+ auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
+ int64_t num_patches = nerf_pixels->ne[1];
+ nerf_pixels = ggml_reshape_3d(ctx,
+ nerf_pixels,
+ nerf_pixels->ne[0] / C,
+ C,
+ nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
+ nerf_pixels = ggml_cont(ctx, ggml_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
+
+ auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
+ auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
+
+ for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
+ auto block = std::dynamic_pointer_cast(blocks["nerf_blocks." + std::to_string(i)]);
+
+ img_dct = block->forward(ctx, img_dct, nerf_hidden);
+ }
+
+ img_dct = ggml_cont(ctx, ggml_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
+ img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
+ img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
+
+ out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
+
+ return out;
+ }
+
+ struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx,
+ ggml_backend_t backend,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timestep,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ struct ggml_tensor* guidance,
+ struct ggml_tensor* pe,
+ struct ggml_tensor* mod_index_arange = nullptr,
+ struct ggml_tensor* dct = nullptr,
+ std::vector ref_latents = {},
+ std::vector skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
- int64_t patch_size = 2;
+ int64_t patch_size = params.patch_size;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -816,21 +1034,16 @@ namespace Flux {
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
- masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
- mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
- control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
-
- masked = patchify(ctx, masked, patch_size);
- mask = patchify(ctx, mask, patch_size);
- control = patchify(ctx, control, patch_size);
+ masked = process_img(ctx, masked);
+ mask = process_img(ctx, mask);
+ control = process_img(ctx, control);
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != nullptr);
- ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
- control = patchify(ctx, control, patch_size);
- img = ggml_concat(ctx, img, control, 0);
+ auto control = process_img(ctx, c_concat);
+ img = ggml_concat(ctx, img, control, 0);
}
if (ref_latents.size() > 0) {
@@ -849,10 +1062,63 @@ namespace Flux {
}
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
- out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
-
+ out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
return out;
}
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ ggml_backend_t backend,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timestep,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ struct ggml_tensor* guidance,
+ struct ggml_tensor* pe,
+ struct ggml_tensor* mod_index_arange = nullptr,
+ struct ggml_tensor* dct = nullptr,
+ std::vector ref_latents = {},
+ std::vector skip_layers = {}) {
+ // Forward pass of DiT.
+ // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ // timestep: (N,) tensor of diffusion timesteps
+ // context: (N, L, D)
+ // c_concat: nullptr, or for (N,C+M, H, W) for Fill
+ // y: (N, adm_in_channels) tensor of class labels
+ // guidance: (N,)
+ // pe: (L, d_head/2, 2, 2)
+ // return: (N, C, H, W)
+
+ if (params.version == VERSION_CHROMA_RADIANCE) {
+ return forward_chroma_radiance(ctx,
+ backend,
+ x,
+ timestep,
+ context,
+ c_concat,
+ y,
+ guidance,
+ pe,
+ mod_index_arange,
+ dct,
+ ref_latents,
+ skip_layers);
+ } else {
+ return forward_flux_chroma(ctx,
+ backend,
+ x,
+ timestep,
+ context,
+ c_concat,
+ y,
+ guidance,
+ pe,
+ mod_index_arange,
+ dct,
+ ref_latents,
+ skip_layers);
+ }
+ }
};
struct FluxRunner : public GGMLRunner {
@@ -860,7 +1126,8 @@ namespace Flux {
FluxParams flux_params;
Flux flux;
std::vector pe_vec;
- std::vector mod_index_arange_vec; // for cache
+ std::vector mod_index_arange_vec;
+ std::vector dct_vec;
SDVersion version;
bool use_mask = false;
@@ -883,6 +1150,9 @@ namespace Flux {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ flux_params.in_channels = 3;
+ flux_params.patch_size = 16;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
@@ -933,6 +1203,56 @@ namespace Flux {
flux.get_param_tensors(tensors, prefix);
}
+ std::vector fetch_dct_pos(int patch_size, int max_freqs) {
+ const float PI = 3.14159265358979323846f;
+
+ std::vector pos(patch_size);
+ for (int i = 0; i < patch_size; ++i) {
+ pos[i] = static_cast(i) / static_cast(patch_size - 1);
+ }
+
+ std::vector pos_x(patch_size * patch_size);
+ std::vector pos_y(patch_size * patch_size);
+ for (int i = 0; i < patch_size; ++i) {
+ for (int j = 0; j < patch_size; ++j) {
+ pos_x[i * patch_size + j] = pos[j];
+ pos_y[i * patch_size + j] = pos[i];
+ }
+ }
+
+ std::vector freqs(max_freqs);
+ for (int i = 0; i < max_freqs; ++i) {
+ freqs[i] = static_cast(i);
+ }
+
+ std::vector coeffs(max_freqs * max_freqs);
+ for (int fx = 0; fx < max_freqs; ++fx) {
+ for (int fy = 0; fy < max_freqs; ++fy) {
+ coeffs[fx * max_freqs + fy] = 1.0f / (1.0f + freqs[fx] * freqs[fy]);
+ }
+ }
+
+ int num_positions = patch_size * patch_size;
+ int num_features = max_freqs * max_freqs;
+ std::vector dct(num_positions * num_features);
+
+ for (int p = 0; p < num_positions; ++p) {
+ float px = pos_x[p];
+ float py = pos_y[p];
+
+ for (int fx = 0; fx < max_freqs; ++fx) {
+ float cx = std::cos(px * freqs[fx] * PI);
+ for (int fy = 0; fy < max_freqs; ++fy) {
+ float cy = std::cos(py * freqs[fy] * PI);
+ float val = cx * cy * coeffs[fx * max_freqs + fy];
+ dct[p * num_features + (fx * max_freqs + fy)] = val;
+ }
+ }
+ }
+
+ return dct;
+ }
+
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
@@ -946,6 +1266,7 @@ namespace Flux {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
struct ggml_tensor* mod_index_arange = nullptr;
+ struct ggml_tensor* dct = nullptr; // for chroma radiance
x = to_backend(x);
context = to_backend(context);
@@ -976,7 +1297,7 @@ namespace Flux {
pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
- 2,
+ flux_params.patch_size,
x->ne[3],
context->ne[1],
ref_latents,
@@ -991,6 +1312,17 @@ namespace Flux {
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());
+ if (version == VERSION_CHROMA_RADIANCE) {
+ int64_t patch_size = flux_params.patch_size;
+ int64_t nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs;
+ dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs);
+ dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size);
+ // dct->data = dct_vec.data();
+ // print_ggml_tensor(dct);
+ // dct->data = nullptr;
+ set_backend_tensor_data(dct, dct_vec.data());
+ }
+
struct ggml_tensor* out = flux.forward(compute_ctx,
runtime_backend,
x,
@@ -1001,6 +1333,7 @@ namespace Flux {
guidance,
pe,
mod_index_arange,
+ dct,
ref_latents,
skip_layers);
@@ -1035,7 +1368,7 @@ namespace Flux {
void test() {
struct ggml_init_params params;
- params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB
+ params.mem_size = static_cast(1024 * 1024) * 1024; // 1GB
params.mem_buffer = nullptr;
params.no_alloc = false;
@@ -1046,22 +1379,25 @@ namespace Flux {
// cpu f16:
// cuda f16: nan
// cuda q8_0: pass
- auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
- ggml_set_f32(x, 0.01f);
+ // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
+ // ggml_set_f32(x, 0.01f);
+ auto x = load_tensor_from_file(work_ctx, "chroma_x.bin");
// print_ggml_tensor(x);
- std::vector timesteps_vec(1, 999.f);
+ std::vector timesteps_vec(1, 1.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
- std::vector guidance_vec(1, 3.5f);
+ std::vector guidance_vec(1, 0.f);
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
- auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1);
- ggml_set_f32(context, 0.01f);
+ // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1);
+ // ggml_set_f32(context, 0.01f);
+ auto context = load_tensor_from_file(work_ctx, "chroma_context.bin");
// print_ggml_tensor(context);
- auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
- ggml_set_f32(y, 0.01f);
+ // auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
+ // ggml_set_f32(y, 0.01f);
+ auto y = nullptr;
// print_ggml_tensor(y);
struct ggml_tensor* out = nullptr;
@@ -1076,32 +1412,44 @@ namespace Flux {
}
static void load_from_file_and_test(const std::string& file_path) {
- // ggml_backend_t backend = ggml_backend_cuda_init(0);
- ggml_backend_t backend = ggml_backend_cpu_init();
- ggml_type model_data_type = GGML_TYPE_Q8_0;
- std::shared_ptr flux = std::make_shared(backend, false);
- {
- LOG_INFO("loading from '%s'", file_path.c_str());
-
- flux->alloc_params_buffer();
- std::map tensors;
- flux->get_param_tensors(tensors, "model.diffusion_model");
+ // ggml_backend_t backend = ggml_backend_cuda_init(0);
+ ggml_backend_t backend = ggml_backend_cpu_init();
+ ggml_type model_data_type = GGML_TYPE_Q8_0;
+
+ ModelLoader model_loader;
+ if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
+ LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
+ return;
+ }
- ModelLoader model_loader;
- if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
- LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
- return;
+ auto tensor_types = model_loader.tensor_storages_types;
+ for (auto& item : tensor_types) {
+ // LOG_DEBUG("%s %u", item.first.c_str(), item.second);
+ if (ends_with(item.first, "weight")) {
+ // item.second = model_data_type;
}
+ }
- bool success = model_loader.load_tensors(tensors);
+ std::shared_ptr flux = std::make_shared(backend,
+ false,
+ tensor_types,
+ "model.diffusion_model",
+ VERSION_CHROMA_RADIANCE,
+ false,
+ true);
- if (!success) {
- LOG_ERROR("load tensors from model loader failed");
- return;
- }
+ flux->alloc_params_buffer();
+ std::map tensors;
+ flux->get_param_tensors(tensors, "model.diffusion_model");
- LOG_INFO("flux model loaded");
+ bool success = model_loader.load_tensors(tensors);
+
+ if (!success) {
+ LOG_ERROR("load tensors from model loader failed");
+ return;
}
+
+ LOG_INFO("flux model loaded");
flux->test();
}
};
diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 02d82bc09..66797941d 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -954,7 +954,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
if (scale != 1.f) {
x = ggml_scale(ctx, x, scale);
}
- x = ggml_mul_mat(ctx, w, x);
+ if (x->ne[2] * x->ne[3] > 1024) {
+ // workaround: avoid ggml cuda error
+ int64_t ne2 = x->ne[2];
+ int64_t ne3 = x->ne[3];
+ x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
+ x = ggml_mul_mat(ctx, w, x);
+ x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
+ } else {
+ x = ggml_mul_mat(ctx, w, x);
+ }
if (force_prec_f32) {
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
}
diff --git a/model.cpp b/model.cpp
index 0a03627f9..da77afedd 100644
--- a/model.cpp
+++ b/model.cpp
@@ -1778,7 +1778,6 @@ bool ModelLoader::model_is_unet() {
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
- bool input_block_checked = false;
bool has_multiple_encoders = false;
bool is_unet = false;
@@ -1791,12 +1790,12 @@ SDVersion ModelLoader::get_sd_version() {
bool has_middle_block_1 = false;
for (auto& tensor_storage : tensor_storages) {
- if (!(is_xl || is_flux)) {
+ if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
- if (input_block_checked) {
- break;
- }
+ }
+ if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
+ return VERSION_CHROMA_RADIANCE;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
@@ -1813,22 +1812,19 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
has_img_emb = true;
}
- if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
+ if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos ||
+ tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true;
if (has_multiple_encoders) {
is_xl = true;
- if (input_block_checked) {
- break;
- }
}
}
- if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
+ if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos ||
+ tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
+ tensor_storage.name.find("te.1") != std::string::npos) {
has_multiple_encoders = true;
if (is_unet) {
is_xl = true;
- if (input_block_checked) {
- break;
- }
}
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
@@ -1848,12 +1844,10 @@ SDVersion ModelLoader::get_sd_version() {
token_embedding_weight = tensor_storage;
// break;
}
- if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
- input_block_weight = tensor_storage;
- input_block_checked = true;
- if (is_flux) {
- break;
- }
+ if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" ||
+ tensor_storage.name == "model.diffusion_model.img_in.weight" ||
+ tensor_storage.name == "unet.conv_in.weight") {
+ input_block_weight = tensor_storage;
}
}
if (is_wan) {
diff --git a/model.h b/model.h
index 65226cd02..f1711e67f 100644
--- a/model.h
+++ b/model.h
@@ -36,6 +36,7 @@ enum SDVersion {
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
+ VERSION_CHROMA_RADIANCE,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
@@ -72,7 +73,11 @@ static inline bool sd_version_is_sd3(SDVersion version) {
}
static inline bool sd_version_is_flux(SDVersion version) {
- if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
+ if (version == VERSION_FLUX ||
+ version == VERSION_FLUX_FILL ||
+ version == VERSION_FLUX_CONTROLS ||
+ version == VERSION_FLEX_2 ||
+ version == VERSION_CHROMA_RADIANCE) {
return true;
}
return false;
diff --git a/qwen_image.hpp b/qwen_image.hpp
index 2d3cd2307..045863021 100644
--- a/qwen_image.hpp
+++ b/qwen_image.hpp
@@ -649,7 +649,7 @@ namespace Qwen {
static void load_from_file_and_test(const std::string& file_path) {
// cuda q8: pass
- // cuda q8 fa: nan
+ // cuda q8 fa: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp
index 802608ea9..f9fe8a5a3 100644
--- a/stable-diffusion.cpp
+++ b/stable-diffusion.cpp
@@ -41,6 +41,7 @@ const char* model_version_to_str[] = {
"Flux Fill",
"Flux Control",
"Flex.2",
+ "Chroma Radiance",
"Wan 2.x",
"Wan 2.2 I2V",
"Wan 2.2 TI2V",
@@ -494,6 +495,9 @@ class StableDiffusionGGML {
version);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ first_stage_model = std::make_shared(vae_backend,
+ offload_params_to_cpu);
} else if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared(vae_backend,
offload_params_to_cpu,
@@ -1041,7 +1045,7 @@ class StableDiffusionGGML {
struct ggml_tensor* c_concat = nullptr;
{
if (zero_out_masked) {
- c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
+ c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / get_vae_scale_factor(), height / get_vae_scale_factor(), 4, 1);
ggml_set_f32(c_concat, 0.f);
} else {
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
@@ -1375,6 +1379,53 @@ class StableDiffusionGGML {
return x;
}
+ int get_vae_scale_factor() {
+ int vae_scale_factor = 8;
+ if (version == VERSION_WAN2_2_TI2V) {
+ vae_scale_factor = 16;
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ vae_scale_factor = 1;
+ }
+ return vae_scale_factor;
+ }
+
+ int get_latent_channel() {
+ int latent_channel = 4;
+ if (sd_version_is_dit(version)) {
+ if (version == VERSION_WAN2_2_TI2V) {
+ latent_channel = 48;
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ latent_channel = 3;
+ } else {
+ latent_channel = 16;
+ }
+ }
+ return latent_channel;
+ }
+
+ ggml_tensor* generate_init_latent(ggml_context* work_ctx,
+ int width,
+ int height,
+ int frames = 1,
+ bool video = false) {
+ int vae_scale_factor = get_vae_scale_factor();
+ int W = width / vae_scale_factor;
+ int H = height / vae_scale_factor;
+ int T = frames;
+ if (sd_version_is_wan(version)) {
+ T = ((T - 1) / 4) + 1;
+ }
+ int C = get_latent_channel();
+ ggml_tensor* init_latent;
+ if (video) {
+ init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
+ } else {
+ init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
+ }
+ ggml_set_f32(init_latent, shift_factor);
+ return init_latent;
+ }
+
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
@@ -1410,6 +1461,8 @@ class StableDiffusionGGML {
}
}
}
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ // pass
} else {
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
@@ -1454,6 +1507,8 @@ class StableDiffusionGGML {
}
}
}
+ } else if (version == VERSION_CHROMA_RADIANCE) {
+ // pass
} else {
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
@@ -1495,11 +1550,11 @@ class StableDiffusionGGML {
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr;
- int W = x->ne[0] / 8;
- int H = x->ne[1] / 8;
+ int W = x->ne[0] / get_vae_scale_factor();
+ int H = x->ne[1] / get_vae_scale_factor();
+ int C = get_latent_channel();
if (vae_tiling_params.enabled && !encode_video) {
// TODO wan2.2 vae support?
- int C = sd_version_is_dit(version) ? 16 : 4;
int ne2;
int ne3;
if (sd_version_is_qwen_image(version)) {
@@ -1586,7 +1641,10 @@ class StableDiffusionGGML {
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
ggml_tensor* latent;
- if (use_tiny_autoencoder || sd_version_is_qwen_image(version) || sd_version_is_wan(version)) {
+ if (use_tiny_autoencoder ||
+ sd_version_is_qwen_image(version) ||
+ sd_version_is_wan(version) ||
+ version == VERSION_CHROMA_RADIANCE) {
latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {
latent = ggml_view_3d(work_ctx,
@@ -1613,18 +1671,14 @@ class StableDiffusionGGML {
}
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
- int64_t W = x->ne[0] * 8;
- int64_t H = x->ne[1] * 8;
+ int64_t W = x->ne[0] * get_vae_scale_factor();
+ int64_t H = x->ne[1] * get_vae_scale_factor();
int64_t C = 3;
ggml_tensor* result = nullptr;
if (decode_video) {
int T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
- if (version == VERSION_WAN2_2_TI2V) {
- W = x->ne[0] * 16;
- H = x->ne[1] * 16;
- }
}
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
@@ -2235,16 +2289,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
// Sample
std::vector final_latents; // collect latents to decode
- int C = 4;
- if (sd_version_is_sd3(sd_ctx->sd->version)) {
- C = 16;
- } else if (sd_version_is_flux(sd_ctx->sd->version)) {
- C = 16;
- } else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
- C = 16;
- }
- int W = width / 8;
- int H = height / 8;
+ int C = sd_ctx->sd->get_latent_channel();
+ int W = width / sd_ctx->sd->get_vae_scale_factor();
+ int H = height / sd_ctx->sd->get_vae_scale_factor();
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* control_latent = nullptr;
@@ -2422,51 +2469,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
return result_images;
}
-ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
- ggml_context* work_ctx,
- int width,
- int height,
- int frames = 1,
- bool video = false) {
- int C = 4;
- int T = frames;
- int W = width / 8;
- int H = height / 8;
- if (sd_version_is_sd3(sd_ctx->sd->version)) {
- C = 16;
- } else if (sd_version_is_flux(sd_ctx->sd->version)) {
- C = 16;
- } else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
- C = 16;
- } else if (sd_version_is_wan(sd_ctx->sd->version)) {
- C = 16;
- T = ((T - 1) / 4) + 1;
- if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
- C = 48;
- W = width / 16;
- H = height / 16;
- }
- }
- ggml_tensor* init_latent;
- if (video) {
- init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
- } else {
- init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
- }
- if (sd_version_is_sd3(sd_ctx->sd->version)) {
- ggml_set_f32(init_latent, 0.0609f);
- } else if (sd_version_is_flux(sd_ctx->sd->version)) {
- ggml_set_f32(init_latent, 0.1159f);
- } else {
- ggml_set_f32(init_latent, 0.f);
- }
- return init_latent;
-}
-
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;
+ int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
if (sd_version_is_dit(sd_ctx->sd->version)) {
if (width % 16 || height % 16) {
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
@@ -2562,20 +2569,20 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1);
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
- int mx = ix * 8;
- int my = iy * 8;
+ int mx = ix * vae_scale_factor;
+ int my = iy * vae_scale_factor;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_latent->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
}
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
- for (int x = 0; x < 8; x++) {
- for (int y = 0; y < 8; y++) {
+ for (int x = 0; x < vae_scale_factor; x++) {
+ for (int y = 0; y < vae_scale_factor; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
- // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
- // python code was using "b (h 8) (w 8) -> b (8 8) h w"
- ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
+ // TODO: check if the way the mask is flattened is correct (is it supposed to be x*vae_scale_factor+y or x+vae_scale_factor*y?)
+ // python code was using "b (h vae_scale_factor) (w vae_scale_factor) -> b (vae_scale_factor vae_scale_factor) h w"
+ ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * vae_scale_factor + y);
}
}
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
@@ -2598,11 +2605,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
{
// LOG_WARN("Inpainting with a base model is not great");
- denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
+ denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / vae_scale_factor, height / vae_scale_factor, 1, 1);
for (int ix = 0; ix < denoise_mask->ne[0]; ix++) {
for (int iy = 0; iy < denoise_mask->ne[1]; iy++) {
- int mx = ix * 8;
- int my = iy * 8;
+ int mx = ix * vae_scale_factor;
+ int my = iy * vae_scale_factor;
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(denoise_mask, m, ix, iy);
}
@@ -2613,7 +2620,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
}
- init_latent = generate_init_latent(sd_ctx, work_ctx, width, height);
+ init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height);
}
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
@@ -2741,6 +2748,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
+ int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
+
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
int high_noise_sample_steps = 0;
@@ -2838,7 +2847,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
});
- concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8]
+ concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@@ -2848,7 +2857,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
concat_latent->ne[0],
concat_latent->ne[1],
concat_latent->ne[2],
- 4); // [b*4, t, w/8, h/8]
+ 4); // [b*4, t, w/vae_scale_factor, h/vae_scale_factor]
ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.0f;
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
@@ -2859,7 +2868,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
});
- concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
+ concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/vae_scale_factor, w/vae_scale_factor]
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
LOG_INFO("IMG2VID");
@@ -2870,7 +2879,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
- init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
+ init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
ggml_set_f32(denoise_mask, 1.f);
@@ -2927,8 +2936,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3);
});
- inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8]
- reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8]
+ inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
+ reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
int64_t length = inactive->ne[2];
if (ref_image_latent) {
@@ -2936,7 +2945,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
frames = (length - 1) * 4 + 1;
ref_image_num = 1;
}
- vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/8, w/8]
+ vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor]
ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value;
if (i3 < 32) {
@@ -2953,7 +2962,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (ref_image_latent && i2 == 0) {
value = 0.f;
} else {
- int64_t vae_stride = 8;
+ int64_t vae_stride = vae_scale_factor;
int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride;
int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride;
value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0);
@@ -2966,7 +2975,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
if (init_latent == nullptr) {
- init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
+ init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
}
// Get learned condition
@@ -2997,16 +3006,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
- int W = width / 8;
- int H = height / 8;
+ int W = width / vae_scale_factor;
+ int H = height / vae_scale_factor;
int T = init_latent->ne[2];
- int C = 16;
-
- if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
- W = width / 16;
- H = height / 16;
- C = 48;
- }
+ int C = sd_ctx->sd->get_latent_channel();
struct ggml_tensor* final_latent;
struct ggml_tensor* x_t = init_latent;
diff --git a/vae.hpp b/vae.hpp
index 455edae04..202ebe7c0 100644
--- a/vae.hpp
+++ b/vae.hpp
@@ -533,6 +533,30 @@ struct VAE : public GGMLRunner {
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
};
+struct FakeVAE : public VAE {
+ FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
+ : VAE(backend, offload_params_to_cpu) {}
+ void compute(const int n_threads,
+ struct ggml_tensor* z,
+ bool decode_graph,
+ struct ggml_tensor** output,
+ struct ggml_context* output_ctx) override {
+ if (*output == nullptr && output_ctx != nullptr) {
+ *output = ggml_dup_tensor(output_ctx, z);
+ }
+ ggml_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
+ float value = ggml_tensor_get_f32(z, i0, i1, i2, i3);
+ ggml_tensor_set_f32(*output, value, i0, i1, i2, i3);
+ });
+ }
+
+ void get_param_tensors(std::map& tensors, const std::string prefix) override {}
+
+ std::string get_desc() override {
+ return "fake_vae";
+ }
+};
+
struct AutoEncoderKL : public VAE {
bool decode_only = true;
AutoencodingEngine ae;