Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
meenchen committed May 22, 2023
1 parent e6cfd2d commit b53b088
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion experimental/transformer/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Compiler and flags
CXX = g++
CXXFLAGS = -std=c++17 -mavx2 -pthread -Wno-deprecated-declarations
CXXFLAGS = -std=c++17 -mavx2 -O3 -pthread -Wno-deprecated-declarations

# Executable and source files
TARGET = test_ops test_Int8OPTAttention test_Int8OPTDecoderLayer test_Int8OPTDecoder test_OPTForCausalLM profile_OPTForCausalLM test_ops_layer5_1.3B
Expand Down
1 change: 1 addition & 0 deletions experimental/transformer/include/Int8OPTDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ class Int8OPTDecoder {
float* attention_mask_buf;
float* pos_embeds_buf;
float* last_hidden_states_buf;
float* hidden_states_buf;
};
12 changes: 6 additions & 6 deletions experimental/transformer/src/Int8OPTAttention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ struct Int8OPTAttention_output Int8OPTAttention::forward(const struct Int8OPTAtt
const int sqlen = input.hidden_states.m_dim_y, b = input.hidden_states.m_dim_x;
assert(b == 1);

std::cout << "input.hidden_states(md5):" << input.hidden_states.getMD5() << std::endl;
std::cout << "input.attention_mask(md5):" << input.attention_mask.getMD5() << std::endl;
// std::cout << "input.hidden_states(md5):" << input.hidden_states.getMD5() << std::endl;
// std::cout << "input.attention_mask(md5):" << input.attention_mask.getMD5() << std::endl;
Matrix3D<int8_t> query_states_unshape(query_states_unshape_arr, b, sqlen, embed_dim);
// opt.py: query_states = self.q_proj(hidden_states)
this->q_proj.forward(input.hidden_states, query_states_unshape);
Expand Down Expand Up @@ -278,13 +278,13 @@ struct Int8OPTAttention_output Int8OPTAttention::forward(const struct Int8OPTAtt
// float attn_weights_arr[this->num_heads * sqlen * tgz];
Matrix3D<float> attn_weights(attn_weights_arr, this->num_heads, sqlen, tgz);
this->qk_bmm.forward(query_states, final_key_states, attn_weights);
std::cout << "attn_weights(md5):" << attn_weights.getMD5() << std::endl;
std::cout << "input.attention_mask(md5):" << input.attention_mask.getMD5() << std::endl;
// std::cout << "attn_weights(md5):" << attn_weights.getMD5() << std::endl;
// std::cout << "input.attention_mask(md5):" << input.attention_mask.getMD5() << std::endl;
// print_first_k_elelment("attn_weights.m_data", attn_weights.m_data, 20);

// opt.py: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
batch_Add(attn_weights, input.attention_mask, attn_weights);
std::cout << "attn_weights(md5):" << attn_weights.getMD5() << std::endl;
// std::cout << "attn_weights(md5):" << attn_weights.getMD5() << std::endl;
// print_first_k_elelment("attn_weights.m_data", attn_weights.m_data, 20);

// opt.py: attn_probs = nn.functional.softmax(attn_weights, dim=-1)
Expand Down Expand Up @@ -337,7 +337,7 @@ struct Int8OPTAttention_output Int8OPTAttention::forward(const struct Int8OPTAtt
output.attn_output = attn_output_fp;
output.past_key_value = {final_key_states, final_value_states};

exit(0);
// exit(0);

PROFILE_END(profile_name);
return output;
Expand Down
10 changes: 6 additions & 4 deletions experimental/transformer/src/Int8OPTDecoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ Matrix3D<float> Int8OPTDecoder::get_position_embed(int sql_length, int past_leng

Int8OPTDecoder::Int8OPTDecoder(std::string param_path, const struct model_config config) {
allocate_aligned_memory(attention_mask_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(pos_embeds_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(last_hidden_states_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(pos_embeds_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(last_hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));

this->voc_size = config.vocsize;
this->embed_dim = config.embed_dim;
Expand Down Expand Up @@ -153,19 +154,20 @@ struct Int8OPTDecoder_output Int8OPTDecoder::forward(const struct Int8OPTDecoder
// causal_attention_mask = self._prepare_decoder_attention_mask
Matrix3D<float> causal_attention_mask =
this->prepare_decoder_attention_mask(sqlen + past_key_values_length, past_key_values_length);
std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;

// modeling_opt.py: pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
Matrix3D<float> pos_embeds = this->get_position_embed(sqlen, past_key_values_length);
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;

// modeling_opt.py: hidden_states = inputs_embeds + pos_embeds
assert(inputs_embeds.m_dim_x == pos_embeds.m_dim_x);
assert(inputs_embeds.m_dim_y == pos_embeds.m_dim_y);
assert(inputs_embeds.m_dim_z == pos_embeds.m_dim_z);
float hidden_states_buf[sqlen * this->embed_dim];
Matrix3D<float> hidden_states(hidden_states_buf, 1, sqlen, this->embed_dim);
for (int i = 0; i < inputs_embeds.length(); i++)
hidden_states.m_data[i] = inputs_embeds.m_data[i] + pos_embeds.m_data[i];
// std::cout << "causal_attention_mask(md5):" << causal_attention_mask.getMD5() << std::endl;
// DEBUGING CODE
// print_first_k_elelment("pos_embeds", pos_embeds.m_data, 20);
// print_first_k_elelment("inputs_embeds", inputs_embeds.m_data, 20);
Expand Down
2 changes: 1 addition & 1 deletion experimental/transformer/tests/test_OPTForCausalLM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void test_OPTForCausalLM_6_7B() {

Matrix3D<float> logits(mem_buf.get_fpbuffer(b * sqlen * voc_size), b, sqlen, voc_size);
read_to_array("assets/tests/OPT_6.7B/causallm/1st_logits.bin", logits.m_data, logits.length());
print_first_k_elelment("O", output_1st.logits.m_data, 70, 50);
// print_first_k_elelment("O", output_1st.logits.m_data, 70, 50);
// print_first_k_elelment("G", logits.m_data, 70, 50);
// sucess = check_two_equal(output_1st.logits.m_data, logits.m_data, logits.length(),
// 0.507); // large error expected, see comments above
Expand Down

0 comments on commit b53b088

Please sign in to comment.