Skip to content

Commit

Permalink
Better check_chat_messages & encode_messages
Browse files Browse the repository at this point in the history
* Check USER and ASSISTANT messages only

* ignore .venv

* Make encode_messages never throw with system prompt

* Disable Apple M1 chips on CI

---------

Co-authored-by: Jiahao Li <liplus17@163.com>
  • Loading branch information
sswater and li-plus committed Apr 28, 2024
1 parent a46f474 commit d0f45ba
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, windows-latest, macos-13]

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dist/
*.so
*.whl
.hypothesis/
.venv

# cpp
build/
Expand Down
48 changes: 34 additions & 14 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,28 @@ const std::string ChatMessage::ROLE_SYSTEM = "system";
const std::string ChatMessage::ROLE_OBSERVATION = "observation";

void BaseTokenizer::check_chat_messages(const std::vector<ChatMessage> &messages) {
CHATGLM_CHECK(messages.size() % 2 == 1) << "invalid chat messages size " << messages.size();
std::string target_role = ChatMessage::ROLE_USER;
for (size_t i = 0; i < messages.size(); i++) {
const std::string &target_role = (i % 2 == 0) ? ChatMessage::ROLE_USER : ChatMessage::ROLE_ASSISTANT;
if (messages[i].role != ChatMessage::ROLE_USER && messages[i].role != ChatMessage::ROLE_ASSISTANT) {
continue;
}
CHATGLM_CHECK(messages[i].role == target_role)
<< "expect messages[" << i << "].role to be " << target_role << ", but got " << messages[i].role;
target_role = (target_role == ChatMessage::ROLE_USER) ? ChatMessage::ROLE_ASSISTANT : ChatMessage::ROLE_USER;
}
CHATGLM_CHECK(target_role == ChatMessage::ROLE_ASSISTANT)
<< "expect last message role to be " << ChatMessage::ROLE_USER << ", but got " << ChatMessage::ROLE_ASSISTANT;
}

std::vector<ChatMessage> BaseTokenizer::filter_user_assistant_messages(const std::vector<ChatMessage> &messages) {
std::vector<ChatMessage> user_assistant_messages;
user_assistant_messages.reserve(messages.size());
for (const auto &msg : messages) {
if (msg.role == ChatMessage::ROLE_USER || msg.role == ChatMessage::ROLE_ASSISTANT) {
user_assistant_messages.emplace_back(msg);
}
}
return user_assistant_messages;
}

// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp
Expand Down Expand Up @@ -998,15 +1014,16 @@ std::vector<int> ChatGLMTokenizer::encode_messages(const std::vector<ChatMessage

std::string ChatGLMTokenizer::build_prompt(const std::vector<ChatMessage> &messages) {
check_chat_messages(messages);
std::vector<ChatMessage> user_assistant_messages = filter_user_assistant_messages(messages);

std::ostringstream oss_prompt;
if (messages.size() == 1) {
oss_prompt << messages.front().content;
if (user_assistant_messages.size() == 1) {
oss_prompt << user_assistant_messages.front().content;
} else {
for (size_t i = 0; i < messages.size(); i += 2) {
oss_prompt << "[Round " << i / 2 << "]\n问:" << messages[i].content << "\n答:";
if (i + 1 < messages.size()) {
oss_prompt << messages[i + 1].content << "\n";
for (size_t i = 0; i < user_assistant_messages.size(); i += 2) {
oss_prompt << "[Round " << i / 2 << "]\n问:" << user_assistant_messages[i].content << "\n答:";
if (i + 1 < user_assistant_messages.size()) {
oss_prompt << user_assistant_messages[i + 1].content << "\n";
}
}
}
Expand Down Expand Up @@ -1223,12 +1240,13 @@ std::vector<int> ChatGLM2Tokenizer::encode_messages(const std::vector<ChatMessag

std::string ChatGLM2Tokenizer::build_prompt(const std::vector<ChatMessage> &messages) {
check_chat_messages(messages);
std::vector<ChatMessage> user_assistant_messages = filter_user_assistant_messages(messages);

std::ostringstream oss_prompt;
for (size_t i = 0; i < messages.size(); i += 2) {
oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << messages[i].content << "\n\n答:";
if (i < messages.size() - 1) {
oss_prompt << messages[i + 1].content << "\n\n";
for (size_t i = 0; i < user_assistant_messages.size(); i += 2) {
oss_prompt << "[Round " << i / 2 + 1 << "]\n\n问:" << user_assistant_messages[i].content << "\n\n答:";
if (i < user_assistant_messages.size() - 1) {
oss_prompt << user_assistant_messages[i + 1].content << "\n\n";
}
}
return oss_prompt.str();
Expand Down Expand Up @@ -1528,10 +1546,11 @@ std::string BaichuanTokenizer::decode(const std::vector<int> &ids) const {

std::vector<int> BaichuanTokenizer::encode_messages(const std::vector<ChatMessage> &messages, int max_length) const {
check_chat_messages(messages);
std::vector<ChatMessage> user_assistant_messages = filter_user_assistant_messages(messages);

std::vector<int> ids;
ids.reserve(max_length);
for (const auto &msg : messages) {
for (const auto &msg : user_assistant_messages) {
ids.push_back((msg.role == ChatMessage::ROLE_USER) ? USER_TOKEN_ID : ASSISTANT_TOKEN_ID);
std::vector<int> content_ids = encode(msg.content, max_length);
ids.insert(ids.end(), content_ids.begin(), content_ids.end());
Expand Down Expand Up @@ -1677,9 +1696,10 @@ std::vector<int> InternLMTokenizer::encode_messages(const std::vector<ChatMessag

std::string InternLMTokenizer::build_prompt(const std::vector<ChatMessage> &messages) {
check_chat_messages(messages);
std::vector<ChatMessage> user_assistant_messages = filter_user_assistant_messages(messages);

std::ostringstream oss_prompt;
for (const auto &msg : messages) {
for (const auto &msg : user_assistant_messages) {
if (msg.role == ChatMessage::ROLE_USER) {
oss_prompt << "<|User|>:" << msg.content << "<eoh>\n<|Bot|>:";
} else {
Expand Down
2 changes: 2 additions & 0 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class BaseTokenizer {

protected:
static void check_chat_messages(const std::vector<ChatMessage> &messages);

static std::vector<ChatMessage> filter_user_assistant_messages(const std::vector<ChatMessage> &messages);
};

struct ggml_context_deleter_t {
Expand Down
25 changes: 25 additions & 0 deletions chatglm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,24 @@ static void check_tokenizer(const BaseTokenizer *tokenizer, const std::vector<To
}
}

static void check_chat_format(const Pipeline &pipeline) {
GenerationConfig gen_config;
gen_config.max_new_tokens = 1;
EXPECT_THROW(
{
pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_USER, "user"}}, gen_config);
},
std::runtime_error);
EXPECT_THROW({ pipeline.chat({{ChatMessage::ROLE_ASSISTANT, "assistant"}}, gen_config); }, std::runtime_error);
EXPECT_THROW(
{
pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_ASSISTANT, "assistant"}}, gen_config);
},
std::runtime_error);
// never throw with system prompt
pipeline.chat({{ChatMessage::ROLE_SYSTEM, "system"}, {ChatMessage::ROLE_USER, "user"}}, gen_config);
}

TEST(Pipeline, ChatGLM) {
fs::path model_path = fs::path(__FILE__).parent_path() / "chatglm-ggml.bin";
if (!fs::exists(model_path)) {
Expand Down Expand Up @@ -1029,6 +1047,7 @@ TEST(Pipeline, ChatGLM) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
std::vector<ChatMessage> messages{{ChatMessage::ROLE_USER, "你好"}};
Expand Down Expand Up @@ -1093,6 +1112,7 @@ TEST(Pipeline, ChatGLM2) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
std::vector<ChatMessage> messages{{ChatMessage::ROLE_USER, "你好"}};
Expand Down Expand Up @@ -1189,6 +1209,7 @@ TEST(Pipeline, ChatGLM3) {

// chat
{
// check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
std::vector<ChatMessage> messages{{ChatMessage::ROLE_USER, "你好"}};
Expand Down Expand Up @@ -1359,6 +1380,7 @@ TEST(Pipeline, Baichuan13B) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
gen_config.repetition_penalty = 1.1;
Expand Down Expand Up @@ -1413,6 +1435,7 @@ TEST(Pipeline, Baichuan2_7B) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
gen_config.repetition_penalty = 1.05;
Expand Down Expand Up @@ -1455,6 +1478,7 @@ TEST(Pipeline, Baichuan2_13B) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
gen_config.repetition_penalty = 1.05;
Expand Down Expand Up @@ -1510,6 +1534,7 @@ TEST(Pipeline, InternLM) {

// chat
{
check_chat_format(pipeline);
GenerationConfig gen_config;
gen_config.do_sample = false;
std::vector<ChatMessage> messages{{ChatMessage::ROLE_USER, "你好"}};
Expand Down
1 change: 0 additions & 1 deletion tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def make_data_glm2_model():


def make_data_glm3_model():

def _forward_steps(model, seq_len):
# self attention
x1 = torch.arange(seq_len, dtype=torch.int64)[None, :]
Expand Down

0 comments on commit d0f45ba

Please sign in to comment.