Skip to content

Create a quantized EmbedLayerNorm for ORT.#8124

Merged
nkreeger merged 15 commits intomasterfrom
kreeger/QEmbedLayerNorm
Jun 25, 2021
Merged

Create a quantized EmbedLayerNorm for ORT.#8124
nkreeger merged 15 commits intomasterfrom
kreeger/QEmbedLayerNorm

Conversation

@nkreeger
Copy link
Copy Markdown
Contributor

@nkreeger nkreeger commented Jun 22, 2021

Reduces memory overhead from a quantized graph containing a EmbedLayerNormalization fused Op. All weights and initializers are converted to uint8_t. Runtime simply dequantizes on the fly during batched/threaded execution. This reduces memory consumption on transformer models with large word embeddings since the entire embedding is not expanded to float32 during every invoke. Additionally, the uint8_t operations are much faster (up to ~3x on my machine with a large word embedding and hidden size of 768).

NOTES:

  • I considered outputting as uint8 but this works well for now - potential for a future pass.
  • I tried consolidating logic between qembed_layer_norm.h/.cc and embed_layer_norm.h/.cc but was running into too many issues with complicated template declarations (see all the commits on kreeger/qembed_layer_norm. Ideally, the guts of QEmbedLayerNorm are eventually fully quantized and the logic diverges more.

@nkreeger nkreeger requested a review from a team as a code owner June 22, 2021 21:30
@nkreeger nkreeger requested review from mrry and yufenglee June 22, 2021 21:30
Copy link
Copy Markdown
Contributor

@mrry mrry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments on the op and kernel implementation only. Note that most of these are nits and/or optional.

The only important thing to fix is the lack of (runtime) shape validation on various tensor inputs, which could cause reads off the end of a buffer. (I guess they aren't writes off the end of a buffer, but still....)

word_embedding_zero_point) +
Dequantize<uint8_t>(input_position_embedding[i],
position_embedding_scale,
position_embedding_zero_point);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slow. To improve the performance, you can use table query, an approach we use for other operators.

As word_embedding_scale and word_embedding_zero_point are constant, you can create a hash table with key is input_word_embedding[i] and value is (input_word_embedding[i] - word_embedding_zero_point ) * word_embedding_scale

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for the case word_embedding_scale and word_embedding_zero_point are non-constant, you can calculate the hash table ahead dynamically

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further more, you can make the value as (input_word_embedding[i] - word_embedding_zero_point ) * word_embedding_scale + (input_segment_embedding[i] - segment_embedding_zero_point) * segment_embedding_scale

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I didn't know these things existed. Mind if I do this as a fast-follow? This PR is getting rather large.

I was thinking that a follow up PR that used these techniques could use as a canonical PR with documentation to help future contributors write these optimizations for quantized kernels.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

T cur_beta = Dequantize<uint8_t>(beta_data[i],
layer_norm_bias_scale,
layer_norm_bias_zero_point);
output[i] = output[i] / e * cur_gamma + cur_beta;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is quantization of gamma_data and beta_data paid off? They are quite small.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models we see have a large layer (768). I'd like to stick with them for now and use Prepack() to help with this in the future.

Copy link
Copy Markdown
Contributor Author

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some updates here - mostly around inference shape function and some other bits.

Still working through how to handle dual uint8_t and int8_t hybrid approaches.

Copy link
Copy Markdown
Contributor Author

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrry @yufenglee PTAL - updated and should address all comments for now.

mrry
mrry previously approved these changes Jun 24, 2021
mrry
mrry previously approved these changes Jun 25, 2021
@nkreeger nkreeger merged commit 800b62a into master Jun 25, 2021
@nkreeger nkreeger deleted the kreeger/QEmbedLayerNorm branch June 25, 2021 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants