Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DownSampleBlock : public GGMLBlock {
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);

x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = sd_pad(ctx, x, 1, 1, 0, 0);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
Expand Down
5 changes: 5 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ struct SDParams {
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool circular_pad = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
Expand Down Expand Up @@ -183,6 +184,7 @@ void print_params(SDParams params) {
printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false");
printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false");
printf(" circular padding: %s\n", params.circular_pad ? "true" : "false");
printf(" control_strength: %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
Expand Down Expand Up @@ -304,6 +306,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" This might crash if it is not supported by the backend.\n");
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --circular use circular padding for convolutions and pad ops\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color colors the logging tags according to level\n");
Expand Down Expand Up @@ -573,6 +576,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
{"", "--circular", "", true, &params.circular_pad},
{"", "--canny", "", true, &params.canny_preprocess},
{"-v", "--verbose", "", true, &params.verbose},
{"", "--color", "", true, &params.color},
Expand Down Expand Up @@ -1386,6 +1390,7 @@ int main(int argc, const char* argv[]) {
params.diffusion_flash_attn,
params.diffusion_conv_direct,
params.vae_conv_direct,
params.circular_pad,
params.force_sdxl_vae_conv_scale,
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
Expand Down
67 changes: 34 additions & 33 deletions flux.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef __FLUX_HPP__
#define __FLUX_HPP__

#include <memory>
#include <vector>

#include "ggml_extend.hpp"
Expand All @@ -18,7 +19,7 @@ namespace Flux {
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
// x: [..., in_dim]
// return: [..., hidden_dim]
auto in_layer = std::dynamic_pointer_cast<Linear>(blocks["in_layer"]);
Expand All @@ -36,7 +37,7 @@ namespace Flux {
int64_t hidden_size;
float eps;

void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
ggml_type wtype = GGML_TYPE_F32;
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
Expand All @@ -47,7 +48,7 @@ namespace Flux {
: hidden_size(hidden_size),
eps(eps) {}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override {
struct ggml_tensor* w = params["scale"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
Expand Down Expand Up @@ -136,11 +137,11 @@ namespace Flux {
};

struct ModulationOut {
ggml_tensor* shift = NULL;
ggml_tensor* scale = NULL;
ggml_tensor* gate = NULL;
ggml_tensor* shift = nullptr;
ggml_tensor* scale = nullptr;
ggml_tensor* gate = nullptr;

ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL)
ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr)
: shift(shift), scale(scale), gate(gate) {}

ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) {
Expand Down Expand Up @@ -259,7 +260,7 @@ namespace Flux {
struct ggml_tensor* txt,
struct ggml_tensor* vec,
struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) {
struct ggml_tensor* mask = nullptr) {
// img: [N, n_img_token, hidden_size]
// txt: [N, n_txt_token, hidden_size]
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
Expand Down Expand Up @@ -398,15 +399,15 @@ namespace Flux {

ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
int64_t offset = 3 * idx;
return ModulationOut(ctx, vec, offset);
return {ctx, vec, offset};
}

struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* vec,
struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) {
struct ggml_tensor* mask = nullptr) {
// x: [N, n_token, hidden_size]
// pe: [n_token, d_head/2, 2, 2]
// return: [N, n_token, hidden_size]
Expand Down Expand Up @@ -485,7 +486,7 @@ namespace Flux {
auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
// No gate
return ModulationOut(shift, scale, NULL);
return {shift, scale, nullptr};
}

struct ggml_tensor* forward(struct ggml_context* ctx,
Expand Down Expand Up @@ -664,15 +665,15 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
struct ggml_tensor* mod_index_arange = nullptr,
std::vector<int> skip_layers = {}) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);

img = img_in->forward(ctx, img);
struct ggml_tensor* vec;
struct ggml_tensor* txt_img_mask = NULL;
struct ggml_tensor* txt_img_mask = nullptr;
if (params.is_chroma) {
int64_t mod_index_length = 344;
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
Expand All @@ -681,7 +682,7 @@ namespace Flux {

// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
GGML_ASSERT(arange != NULL);
GGML_ASSERT(arange != nullptr);
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]

// Batch broadcast (will it ever be useful)
Expand All @@ -695,15 +696,15 @@ namespace Flux {
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
vec = approx->forward(ctx, vec); // [344, N, hidden_size]

if (y != NULL) {
if (y != nullptr) {
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
}
} else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) {
GGML_ASSERT(guidance != NULL);
GGML_ASSERT(guidance != nullptr);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
Expand Down Expand Up @@ -775,14 +776,14 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
struct ggml_tensor* mod_index_arange = nullptr,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// c_concat: nullptr, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
Expand All @@ -801,7 +802,7 @@ namespace Flux {
uint64_t img_tokens = img->ne[1];

if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

Expand All @@ -810,7 +811,7 @@ namespace Flux {

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
GGML_ASSERT(c_concat != nullptr);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
Expand All @@ -825,7 +826,7 @@ namespace Flux {

img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);
GGML_ASSERT(c_concat != nullptr);

ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
control = patchify(ctx, control, patch_size);
Expand Down Expand Up @@ -924,7 +925,7 @@ namespace Flux {
flux.init(params_ctx, tensor_types, prefix);
}

std::string get_desc() {
std::string get_desc() override {
return "flux";
}

Expand All @@ -944,18 +945,18 @@ namespace Flux {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

struct ggml_tensor* mod_index_arange = NULL;
struct ggml_tensor* mod_index_arange = nullptr;

x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
if (c_concat != nullptr) {
c_concat = to_backend(c_concat);
}
if (flux_params.is_chroma) {
guidance = ggml_set_f32(guidance, 0);

if (!use_mask) {
y = NULL;
y = nullptr;
}

// ggml_arange is not working on some backends, precompute it
Expand Down Expand Up @@ -987,7 +988,7 @@ namespace Flux {
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
// pe->data = nullptr;
set_backend_tensor_data(pe, pe_vec.data());

struct ggml_tensor* out = flux.forward(compute_ctx,
Expand Down Expand Up @@ -1017,8 +1018,8 @@ namespace Flux {
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr,
std::vector<int> skip_layers = std::vector<int>()) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
Expand All @@ -1035,11 +1036,11 @@ namespace Flux {
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
params.mem_buffer = NULL;
params.mem_buffer = nullptr;
params.no_alloc = false;

struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL);
GGML_ASSERT(work_ctx != nullptr);

{
// cpu f16:
Expand All @@ -1063,10 +1064,10 @@ namespace Flux {
ggml_set_f32(y, 0.01f);
// print_ggml_tensor(y);

struct ggml_tensor* out = NULL;
struct ggml_tensor* out = nullptr;

int t0 = ggml_time_ms();
compute(8, x, timesteps, context, NULL, y, guidance, {}, false, &out, work_ctx);
compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand All @@ -1078,7 +1079,7 @@ namespace Flux {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend, false));
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend, false);
{
LOG_INFO("loading from '%s'", file_path.c_str());

Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 413 files
Loading