Skip to content

Commit

Permalink
fix toeknizer test
Browse files Browse the repository at this point in the history
  • Loading branch information
meenchen committed May 22, 2023
1 parent b584018 commit e2a7d2c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
4 changes: 2 additions & 2 deletions experimental/transformer/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# List of files to download, their corresponding MD5 checksums, and target local paths
files_and_checksums=(
"https://www.dropbox.com/s/4r4dm1hssbdlgb9/models.zip 349568042ac013f0de97baf5fdb1f952 models.zip"
"https://www.dropbox.com/s/8q5cupqw00twvoa/assets.zip 8fe97930409b7d66fd085dc77d4e9926 assets.zip"
"https://www.dropbox.com/s/4r4dm1hssbdlgb9/models.zip 3c5d765f76093bcff4951180cdd899f4 models.zip"
"https://www.dropbox.com/s/8q5cupqw00twvoa/assets.zip dce6d88f2b79046b68e9560dd42c7cc2 assets.zip"
)

# Function to download a file if it doesn't exist or if its MD5 checksum is incorrect
Expand Down
4 changes: 2 additions & 2 deletions experimental/transformer/tests/test_OPTForCausalLM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void test_OPTForCausalLM_1_3B() {
// std::cout << "arg_max_2nd.m_data:" << arg_max_2nd.m_data[0] << ", arg_maxGT_2nd.m_data:" <<
// arg_maxGT_2nd.m_data[0] << std::endl;
hit_rate = (float)total_hit / (float)1;
std::cout << "sqlen:" << sqlen << ", hits:" << total_hit << ", hit rate:" << hit_rate << std::endl;
std::cout << "sqlen:" << 1 << ", hits:" << total_hit << ", hit rate:" << hit_rate << std::endl;
sucess &= hit_rate > 0.99;

Profiler::getInstance().report();
Expand Down Expand Up @@ -255,7 +255,7 @@ void test_OPTForCausalLM_6_7B() {
// std::cout << "arg_max_2nd.m_data:" << arg_max_2nd.m_data[0] << ", arg_maxGT_2nd.m_data:" <<
// arg_maxGT_2nd.m_data[0] << std::endl;
hit_rate = (float)total_hit / (float)1;
// std::cout << "sqlen:" << sqlen << ", hits:" << total_hit << ", hit rate:" << hit_rate << std::endl;
std::cout << "sqlen:" << 1 << ", hits:" << total_hit << ", hit rate:" << hit_rate << std::endl;
sucess &= hit_rate > 0.99;

Profiler::getInstance().report();
Expand Down
44 changes: 27 additions & 17 deletions experimental/transformer/tests/test_OPTTokenizer.cc
Original file line number Diff line number Diff line change
@@ -1,39 +1,49 @@
#include <iostream>
#include "OPTTokenizer.h"

void test_OPTTokenizer () {
#include "OPTTokenizer.h"

void test_OPTTokenizer() {
// Test bpe
//std::cout << "Test End!" << std::endl;
// std::cout << "Test End!" << std::endl;
}

int main() {
//test_OPTTokenizer();
std::string vocab_file = "./opt-125m/vocab.json";
std::string bpe_file = "./opt-125m/merges.txt";
// test_OPTTokenizer();
std::string vocab_file = "./models/OPT_125m/vocab.json";
std::string bpe_file = "./models/OPT_125m/merges.txt";

Encoder encoder = get_encoder(vocab_file, bpe_file);
//encoder.new_bytes_to_unicode();
//encoder.encode("Building. a don't.");
//encoder.encode("Building a website can be done in 10 simple steps. This message is for general people, so we assume they don't have basic concepts.");
std::vector<int> encoded = encoder.encode("Building a website can be done in 10 simple steps. This message is for general people, so we assume they don't have basic concepts.");
std::vector<int> encoded_answer = {37500, 10, 998, 64, 28, 626, 11, 158, 2007, 2402, 4, 152, 1579, 16, 13, 937, 82, 6, 98, 52, 6876, 51, 218, 75, 33, 3280, 14198, 4};
// encoder.new_bytes_to_unicode();
// encoder.encode("Building. a don't.");
// encoder.encode("Building a website can be done in 10 simple steps. This message is for general people, so we
// assume they don't have basic concepts.");
std::vector<int> encoded = encoder.encode(
"Building a website can be done in 10 simple steps. This message is for general people, so we assume they "
"don't have basic concepts.");
std::vector<int> encoded_answer = {37500, 10, 998, 64, 28, 626, 11, 158, 2007, 2402, 4, 152, 1579, 16,
13, 937, 82, 6, 98, 52, 6876, 51, 218, 75, 33, 3280, 14198, 4};
bool is_equal = true;
for (int i = 0; i < encoded.size(); i++) {
if (encoded[i] != encoded_answer[i]) {
is_equal = false;
break;
}
}
if (is_equal) std::cout << "-------- Test of Encoder::encode: Passed! -------- " << std::endl;
else std::cout << "-------- Test of Encoder::encode: Failed! -------- " << std::endl;

if (is_equal)
std::cout << "-------- Test of Encoder::encode: Passed! -------- " << std::endl;
else
std::cout << "-------- Test of Encoder::encode: Failed! -------- " << std::endl;

std::string decoded = encoder.decode(encoded);
std::string decoded_answer = "Building a website can be done in 10 simple steps. This message is for general people, so we assume they don't have basic concepts.";
std::string decoded_answer =
"Building a website can be done in 10 simple steps. This message is for general people, so we assume they "
"don't have basic concepts.";
is_equal = true;
if (decoded != decoded_answer) is_equal = false;
if (is_equal) std::cout << "-------- Test of Encoder::decode: Passed! -------- " << std::endl;
else std::cout << "-------- Test of Encoder::decode: Failed! -------- " << std::endl;
if (is_equal)
std::cout << "-------- Test of Encoder::decode: Passed! -------- " << std::endl;
else
std::cout << "-------- Test of Encoder::decode: Failed! -------- " << std::endl;

std::cout << "-------- End of test_OPTTokenizer --------" << std::endl;
};

0 comments on commit e2a7d2c

Please sign in to comment.