Skip to content

Commit 4e3f76c

Browse files
committed
add easycache
1 parent 8f6c5c2 commit 4e3f76c

File tree

3 files changed

+449
-10
lines changed

3 files changed

+449
-10
lines changed

examples/cli/main.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ struct SDParams {
9898
std::vector<int> high_noise_skip_layers = {7, 8, 9};
9999
sd_sample_params_t high_noise_sample_params;
100100

101+
std::string easycache_option;
102+
sd_easycache_params_t easycache_params;
103+
101104
float moe_boundary = 0.875f;
102105
int video_frames = 1;
103106
int fps = 16;
@@ -139,6 +142,7 @@ struct SDParams {
139142
sd_sample_params_init(&sample_params);
140143
sd_sample_params_init(&high_noise_sample_params);
141144
high_noise_sample_params.sample_steps = -1;
145+
sd_easycache_params_init(&easycache_params);
142146
}
143147
};
144148

@@ -208,6 +212,11 @@ void print_params(SDParams params) {
208212
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
209213
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
210214
printf(" video_frames: %d\n", params.video_frames);
215+
printf(" easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
216+
params.easycache_params.enabled ? "enabled" : "disabled",
217+
params.easycache_params.reuse_threshold,
218+
params.easycache_params.start_percent,
219+
params.easycache_params.end_percent);
211220
printf(" vace_strength: %.2f\n", params.vace_strength);
212221
printf(" fps: %d\n", params.fps);
213222
free(sample_params_str);
@@ -593,6 +602,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
593602
"--upscale-model",
594603
"path to esrgan model.",
595604
&params.esrgan_path},
605+
{"",
606+
"--easycache",
607+
"enable EasyCache for DiT models with \"threshold,start_percent,end_percent\" (example: 0.2,0.15,0.95)",
608+
&params.easycache_option},
596609
};
597610

598611
options.int_options = {
@@ -1117,6 +1130,59 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11171130
exit(1);
11181131
}
11191132

1133+
if (!params.easycache_option.empty()) {
1134+
float values[3] = {0.0f, 0.0f, 0.0f};
1135+
std::stringstream ss(params.easycache_option);
1136+
std::string token;
1137+
int idx = 0;
1138+
while (std::getline(ss, token, ',')) {
1139+
auto trim = [](std::string& s) {
1140+
const char* whitespace = " \t\r\n";
1141+
auto start = s.find_first_not_of(whitespace);
1142+
if (start == std::string::npos) {
1143+
s.clear();
1144+
return;
1145+
}
1146+
auto end = s.find_last_not_of(whitespace);
1147+
s = s.substr(start, end - start + 1);
1148+
};
1149+
trim(token);
1150+
if (token.empty()) {
1151+
fprintf(stderr, "error: invalid easycache option '%s'\n", params.easycache_option.c_str());
1152+
exit(1);
1153+
}
1154+
if (idx >= 3) {
1155+
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1156+
exit(1);
1157+
}
1158+
try {
1159+
values[idx] = std::stof(token);
1160+
} catch (const std::exception&) {
1161+
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
1162+
exit(1);
1163+
}
1164+
idx++;
1165+
}
1166+
if (idx != 3) {
1167+
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1168+
exit(1);
1169+
}
1170+
if (values[0] < 0.0f) {
1171+
fprintf(stderr, "error: easycache threshold must be non-negative\n");
1172+
exit(1);
1173+
}
1174+
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1175+
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1176+
exit(1);
1177+
}
1178+
params.easycache_params.enabled = true;
1179+
params.easycache_params.reuse_threshold = values[0];
1180+
params.easycache_params.start_percent = values[1];
1181+
params.easycache_params.end_percent = values[2];
1182+
} else {
1183+
params.easycache_params.enabled = false;
1184+
}
1185+
11201186
if (params.n_threads <= 0) {
11211187
params.n_threads = get_num_physical_cores();
11221188
}
@@ -1716,6 +1782,7 @@ int main(int argc, const char* argv[]) {
17161782
params.pm_style_strength,
17171783
}, // pm_params
17181784
params.vae_tiling_params,
1785+
params.easycache_params,
17191786
};
17201787

17211788
results = generate_image(sd_ctx, &img_gen_params);
@@ -1738,6 +1805,7 @@ int main(int argc, const char* argv[]) {
17381805
params.seed,
17391806
params.video_frames,
17401807
params.vace_strength,
1808+
params.easycache_params,
17411809
};
17421810

17431811
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

0 commit comments

Comments
 (0)