diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 02f4767b..56611f39 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -41,13 +41,15 @@ const char* modes_str[] = { "img_gen", "vid_gen", "convert", + "upscale", }; -#define SD_ALL_MODES_STR "img_gen, vid_gen, convert" +#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale" enum SDMode { IMG_GEN, VID_GEN, CONVERT, + UPSCALE, MODE_COUNT }; @@ -204,7 +206,7 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n"); + printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n"); printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n"); @@ -219,7 +221,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); - printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); @@ -817,13 +819,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.n_threads = get_num_physical_cores(); } - if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) { + if (params.mode == IMG_GEN && params.prompt.length() == 0) { fprintf(stderr, "error: the following arguments are required: prompt\n"); print_usage(argc, argv); exit(1); } - if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { + if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); print_usage(argc, argv); exit(1); @@ -883,6 +885,17 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (params.mode == UPSCALE) { + if (params.esrgan_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n"); + exit(1); + } + if( params.init_image_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n"); + exit(1); + } + } + if (params.seed < 0) { srand((int)time(NULL)); params.seed = rand(); @@ -1343,76 +1356,96 @@ int main(int argc, const char* argv[]) { params.flow_shift, }; - sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + sd_image_t* results; + int num_results = 1; - if (sd_ctx == NULL) { - printf("new_sd_ctx_t failed\n"); - release_all_resources(); - return 1; - } + if (params.mode == UPSCALE) { + // TODO image metadata should come from the original file - if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { - params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); - } + num_results = 1; + results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t)); + if (results == NULL) { + printf("failed to allocate results array\n"); + release_all_resources(); + return 1; + } - sd_image_t* results; - int num_results = 1; - if (params.mode == IMG_GEN) { - sd_img_gen_params_t img_gen_params = { - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - init_image, - ref_images.data(), - (int)ref_images.size(), - params.increase_ref_index, - mask_image, - params.width, - params.height, - params.sample_params, - params.strength, - params.seed, - params.batch_count, - control_image, - params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - params.pm_id_embed_path.c_str(), - params.pm_style_strength, - }, // pm_params - params.vae_tiling_params, - }; - - results = generate_image(sd_ctx, &img_gen_params); - num_results = params.batch_count; - } else if (params.mode == VID_GEN) { - sd_vid_gen_params_t vid_gen_params = { - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - init_image, - end_image, - control_frames.data(), - (int)control_frames.size(), - params.width, - params.height, - params.sample_params, - params.high_noise_sample_params, - params.moe_boundary, - params.strength, - params.seed, - params.video_frames, - params.vace_strength, - }; - - results = generate_video(sd_ctx, &vid_gen_params, &num_results); - } + results[0] = init_image; + init_image.data = NULL; + + } else { + + sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + + if (sd_ctx == NULL) { + printf("new_sd_ctx_t failed\n"); + release_all_resources(); + return 1; + } + + if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { + params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + } + + if (params.mode == IMG_GEN) { + sd_img_gen_params_t img_gen_params = { + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + init_image, + ref_images.data(), + (int)ref_images.size(), + params.increase_ref_index, + mask_image, + params.width, + params.height, + params.sample_params, + params.strength, + params.seed, + params.batch_count, + control_image, + params.control_strength, + { + pmid_images.data(), + (int)pmid_images.size(), + params.pm_id_embed_path.c_str(), + params.pm_style_strength, + }, // pm_params + params.vae_tiling_params, + }; + + results = generate_image(sd_ctx, &img_gen_params); + num_results = params.batch_count; + } else if (params.mode == VID_GEN) { + sd_vid_gen_params_t vid_gen_params = { + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + init_image, + end_image, + control_frames.data(), + (int)control_frames.size(), + params.width, + params.height, + params.sample_params, + params.high_noise_sample_params, + params.moe_boundary, + params.strength, + params.seed, + params.video_frames, + params.vace_strength, + }; + + results = generate_video(sd_ctx, &vid_gen_params, &num_results); + } + + if (results == NULL) { + printf("generate failed\n"); + free_sd_ctx(sd_ctx); + return 1; + } - if (results == NULL) { - printf("generate failed\n"); free_sd_ctx(sd_ctx); - return 1; } int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth @@ -1425,7 +1458,7 @@ int main(int argc, const char* argv[]) { if (upscaler_ctx == NULL) { printf("new_upscaler_ctx failed\n"); } else { - for (int i = 0; i < params.batch_count; i++) { + for (int i = 0; i < num_results; i++) { if (results[i].data == NULL) { continue; } @@ -1511,7 +1544,6 @@ int main(int argc, const char* argv[]) { results[i].data = NULL; } free(results); - free_sd_ctx(sd_ctx); release_all_resources();