Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 104 additions & 72 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ const char* modes_str[] = {
"img_gen",
"vid_gen",
"convert",
"upscale",
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert"
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"

enum SDMode {
IMG_GEN,
VID_GEN,
CONVERT,
UPSCALE,
MODE_COUNT
};

Expand Down Expand Up @@ -204,7 +206,7 @@ void print_usage(int argc, const char* argv[]) {
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n");
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n");
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n");
Expand All @@ -219,7 +221,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
printf(" If not specified, the default is the type of the weight file\n");
Expand Down Expand Up @@ -817,13 +819,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.n_threads = get_num_physical_cores();
}

if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) {
if (params.mode == IMG_GEN && params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
}

if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
print_usage(argc, argv);
exit(1);
Expand Down Expand Up @@ -883,6 +885,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}

if (params.mode == UPSCALE) {
if (params.esrgan_path.length() == 0) {
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
exit(1);
}
if( params.init_image_path.length() == 0) {
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
exit(1);
}
}

if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
Expand Down Expand Up @@ -1343,76 +1356,96 @@ int main(int argc, const char* argv[]) {
params.flow_shift,
};

sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
sd_image_t* results;
int num_results = 1;

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
release_all_resources();
return 1;
}
if (params.mode == UPSCALE) {
// TODO image metadata should come from the original file

if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
num_results = 1;
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
if (results == NULL) {
printf("failed to allocate results array\n");
release_all_resources();
return 1;
}

sd_image_t* results;
int num_results = 1;
if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
params.increase_ref_index,
mask_image,
params.width,
params.height,
params.sample_params,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
params.pm_id_embed_path.c_str(),
params.pm_style_strength,
}, // pm_params
params.vae_tiling_params,
};

results = generate_image(sd_ctx, &img_gen_params);
num_results = params.batch_count;
} else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
end_image,
control_frames.data(),
(int)control_frames.size(),
params.width,
params.height,
params.sample_params,
params.high_noise_sample_params,
params.moe_boundary,
params.strength,
params.seed,
params.video_frames,
params.vace_strength,
};

results = generate_video(sd_ctx, &vid_gen_params, &num_results);
}
results[0] = init_image;
init_image.data = NULL;

} else {

sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
release_all_resources();
return 1;
}

if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}

if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
params.increase_ref_index,
mask_image,
params.width,
params.height,
params.sample_params,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
params.pm_id_embed_path.c_str(),
params.pm_style_strength,
}, // pm_params
params.vae_tiling_params,
};

results = generate_image(sd_ctx, &img_gen_params);
num_results = params.batch_count;
} else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
end_image,
control_frames.data(),
(int)control_frames.size(),
params.width,
params.height,
params.sample_params,
params.high_noise_sample_params,
params.moe_boundary,
params.strength,
params.seed,
params.video_frames,
params.vace_strength,
};

results = generate_video(sd_ctx, &vid_gen_params, &num_results);
}

if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}

if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}

int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
Expand All @@ -1425,7 +1458,7 @@ int main(int argc, const char* argv[]) {
if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");
} else {
for (int i = 0; i < params.batch_count; i++) {
for (int i = 0; i < num_results; i++) {
if (results[i].data == NULL) {
continue;
}
Expand Down Expand Up @@ -1511,7 +1544,6 @@ int main(int argc, const char* argv[]) {
results[i].data = NULL;
}
free(results);
free_sd_ctx(sd_ctx);

release_all_resources();

Expand Down
Loading