Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 23 additions & 135 deletions src/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
parse_extra_sample_args(extra_sample_args);
}

static std::string trim(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}

void parse_extra_sample_args(const char* extra_sample_args) {
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return;
}

std::string raw(extra_sample_args);
size_t start = 0;
auto parse_arg = [&](const std::string& item) {
std::string token = trim(item);
if (token.empty()) {
return;
}
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
return;
}

std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
auto parse_float = [&](float* out) -> bool {
try {
size_t consumed = 0;
float parsed = std::stof(value, &consumed);
if (!trim(value.substr(consumed)).empty()) {
return false;
}
*out = parsed;
return true;
} catch (const std::exception&) {
return false;
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
if (key == "max_shift") {
if (!parse_strict_float(value, max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
};
try {
if (key == "max_shift") {
if (!parse_float(&max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "base_shift") {
if (!parse_float(&base_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "terminal") {
if (!parse_float(&terminal)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "stretch") {
std::string v = value;
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
if (v == "1" || v == "true" || v == "yes" || v == "on") {
stretch = true;
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
stretch = false;
} else {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
} else if (key == "base_shift") {
if (!parse_strict_float(value, base_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
};

for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
parse_arg(raw.substr(start, pos - start));
start = pos + 1;
} else if (key == "terminal") {
if (!parse_strict_float(value, terminal)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "stretch") {
if (!parse_strict_bool(value, stretch)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
}
}
}
Expand Down Expand Up @@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
return x;
}

using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
using SamplerExtraArgs = KeyValueArgs;

static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x,
Expand All @@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,

for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "noise_clip_std") {
Expand Down Expand Up @@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,

for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "gamma") {
Expand Down Expand Up @@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
return x;
}

static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
SamplerExtraArgs pairs;

if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return pairs;
}

auto trim = [](std::string value) -> std::string {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
};

std::string raw(extra_sample_args);
size_t start = 0;

for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
std::string item = raw.substr(start, pos - start);
std::string token = trim(item);

if (!token.empty()) {
size_t eq = token.find('=');
if (eq != std::string::npos) {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
}

return pairs;
}

// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
Expand All @@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float eta,
bool is_flow_denoiser,
const char* extra_sample_args) {
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
switch (method) {
case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
Expand Down
63 changes: 10 additions & 53 deletions src/ltx_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,65 +1251,22 @@ struct LTXVideoVAE : public VAE {
temporal_tiling_enabled = enabled;
}

static std::string trim_tiling_arg(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}

static bool parse_tiling_int(const std::string& value, int& parsed) {
try {
size_t consumed = 0;
parsed = std::stoi(value, &consumed);
return trim_tiling_arg(value.substr(consumed)).empty();
} catch (...) {
return false;
}
}

void set_tiling_params(const sd_tiling_params_t& params) override {
temporal_tiling_enabled = params.temporal_tiling;
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;

const char* extra_tiling_args = params.extra_tiling_args;
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
return;
}

std::string raw(extra_tiling_args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
continue;
}

std::string token = trim_tiling_arg(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
} else {
std::string key = trim_tiling_arg(token.substr(0, eq));
std::string value = trim_tiling_arg(token.substr(eq + 1));
int parsed = 0;
if (!parse_tiling_int(value, parsed)) {
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
temporal_tile_frames = std::max(1, parsed);
} else if (key == "temporal_tile_overlap") {
temporal_tile_overlap = std::max(0, parsed);
} else {
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
}
}
for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
int parsed = 0;
if (!parse_strict_int(value, parsed)) {
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
temporal_tile_frames = std::max(1, parsed);
} else if (key == "temporal_tile_overlap") {
temporal_tile_overlap = std::max(0, parsed);
} else {
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
}

start = pos + 1;
}
}

Expand Down
84 changes: 84 additions & 0 deletions src/util.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "util.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <codecvt>
#include <cstdarg>
#include <exception>
#include <fstream>
#include <locale>
#include <regex>
Expand Down Expand Up @@ -406,6 +408,88 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
return result;
}

KeyValueArgs parse_key_value_args(const char* args, const char* context) {
KeyValueArgs pairs;

if (args == nullptr || args[0] == '\0') {
return pairs;
}

std::string raw(args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
continue;
}

std::string token = trim(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
const char* log_context = context ? context : "key=value arg";
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
} else {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}

start = pos + 1;
}

return pairs;
}

KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
return parse_key_value_args(args.c_str(), context);
}

bool parse_strict_float(const std::string& text, float& value) {
try {
size_t consumed = 0;
float parsed = std::stof(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}

bool parse_strict_int(const std::string& text, int& value) {
try {
size_t consumed = 0;
int parsed = std::stoi(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}

bool parse_strict_bool(const std::string& text, bool& value) {
std::string lowered = trim(text);
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});

if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
value = true;
return true;
}
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
value = false;
return true;
}
return false;
}

static std::string build_progress_bar(int step, int steps) {
std::string progress = " |";
int max_progress = 50;
Expand Down
Loading
Loading