Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 188 additions & 24 deletions esrgan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,39 +83,44 @@ class RRDB : public GGMLBlock {

class RRDBNet : public GGMLBlock {
protected:
int scale = 4; // default RealESRGAN_x4plus_anime_6B
int num_block = 6; // default RealESRGAN_x4plus_anime_6B
int scale = 4;
int num_block = 23;
int num_in_ch = 3;
int num_out_ch = 3;
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
int num_feat = 64;
int num_grow_ch = 32;

public:
RRDBNet() {
RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
: scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
for (int i = 0; i < num_block; i++) {
std::string name = "body." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
}
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
// upsample
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
if (scale >= 2) {
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
if (scale == 4) {
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
}

int get_scale() { return scale; }
int get_num_block() { return num_block; }

struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, num_in_ch, h, w]
// return: [n, num_out_ch, h*4, w*4]
// return: [n, num_out_ch, h*scale, w*scale]
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]);
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);

Expand All @@ -130,28 +135,36 @@ class RRDBNet : public GGMLBlock {
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx, feat, body_feat);
// upsample
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
if (scale >= 2) {
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
if (scale == 4) {
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
}
}
// for all scales
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
return out;
}
};

struct ESRGAN : public GGMLRunner {
RRDBNet rrdb_net;
std::unique_ptr<RRDBNet> rrdb_net;
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM

ESRGAN(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
: GGMLRunner(backend, offload_params_to_cpu) {
rrdb_net.init(params_ctx, tensor_types, "");
// rrdb_net will be created in load_from_file
}

void enable_conv2d_direct() {
if (!rrdb_net) return;
std::vector<GGMLBlock*> blocks;
rrdb_net.get_all_blocks(blocks);
rrdb_net->get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
Expand All @@ -167,31 +180,182 @@ struct ESRGAN : public GGMLRunner {
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading esrgan from '%s'", file_path.c_str());

alloc_params_buffer();
std::map<std::string, ggml_tensor*> esrgan_tensors;
rrdb_net.get_param_tensors(esrgan_tensors);

ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
return false;
}

bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
// Get tensor names
auto tensor_names = model_loader.get_tensor_names();

// Detect if it's ESRGAN format
bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end();

// Detect parameters from tensor names
int detected_num_block = 0;
if (is_ESRGAN) {
for (const auto& name : tensor_names) {
if (name.find("model.1.sub.") == 0) {
size_t first_dot = name.find('.', 12);
if (first_dot != std::string::npos) {
size_t second_dot = name.find('.', first_dot + 1);
if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") {
try {
int idx = std::stoi(name.substr(12, first_dot - 12));
detected_num_block = std::max(detected_num_block, idx + 1);
} catch (...) {}
}
}
}
}
} else {
// Original format
for (const auto& name : tensor_names) {
if (name.find("body.") == 0) {
size_t pos = name.find('.', 5);
if (pos != std::string::npos) {
try {
int idx = std::stoi(name.substr(5, pos - 5));
detected_num_block = std::max(detected_num_block, idx + 1);
} catch (...) {}
}
}
}
}

int detected_scale = 4; // default
if (is_ESRGAN) {
// For ESRGAN format, detect scale by highest model number
int max_model_num = 0;
for (const auto& name : tensor_names) {
if (name.find("model.") == 0) {
size_t dot_pos = name.find('.', 6);
if (dot_pos != std::string::npos) {
try {
int num = std::stoi(name.substr(6, dot_pos - 6));
max_model_num = std::max(max_model_num, num);
} catch (...) {}
}
}
}
if (max_model_num <= 4) {
detected_scale = 1;
} else if (max_model_num <= 7) {
detected_scale = 2;
} else {
detected_scale = 4;
}
} else {
// Original format
bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
return name == "conv_up2.weight";
});
bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
return name == "conv_up1.weight";
});
if (has_conv_up2) {
detected_scale = 4;
} else if (has_conv_up1) {
detected_scale = 2;
} else {
detected_scale = 1;
}
}

int detected_num_in_ch = 3;
int detected_num_out_ch = 3;
int detected_num_feat = 64;
int detected_num_grow_ch = 32;

// Create RRDBNet with detected parameters
rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
rrdb_net->init(params_ctx, {}, "");

alloc_params_buffer();
std::map<std::string, ggml_tensor*> esrgan_tensors;
rrdb_net->get_param_tensors(esrgan_tensors);

bool success;
if (is_ESRGAN) {
// Build name mapping for ESRGAN format
std::map<std::string, std::string> expected_to_model;
expected_to_model["conv_first.weight"] = "model.0.weight";
expected_to_model["conv_first.bias"] = "model.0.bias";

for (int i = 0; i < detected_num_block; i++) {
for (int j = 1; j <= 3; j++) {
for (int k = 1; k <= 5; k++) {
std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight";
std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight";
expected_to_model[expected_weight] = model_weight;

std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias";
std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias";
expected_to_model[expected_bias] = model_bias;
}
}
}

if (detected_scale == 1) {
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
expected_to_model["conv_hr.weight"] = "model.2.weight";
expected_to_model["conv_hr.bias"] = "model.2.bias";
expected_to_model["conv_last.weight"] = "model.4.weight";
expected_to_model["conv_last.bias"] = "model.4.bias";
} else {
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
if (detected_scale >= 2) {
expected_to_model["conv_up1.weight"] = "model.3.weight";
expected_to_model["conv_up1.bias"] = "model.3.bias";
}
if (detected_scale == 4) {
expected_to_model["conv_up2.weight"] = "model.6.weight";
expected_to_model["conv_up2.bias"] = "model.6.bias";
expected_to_model["conv_hr.weight"] = "model.8.weight";
expected_to_model["conv_hr.bias"] = "model.8.bias";
expected_to_model["conv_last.weight"] = "model.10.weight";
expected_to_model["conv_last.bias"] = "model.10.bias";
} else if (detected_scale == 2) {
expected_to_model["conv_hr.weight"] = "model.5.weight";
expected_to_model["conv_hr.bias"] = "model.5.bias";
expected_to_model["conv_last.weight"] = "model.7.weight";
expected_to_model["conv_last.bias"] = "model.7.bias";
}
}

std::map<std::string, ggml_tensor*> model_tensors;
for (auto& p : esrgan_tensors) {
auto it = expected_to_model.find(p.first);
if (it != expected_to_model.end()) {
model_tensors[it->second] = p.second;
}
}

success = model_loader.load_tensors(model_tensors,{}, n_threads);
} else {

success = model_loader.load_tensors(esrgan_tensors,{}, n_threads);
}

if (!success) {
LOG_ERROR("load esrgan tensors from model loader failed");
return false;
}

LOG_INFO("esrgan model loaded");
scale = rrdb_net->get_scale();
LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block);
return success;
}

struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
if (!rrdb_net) return nullptr;
constexpr int kGraphNodes = 1 << 16; // 65k
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
x = to_backend(x);
struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x);
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
ggml_build_forward_expand(gf, out);
return gf;
}
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 52 files
+2 −2 CMakeLists.txt
+1 −2 include/ggml-backend.h
+0 −3 include/ggml-zdnn.h
+183 −279 scripts/release.sh
+1 −1 scripts/sync-llama.last
+126 −264 src/ggml-alloc.c
+0 −8 src/ggml-backend.cpp
+2 −2 src/ggml-cpu/arch/x86/repack.cpp
+3 −17 src/ggml-cpu/ggml-cpu.c
+2 −10 src/ggml-cpu/ops.cpp
+1 −1 src/ggml-cuda/ggml-cuda.cu
+32 −40 src/ggml-cuda/set-rows.cu
+10 −43 src/ggml-impl.h
+18 −6 src/ggml-metal/ggml-metal-common.cpp
+38 −29 src/ggml-metal/ggml-metal-device.cpp
+3 −2 src/ggml-metal/ggml-metal-device.h
+2 −2 src/ggml-metal/ggml-metal-device.m
+9 −4 src/ggml-metal/ggml-metal-impl.h
+160 −109 src/ggml-metal/ggml-metal-ops.cpp
+1 −0 src/ggml-metal/ggml-metal-ops.h
+119 −179 src/ggml-metal/ggml-metal.metal
+0 −4 src/ggml-opencl/CMakeLists.txt
+23 −409 src/ggml-opencl/ggml-opencl.cpp
+2 −40 src/ggml-opencl/kernels/cvt.cl
+0 −140 src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl
+0 −222 src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl
+0 −125 src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl
+0 −202 src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl
+2 −96 src/ggml-opencl/kernels/set_rows.cl
+0 −1 src/ggml-quants.c
+35 −41 src/ggml-rpc/ggml-rpc.cpp
+1 −1 src/ggml-sycl/ggml-sycl.cpp
+33 −43 src/ggml-sycl/set_rows.cpp
+50 −262 src/ggml-vulkan/ggml-vulkan.cpp
+2 −22 src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
+2 −9 src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
+0 −1 src/ggml-vulkan/vulkan-shaders/exp.comp
+8 −20 src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+9 −9 src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp
+13 −28 src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+1 −1 src/ggml-webgpu/ggml-webgpu.cpp
+0 −1 src/ggml-zdnn/.gitignore
+0 −59 src/ggml-zdnn/common.hpp
+98 −0 src/ggml-zdnn/ggml-zdnn-impl.h
+168 −19 src/ggml-zdnn/ggml-zdnn.cpp
+0 −80 src/ggml-zdnn/mmf.cpp
+0 −12 src/ggml-zdnn/mmf.hpp
+0 −79 src/ggml-zdnn/utils.cpp
+0 −19 src/ggml-zdnn/utils.hpp
+1 −1 src/ggml.c
+12 −30 tests/test-backend-ops.cpp
+10 −1 tests/test-quantize-perf.cpp
8 changes: 8 additions & 0 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ class ModelLoader {
std::set<std::string> ignore_tensors = {},
int n_threads = 0);

std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names;
for (const auto& ts : tensor_storages) {
names.push_back(ts.name);
}
return names;
}

bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
Expand Down
2 changes: 2 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
sd_image_t input_image,
uint32_t upscale_factor);

SD_API int get_upscale_factor(upscaler_ctx_t* upscaler_ctx);

SD_API bool convert(const char* input_path,
const char* vae_path,
const char* output_path,
Expand Down
7 changes: 7 additions & 0 deletions upscaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_
return upscaler_ctx->upscaler->upscale(input_image, upscale_factor);
}

int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) {
if (upscaler_ctx == NULL || upscaler_ctx->upscaler == NULL || upscaler_ctx->upscaler->esrgan_upscaler == NULL) {
return 1;
}
return upscaler_ctx->upscaler->esrgan_upscaler->scale;
}

void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
if (upscaler_ctx->upscaler != NULL) {
delete upscaler_ctx->upscaler;
Expand Down