Skip to content

Commit c35baca

Browse files
daniandthewebursg
andcommitted
Support BREAK pseudo-token
Co-authored-by: Urs Ganse <urs.ganse@helsinki.fi>
1 parent 40a6a87 commit c35baca

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

conditioner.hpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
270270
const std::string& curr_text = item.first;
271271
float curr_weight = item.second;
272272
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
273-
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
274273
int32_t clean_index = 0;
274+
if(curr_text == "BREAK" && curr_weight == -1.0f) {
275+
// Pad token array up to chunk size at this point.
276+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
277+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
278+
int padding_size = 75 - (tokens_acc % 75);
279+
for (int j = 0; j < padding_size; j++) {
280+
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
281+
clean_index++;
282+
}
283+
284+
// After padding, continue to the next iteration to process the following text as a new segment
285+
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
286+
weights.insert(weights.end(), padding_size, curr_weight);
287+
continue;
288+
}
289+
290+
// Regular token, process normally
291+
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
275292
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
276293
int token_id = curr_tokens[i];
277-
if (token_id == image_token)
294+
if (token_id == image_token) {
278295
class_token_index.push_back(clean_index - 1);
279-
else {
296+
} else {
280297
clean_input_ids.push_back(token_id);
281298
clean_index++;
282299
}
@@ -379,6 +396,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
379396
for (const auto& item : parsed_attention) {
380397
const std::string& curr_text = item.first;
381398
float curr_weight = item.second;
399+
400+
if(curr_text == "BREAK" && curr_weight == -1.0f) {
401+
// Pad token array up to chunk size at this point.
402+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
403+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
404+
size_t current_size = tokens.size();
405+
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding
406+
407+
if (padding_size > 0) {
408+
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
409+
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
410+
weights.insert(weights.end(), padding_size, 1.0f);
411+
}
412+
continue; // Skip to the next item after handling BREAK
413+
}
414+
382415
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
383416
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
384417
weights.insert(weights.end(), curr_tokens.size(), curr_weight);

util.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <codecvt>
66
#include <fstream>
77
#include <locale>
8+
#include <regex>
89
#include <sstream>
910
#include <string>
1011
#include <thread>
@@ -548,7 +549,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
548549
float round_bracket_multiplier = 1.1f;
549550
float square_bracket_multiplier = 1 / 1.1f;
550551

551-
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
552+
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:B]+|:|\bB)");
552553
std::regex re_break(R"(\s*\bBREAK\b\s*)");
553554

554555
auto multiply_range = [&](int start_position, float multiplier) {
@@ -557,7 +558,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
557558
}
558559
};
559560

560-
std::smatch m;
561+
std::smatch m,m2;
561562
std::string remaining_text = text;
562563

563564
while (std::regex_search(remaining_text, m, re_attention)) {
@@ -581,6 +582,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
581582
square_brackets.pop_back();
582583
} else if (text == "\\(") {
583584
res.push_back({text.substr(1), 1.0f});
585+
} else if (std::regex_search(text, m2, re_break)) {
586+
res.push_back({"BREAK", -1.0f});
584587
} else {
585588
res.push_back({text, 1.0f});
586589
}
@@ -611,4 +614,4 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
611614
}
612615

613616
return res;
614-
}
617+
}

0 commit comments

Comments
 (0)