Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/tokenize_tool/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

using namespace tokenizers;

std::string help(char *argv[]) {
std::string help(char* argv[]) {
std::stringstream ss;
ss << "Usage: " << argv[0] << " <type> <model> <input to tokenize...>"
<< std::endl
Expand All @@ -37,7 +37,7 @@ std::string help(char *argv[]) {
return ss.str();
}

int main(int argc, char *argv[]) {
int main(int argc, char* argv[]) {
// Check for the right number of CLI args
if (argc < 4) {
std::cerr << help(argv) << std::endl;
Expand Down Expand Up @@ -95,7 +95,7 @@ int main(int argc, char *argv[]) {
// Decode
std::cout << "Decoding..." << std::endl;
uint64_t prev = tok_ptr->bos_tok();
for (const auto &current : encoded) {
for (const auto& current : encoded) {
const auto decoded_result = tok_ptr->decode(prev, current);
std::cout << decoded_result.get();
prev = current;
Expand Down
2 changes: 1 addition & 1 deletion test/test_base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

#include "gtest/gtest.h"
#include <pytorch/tokenizers/base64.h>
#include "gtest/gtest.h"

namespace tokenizers {

Expand Down
8 changes: 4 additions & 4 deletions test/test_llama2c_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ namespace tokenizers {

namespace {
// Test case based on llama2.c tokenizer
static inline std::string _get_resource_path(const std::string &name) {
static inline std::string _get_resource_path(const std::string& name) {
#ifdef TOKENIZERS_FB_BUCK
return facebook::xplat::testing::getPathForTestResource("test/resources/" +
name);
return facebook::xplat::testing::getPathForTestResource(
"test/resources/" + name);
#else
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
#endif
Expand All @@ -24,7 +24,7 @@ static inline std::string _get_resource_path(const std::string &name) {
} // namespace

class Llama2cTokenizerTest : public Test {
public:
public:
void SetUp() override {
tokenizer_ = std::make_unique<Llama2cTokenizer>();
modelPath_ = _get_resource_path("test_llama2c_tokenizer.bin");
Expand Down
87 changes: 62 additions & 25 deletions test/test_pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ using namespace tokenizers;

// Helpers /////////////////////////////////////////////////////////////////////

static void assert_split_match(const PreTokenizer &ptok,
const std::string &prompt,
const std::vector<std::string> &expected) {
static void assert_split_match(
const PreTokenizer& ptok,
const std::string& prompt,
const std::vector<std::string>& expected) {
re2::StringPiece prompt_view(prompt);
const auto &got = ptok.pre_tokenize(prompt_view);
const auto& got = ptok.pre_tokenize(prompt_view);
EXPECT_EQ(expected.size(), got.size());
for (auto i = 0; i < got.size(); ++i) {
EXPECT_EQ(expected[i], got[i]);
Expand All @@ -34,14 +35,16 @@ static void assert_split_match(const PreTokenizer &ptok,
class RegexPreTokenizerTest : public ::testing::Test {};

// Test the basic construction
TEST_F(RegexPreTokenizerTest, Construct) { RegexPreTokenizer ptok("[0-9]+"); }
TEST_F(RegexPreTokenizerTest, Construct) {
RegexPreTokenizer ptok("[0-9]+");
}

// Test basic splitting using the expression for Tiktoken
TEST_F(RegexPreTokenizerTest, TiktokenExpr) {
RegexPreTokenizer ptok(
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)");
assert_split_match(ptok, "How are you doing?",
{"How", " are", " you", " doing", "?"});
assert_split_match(
ptok, "How are you doing?", {"How", " are", " you", " doing", "?"});
}

// DigitsPreTokenizer //////////////////////////////////////////////////////////
Expand All @@ -51,15 +54,18 @@ class DigitsPreTokenizerTest : public ::testing::Test {};
TEST_F(DigitsPreTokenizerTest, IndividualDigits) {
DigitsPreTokenizer ptok(true);
assert_split_match(
ptok, "The number 1 then 234 then 5.",
ptok,
"The number 1 then 234 then 5.",
{"The number ", "1", " then ", "2", "3", "4", " then ", "5", "."});
}

// Test digit splitting with contiguous digits
TEST_F(DigitsPreTokenizerTest, ContiguousDigits) {
DigitsPreTokenizer ptok(false);
assert_split_match(ptok, "The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
}

// ByteLevelPreTokenizer ///////////////////////////////////////////////////////
Expand All @@ -69,7 +75,8 @@ TEST_F(ByteLevelPreTokenizerTest, PreTokenizeDefault) {
ByteLevelPreTokenizer ptok;
assert_split_match(ptok, "Hello World", {"ĠHello", "ĠWorld"});
assert_split_match(
ptok, "The number 1 then 234 then 5.",
ptok,
"The number 1 then 234 then 5.",
{"ĠThe", "Ġnumber", "Ġ1", "Ġthen", "Ġ234", "Ġthen", "Ġ5", "."});
}

Expand All @@ -90,9 +97,22 @@ TEST_F(SequencePreTokenizerTest, PreTokenizeDigitAndByteLevel) {
PreTokenizer::Ptr dptok(new DigitsPreTokenizer(true));
PreTokenizer::Ptr bptok(new ByteLevelPreTokenizer(false));
SequencePreTokenizer ptok({dptok, bptok});
assert_split_match(ptok, "The number 1 then 234 then 5.",
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
"Ġthen", "Ġ", "5", "."});
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
{"The",
"Ġnumber",
"Ġ",
"1",
"Ġthen",
"Ġ",
"2",
"3",
"4",
"Ġthen",
"Ġ",
"5",
"."});
}

// PreTokenizerConfig //////////////////////////////////////////////////////////
Expand Down Expand Up @@ -132,12 +152,14 @@ TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) {

// Sequence
EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error);
EXPECT_THROW(PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
std::runtime_error);
EXPECT_THROW(PreTokenizerConfig("Sequence")
.set_pretokenizers({PreTokenizerConfig("Split")})
.create(),
std::runtime_error);
EXPECT_THROW(
PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
std::runtime_error);
EXPECT_THROW(
PreTokenizerConfig("Sequence")
.set_pretokenizers({PreTokenizerConfig("Split")})
.create(),
std::runtime_error);

// Unsupported
EXPECT_THROW(PreTokenizerConfig("Unsupported").create(), std::runtime_error);
Expand All @@ -161,9 +183,22 @@ TEST_F(PreTokenizerConfigTest, ParseJson) {
}},
})
.create();
assert_split_match(*ptok, "The number 1 then 234 then 5.",
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
"Ġthen", "Ġ", "5", "."});
assert_split_match(
*ptok,
"The number 1 then 234 then 5.",
{"The",
"Ġnumber",
"Ġ",
"1",
"Ġthen",
"Ġ",
"2",
"3",
"4",
"Ġthen",
"Ġ",
"5",
"."});
}

TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
Expand All @@ -173,8 +208,10 @@ TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
{"type", "Digits"},
})
.create();
assert_split_match(*ptok, "The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
assert_split_match(
*ptok,
"The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
}

TEST_F(PreTokenizerConfigTest, Split) {
Expand Down
2 changes: 1 addition & 1 deletion test/test_sentencepiece.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace tokenizers {

namespace {
static inline std::string _get_resource_path(const std::string &name) {
static inline std::string _get_resource_path(const std::string& name) {
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
}
} // namespace
Expand Down
51 changes: 31 additions & 20 deletions test/test_tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,35 @@ static constexpr int32_t kSpecialTokensSize = 256;
static inline std::unique_ptr<std::vector<std::string>> _get_special_tokens() {
auto special_tokens =
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
"<|begin_of_text|>", "<|end_of_text|>",
"<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
"<|reserved_special_token_2|>", "<|reserved_special_token_3|>",
"<|start_header_id|>", "<|end_header_id|>",
"<|reserved_special_token_4|>", "<|eot_id|>"});
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>"});

// pad the rest of the special tokens with reserved tokens
ssize_t reserved_special_token_num = 5;
while (special_tokens->size() < kSpecialTokensSize) {
special_tokens->emplace_back("<|reserved_special_token_" +
std::to_string(reserved_special_token_num++) +
"|>");
special_tokens->emplace_back(
"<|reserved_special_token_" +
std::to_string(reserved_special_token_num++) + "|>");
}
return special_tokens;
}

static inline std::string _get_resource_path(const std::string &name) {
static inline std::string _get_resource_path(const std::string& name) {
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
}

} // namespace

class TiktokenTest : public Test {
public:
public:
void SetUp() override {
tokenizer_ = std::make_unique<Tiktoken>(_get_special_tokens(), 0, 1);
modelPath_ = _get_resource_path("test_tiktoken_tokenizer.model");
Expand Down Expand Up @@ -110,23 +115,29 @@ TEST_F(TiktokenTest, ConstructionWithInvalidBOSIndex) {
// gtest death test doesn't work on iOS:
// https://github.com/google/googletest/issues/2834
#if !GTEST_OS_IOS
EXPECT_EXIT(std::make_unique<Tiktoken>(
std::make_unique<std::vector<std::string>>(
std::vector<std::string>{"<|end_of_text|>"}),
1, 0),
::testing::KilledBySignal(SIGABRT), "");
EXPECT_EXIT(
std::make_unique<Tiktoken>(
std::make_unique<std::vector<std::string>>(
std::vector<std::string>{"<|end_of_text|>"}),
1,
0),
::testing::KilledBySignal(SIGABRT),
"");
#endif
}

TEST_F(TiktokenTest, ConstructionWithInvalidEOSIndex) {
// gtest death test doesn't work on iOS:
// https://github.com/google/googletest/issues/2834
#if !GTEST_OS_IOS
EXPECT_EXIT(std::make_unique<Tiktoken>(
std::make_unique<std::vector<std::string>>(
std::vector<std::string>{"<|begin_of_text|>"}),
0, 1),
::testing::KilledBySignal(SIGABRT), "");
EXPECT_EXIT(
std::make_unique<Tiktoken>(
std::make_unique<std::vector<std::string>>(
std::vector<std::string>{"<|begin_of_text|>"}),
0,
1),
::testing::KilledBySignal(SIGABRT),
"");
#endif
}

Expand Down