Skip to content

Commit

Permalink
fix decoder bug
Browse files Browse the repository at this point in the history
  • Loading branch information
meenchen committed May 22, 2023
1 parent 3813bb9 commit b584018
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 502 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "mcunet"]
path = mcunet
url = https://github.com/mit-han-lab/mcunet
[submodule "experimental/transformer/json"]
path = experimental/transformer/json
url = https://github.com/nlohmann/json
9 changes: 3 additions & 6 deletions experimental/transformer/Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Compiler and flags
CXX = g++
CXXFLAGS = -std=c++17 -mavx2 -mfma -pthread -O3
CXXFLAGS = -std=c++17 -mavx2 -pthread -O3

# Executable and source files
TARGET = test_ops test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM profile_OPTForCausalLM test_ops_layer5_1.3B test_OPTTokenizer
TARGET = test_ops test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM profile_OPTForCausalLM test_OPTTokenizer

LIB_DIR = ../matmul_optimization/src
LIB_SRC = $(wildcard $(LIB_DIR)/lib/*.cc)
INCLUDE_DIRS = -I$(LIB_DIR) -I./include
INCLUDE_DIRS = -I$(LIB_DIR) -I./include -I./json/single_include/

$(info $(LIB_SRC))

Expand Down Expand Up @@ -36,9 +36,6 @@ test_Int8OPTDecoder:
test_OPTForCausalLM:
$(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o test_OPTForCausalLM tests/test_OPTForCausalLM.cc $(SRC)

test_ops_layer5_1.3B:
$(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -o test_ops_layer5_1.3B tests/test_ops_layer5_1.3B.cc $(SRC)

profile_OPTForCausalLM:
$(CXX) $(CXXFLAGS) $(INCLUDE_DIRS) -D PROFILER -o profile_OPTForCausalLM tests/test_OPTForCausalLM.cc $(SRC)

Expand Down
4 changes: 2 additions & 2 deletions experimental/transformer/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# List of files to download, their corresponding MD5 checksums, and target local paths
files_and_checksums=(
"https://www.dropbox.com/s/vcuzqyrewt1jjs3/models.zip e9a99baf4f5e66e4a69f280f07397e23 models.zip"
"https://www.dropbox.com/s/dstbc72fp7ka33d/assets.zip 5c18cc891bcc74be12f5cbb926fd9cc9sh assets.zip"
"https://www.dropbox.com/s/4r4dm1hssbdlgb9/models.zip 349568042ac013f0de97baf5fdb1f952 models.zip"
"https://www.dropbox.com/s/8q5cupqw00twvoa/assets.zip 8fe97930409b7d66fd085dc77d4e9926 assets.zip"
)

# Function to download a file if it doesn't exist or if its MD5 checksum is incorrect
Expand Down
1 change: 1 addition & 0 deletions experimental/transformer/include/Int8OPTDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ class Int8OPTDecoder {
float* attention_mask_buf;
float* pos_embeds_buf;
float* last_hidden_states_buf;
float* hidden_states_buf;
};
1 change: 1 addition & 0 deletions experimental/transformer/json
Submodule json added at a0c131
9 changes: 6 additions & 3 deletions experimental/transformer/src/Int8OPTDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ Matrix3D<float> Int8OPTDecoder::get_position_embed(int sql_length, int past_leng

Int8OPTDecoder::Int8OPTDecoder(std::string param_path, const struct model_config config) {
allocate_aligned_memory(attention_mask_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(pos_embeds_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(last_hidden_states_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(pos_embeds_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(last_hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));

this->voc_size = config.vocsize;
this->embed_dim = config.embed_dim;
Expand Down Expand Up @@ -153,18 +154,20 @@ struct Int8OPTDecoder_output Int8OPTDecoder::forward(const struct Int8OPTDecoder
// causal_attention_mask = self._prepare_decoder_attention_mask
Matrix3D<float> causal_attention_mask =
this->prepare_decoder_attention_mask(sqlen + past_key_values_length, past_key_values_length);
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;

// modeling_opt.py: pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
Matrix3D<float> pos_embeds = this->get_position_embed(sqlen, past_key_values_length);
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;

// modeling_opt.py: hidden_states = inputs_embeds + pos_embeds
assert(inputs_embeds.m_dim_x == pos_embeds.m_dim_x);
assert(inputs_embeds.m_dim_y == pos_embeds.m_dim_y);
assert(inputs_embeds.m_dim_z == pos_embeds.m_dim_z);
float hidden_states_buf[sqlen * this->embed_dim];
Matrix3D<float> hidden_states(hidden_states_buf, 1, sqlen, this->embed_dim);
for (int i = 0; i < inputs_embeds.length(); i++)
hidden_states.m_data[i] = inputs_embeds.m_data[i] + pos_embeds.m_data[i];
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;
// DEBUGING CODE
// print_first_k_elelment("pos_embeds", pos_embeds.m_data, 20);
// print_first_k_elelment("inputs_embeds", inputs_embeds.m_data, 20);
Expand Down
7 changes: 3 additions & 4 deletions experimental/transformer/tests/test_OPTForCausalLM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ void test_OPTForCausalLM_1_3B() {

Matrix3D<float> logits(mem_buf.get_fpbuffer(b * sqlen * voc_size), b, sqlen, voc_size);
read_to_array("assets/tests/OPT_1.3B/causallm/1st_logits.bin", logits.m_data, logits.length());
print_first_k_elelment("O", output_1st.logits.m_data, 70, 50);
print_first_k_elelment("G", logits.m_data, 70, 50);
// print_first_k_elelment("O", output_1st.logits.m_data, 70, 50);
// print_first_k_elelment("G", logits.m_data, 70, 50);
sucess = check_two_equal(output_1st.logits.m_data, logits.m_data, logits.length(),
0.41); // large error expected, see comments above

Expand Down Expand Up @@ -159,7 +159,7 @@ void test_OPTForCausalLM_1_3B() {
read_to_array("assets/tests/OPT_1.3B/causallm/2nd_logits.bin", logits.m_data, logits.length());
// print_first_k_elelment("O", output_2nd.logits.m_data, 20);
// print_first_k_elelment("G", logits.m_data, 20);
sucess &= check_two_equal(output_2nd.logits.m_data, logits.m_data, logits.length(), 0.21);
sucess &= check_two_equal(output_2nd.logits.m_data, logits.m_data, logits.length(), 1.67);

Matrix3D<int> arg_max_2nd(mem_buf.get_intbuffer(sqlen), 1, 1, 1);
arg_max_dim2(output_2nd.logits, arg_max_2nd);
Expand All @@ -183,7 +183,6 @@ void test_OPTForCausalLM_1_3B() {
std::cout << "-------- Test of " << __func__ << ": Passed! -------- " << std::endl;
}

// TODO: update the asset
void test_OPTForCausalLM_6_7B() {
MemoryAllocator mem_buf;
int sqlen = 108, b = 1;
Expand Down
Loading

0 comments on commit b584018

Please sign in to comment.