Skip to content

Commit b494575

Browse files
Merge pull request #104 from menloresearch/update-dev-from-master-2025-05-27-00-08
Sync master with upstream release b5501
2 parents de7bfe2 + cdf94a1 commit b494575

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1041
-423
lines changed

common/arg.cpp

Lines changed: 113 additions & 93 deletions
Large diffs are not rendered by default.

common/chat-parser.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
170170
}
171171

172172
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
173-
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
173+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
174174
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
175175
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
176176
return std::nullopt;
177177
}
178+
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
179+
pos_ = m.groups[0].end;
180+
181+
if (add_prelude_to_content) {
182+
add_content(prelude);
183+
}
178184
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
179185
if (is_partial()) {
180186
throw common_chat_msg_partial_exception(regex.str());
181187
}
182188
return std::nullopt;
183189
}
184-
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
185-
pos_ = m.groups[0].end;
186-
187190
return find_regex_result{prelude, m.groups};
188191
}
189192

common/chat-parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class common_chat_msg_parser {
3030
const std::string & healing_marker() const { return healing_marker_; }
3131
const bool & is_partial() const { return is_partial_; }
3232
const common_chat_msg & result() const { return result_; }
33+
const common_chat_syntax & syntax() const { return syntax_; }
3334

3435
void move_to(size_t pos) {
3536
if (pos > input_.size()) {
@@ -77,7 +78,7 @@ class common_chat_msg_parser {
7778
std::vector<common_string_range> groups;
7879
};
7980

80-
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
81+
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8182

8283
bool try_consume_literal(const std::string & literal);
8384

common/chat.cpp

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur
3131
return current;
3232
}
3333
if (!string_starts_with(current, last)) {
34+
if (string_starts_with(last, current)) {
35+
// This happens if the last generation ended on a partial stop word (not erased),
36+
// and the current ended on a stop word (erased).
37+
return "";
38+
}
3439
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
3540
}
3641
return current.substr(last.size());
@@ -101,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
101106
if (!args_diff.empty() || pref.id != newf.id) {
102107
auto & diff = diffs.emplace_back();
103108
diff.tool_call_index = idx;
104-
diff.tool_call_delta.name = newf.name;
105109
if (pref.id != newf.id) {
106110
diff.tool_call_delta.id = newf.id;
111+
diff.tool_call_delta.name = newf.name;
107112
}
108113
diff.tool_call_delta.arguments = args_diff;
109114
}
@@ -387,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
387392
delta["content"] = diff.content_delta;
388393
}
389394
if (diff.tool_call_index != std::string::npos) {
395+
json tool_call;
396+
tool_call["index"] = diff.tool_call_index;
397+
if (!diff.tool_call_delta.id.empty()) {
398+
tool_call["id"] = diff.tool_call_delta.id;
399+
tool_call["type"] = "function";
400+
}
390401
json function = json::object();
391402
if (!diff.tool_call_delta.name.empty()) {
392403
function["name"] = diff.tool_call_delta.name;
393404
}
394-
if (!diff.tool_call_delta.id.empty()) {
395-
function["id"] = diff.tool_call_delta.id;
396-
}
397-
if (!diff.tool_call_delta.arguments.empty()) {
398-
function["arguments"] = diff.tool_call_delta.arguments;
399-
}
400-
delta["tool_calls"] = json::array({
401-
json {
402-
{"index", diff.tool_call_index},
403-
{"function", function}
404-
}
405-
});
405+
function["arguments"] = diff.tool_call_delta.arguments;
406+
tool_call["function"] = function;
407+
delta["tool_calls"] = json::array({tool_call});
406408
}
407409
return delta;
408410
}
@@ -654,7 +656,6 @@ static void parse_json_tool_calls(
654656
}
655657
from = std::string::npos;
656658

657-
builder.add_content(res->prelude);
658659
auto maybe_raw_python = name == "python" && allow_raw_python;
659660
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
660661
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
@@ -684,7 +685,6 @@ static void parse_json_tool_calls(
684685
};
685686
if (block_open) {
686687
if (auto res = builder.try_find_regex(*block_open)) {
687-
builder.add_content(res->prelude);
688688
parse_tool_calls();
689689
} else {
690690
builder.add_content(builder.consume_rest());
@@ -697,7 +697,6 @@ static void parse_json_tool_calls(
697697
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
698698
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
699699
if (auto res = builder.try_find_regex(prefix)) {
700-
builder.add_content(res->prelude);
701700
builder.move_back(rstrip_prefix);
702701
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
703702
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
@@ -833,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
833832
return data;
834833
}
835834
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
835+
if (!builder.syntax().parse_tool_calls) {
836+
builder.add_content(builder.consume_rest());
837+
return;
838+
}
836839
static const std::vector<std::vector<std::string>> content_paths = {
837840
{"response"},
838841
};
@@ -905,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
905908
return data;
906909
}
907910
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
911+
if (!builder.syntax().parse_tool_calls) {
912+
builder.add_content(builder.consume_rest());
913+
return;
914+
}
915+
908916
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
909917
parse_prefixed_json_tool_call_array(builder, prefix);
910918
}
@@ -999,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
9991007

10001008
if (auto res = builder.try_find_regex(start_action_regex)) {
10011009
// If we didn't extract thoughts, prelude includes them.
1002-
builder.add_content(res->prelude);
10031010
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
10041011
for (const auto & tool_call : tool_calls.value) {
10051012
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
@@ -1014,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
10141021
}
10151022
builder.consume_regex(end_action_regex);
10161023
} else if (auto res = builder.try_find_regex(start_response_regex)) {
1017-
// If we didn't extract thoughts, prelude includes them.
1018-
builder.add_content(res->prelude);
1019-
if (auto res = builder.try_find_regex(end_response_regex)) {
1020-
builder.add_content(res->prelude);
1021-
} else {
1024+
if (!builder.try_find_regex(end_response_regex)) {
10221025
builder.add_content(builder.consume_rest());
10231026
throw common_chat_msg_partial_exception(end_response_regex.str());
10241027
}
@@ -1126,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11261129
return data;
11271130
}
11281131
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
1132+
if (!builder.syntax().parse_tool_calls) {
1133+
builder.add_content(builder.consume_rest());
1134+
return;
1135+
}
1136+
11291137
static const common_regex function_regex(
11301138
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
11311139
static const common_regex close_regex("\\}\\s*");
@@ -1136,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11361144
if (with_builtin_tools) {
11371145
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
11381146
if (auto res = builder.try_find_regex(builtin_call_regex)) {
1139-
builder.add_content(res->prelude);
1140-
11411147
auto fun_res = builder.consume_regex(function_name_regex);
11421148
auto function_name = builder.str(fun_res.groups[1]);
11431149

@@ -1253,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
12531259
}
12541260
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12551261
builder.try_parse_reasoning("<think>", "</think>");
1262+
if (!builder.syntax().parse_tool_calls) {
1263+
builder.add_content(builder.consume_rest());
1264+
return;
1265+
}
12561266

12571267
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
12581268
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
@@ -1314,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
13141324
return data;
13151325
}
13161326
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
1327+
if (!builder.syntax().parse_tool_calls) {
1328+
builder.add_content(builder.consume_rest());
1329+
return;
1330+
}
13171331
static const common_regex prefix(regex_escape(" functools["));
13181332
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
13191333
}
@@ -1455,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14551469
return data;
14561470
}
14571471
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
1458-
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1459-
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
1460-
1461-
if (auto res = builder.try_find_regex(python_tag_regex)) {
1462-
builder.add_content(res->prelude);
1463-
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
1464-
builder.add_tool_call("python", "", arguments);
1472+
if (!builder.syntax().parse_tool_calls) {
1473+
builder.add_content(builder.consume_rest());
14651474
return;
14661475
}
1476+
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1477+
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
14671478

14681479
static const common_regex function_regex(R"(<function=(\w+)>)");
14691480
static const common_regex close_regex(R"(</function>)");
@@ -1475,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14751486
function_regex,
14761487
close_regex,
14771488
std::nullopt);
1489+
1490+
if (auto res = builder.try_find_regex(python_tag_regex)) {
1491+
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
1492+
builder.add_tool_call("python", "", arguments);
1493+
return;
1494+
}
14781495
}
14791496

14801497
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1593,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
15931610
}
15941611
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
15951612
builder.try_parse_reasoning("<think>", "</think>");
1613+
if (!builder.syntax().parse_tool_calls) {
1614+
builder.add_content(builder.consume_rest());
1615+
return;
1616+
}
15961617

15971618
static const common_regex open_regex(
15981619
"(?:"
@@ -1614,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16141635
);
16151636

16161637
if (auto res = builder.try_find_regex(open_regex)) {
1617-
builder.add_content(res->prelude);
1618-
16191638
const auto & block_start = res->groups[1];
16201639
std::string block_end = block_start.empty() ? "" : "```";
16211640

@@ -1851,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
18511870
builder.add_content(builder.consume_rest());
18521871
}
18531872

1854-
static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
1855-
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str());
1873+
static void common_chat_parse(common_chat_msg_parser & builder) {
1874+
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
18561875

1857-
switch (format) {
1876+
switch (builder.syntax().format) {
18581877
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
18591878
common_chat_parse_content_only(builder);
18601879
break;
@@ -1889,15 +1908,15 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
18891908
common_chat_parse_command_r7b(builder);
18901909
break;
18911910
default:
1892-
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format));
1911+
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
18931912
}
18941913
builder.finish();
18951914
}
18961915

18971916
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
18981917
common_chat_msg_parser builder(input, is_partial, syntax);
18991918
try {
1900-
common_chat_parse(builder, syntax.format);
1919+
common_chat_parse(builder);
19011920
} catch (const common_chat_msg_partial_exception & ex) {
19021921
LOG_DBG("Partial parse: %s\n", ex.what());
19031922
if (!is_partial) {

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ struct common_chat_syntax {
144144
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
145145
bool reasoning_in_content = false;
146146
bool thinking_forced_open = false;
147+
bool parse_tool_calls = true;
147148
};
148149

149150
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ struct common_params {
291291
int32_t verbosity = 0;
292292
int32_t control_vector_layer_start = -1; // layer range for control vector
293293
int32_t control_vector_layer_end = -1; // layer range for control vector
294+
bool offline = false;
294295

295296
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
296297
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

docs/backend/CANN.md

100644100755
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ cmake --build build --config release
280280
### **GitHub contribution**:
281281
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
282282

283+
## Updates
284+
### Basic Flash Attention Support
285+
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
286+
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
287+
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.
288+
289+
Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn).
290+
291+
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
283292

284293
## TODO
285294
- Support more models and data types.

examples/embedding/embedding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4141

4242
// run model
4343
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
44-
if (llama_encode(ctx, batch) < 0) {
45-
LOG_ERR("%s : failed to encode\n", __func__);
44+
if (llama_decode(ctx, batch) < 0) {
45+
LOG_ERR("%s : failed to process\n", __func__);
4646
}
4747

4848
for (int i = 0; i < batch.n_tokens; i++) {

examples/retrieval/retrieval.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
8181
}
8282
}
8383

84-
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
84+
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
8585
// clear previous kv_cache values (irrelevant for embeddings)
8686
llama_kv_self_clear(ctx);
8787

8888
// run model
8989
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
90-
if (llama_encode(ctx, batch) < 0) {
91-
LOG_ERR("%s : failed to encode\n", __func__);
90+
if (llama_decode(ctx, batch) < 0) {
91+
LOG_ERR("%s : failed to process\n", __func__);
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
233233
// encode if at capacity
234234
if (batch.n_tokens + n_toks > n_batch) {
235235
float * out = emb + p * n_embd;
236-
batch_encode(ctx, batch, out, s, n_embd);
236+
batch_process(ctx, batch, out, s, n_embd);
237237
common_batch_clear(batch);
238238
p += s;
239239
s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
246246

247247
// final batch
248248
float * out = emb + p * n_embd;
249-
batch_encode(ctx, batch, out, s, n_embd);
249+
batch_process(ctx, batch, out, s, n_embd);
250250

251251
// save embeddings to chunks
252252
for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267267
batch_add_seq(query_batch, query_tokens, 0);
268268

269269
std::vector<float> query_emb(n_embd, 0);
270-
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
270+
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
271271

272272
common_batch_clear(query_batch);
273273

examples/training/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Proof of concept:
1010

1111
``` sh
1212
export model_name=llama_3.2-1b && export quantization=f32
13-
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14-
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
13+
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14+
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
1515
```
1616

1717
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

0 commit comments

Comments
 (0)