Create a quantized EmbedLayerNorm for ORT.#8124
Conversation
mrry
left a comment
There was a problem hiding this comment.
Comments on the op and kernel implementation only. Note that most of these are nits and/or optional.
The only important thing to fix is the lack of (runtime) shape validation on various tensor inputs, which could cause reads off the end of a buffer. (I guess they aren't writes off the end of a buffer, but still....)
| word_embedding_zero_point) + | ||
| Dequantize<uint8_t>(input_position_embedding[i], | ||
| position_embedding_scale, | ||
| position_embedding_zero_point); |
There was a problem hiding this comment.
This is slow. To improve the performance, you can use table query, an approach we use for other operators.
As word_embedding_scale and word_embedding_zero_point are constant, you can create a hash table with key is input_word_embedding[i] and value is (input_word_embedding[i] - word_embedding_zero_point ) * word_embedding_scale
There was a problem hiding this comment.
Even for the case word_embedding_scale and word_embedding_zero_point are non-constant, you can calculate the hash table ahead dynamically
There was a problem hiding this comment.
Further more, you can make the value as (input_word_embedding[i] - word_embedding_zero_point ) * word_embedding_scale + (input_segment_embedding[i] - segment_embedding_zero_point) * segment_embedding_scale
There was a problem hiding this comment.
There was a problem hiding this comment.
This is great! I didn't know these things existed. Mind if I do this as a fast-follow? This PR is getting rather large.
I was thinking that a follow up PR that used these techniques could use as a canonical PR with documentation to help future contributors write these optimizations for quantized kernels.
| T cur_beta = Dequantize<uint8_t>(beta_data[i], | ||
| layer_norm_bias_scale, | ||
| layer_norm_bias_zero_point); | ||
| output[i] = output[i] / e * cur_gamma + cur_beta; |
There was a problem hiding this comment.
Is quantization of gamma_data and beta_data paid off? They are quite small.
There was a problem hiding this comment.
The models we see have a large layer (768). I'd like to stick with them for now and use Prepack() to help with this in the future.
nkreeger
left a comment
There was a problem hiding this comment.
Some updates here - mostly around inference shape function and some other bits.
Still working through how to handle dual uint8_t and int8_t hybrid approaches.
nkreeger
left a comment
There was a problem hiding this comment.
@mrry @yufenglee PTAL - updated and should address all comments for now.
Reduces memory overhead from a quantized graph containing a EmbedLayerNormalization fused Op. All weights and initializers are converted to
uint8_t. Runtime simply dequantizes on the fly during batched/threaded execution. This reduces memory consumption on transformer models with large word embeddings since the entire embedding is not expanded to float32 during every invoke. Additionally, theuint8_toperations are much faster (up to ~3x on my machine with a large word embedding and hidden size of 768).NOTES:
uint8but this works well for now - potential for a future pass.qembed_layer_norm.h/.ccandembed_layer_norm.h/.ccbut was running into too many issues with complicated template declarations (see all the commits onkreeger/qembed_layer_norm. Ideally, the guts of QEmbedLayerNorm are eventually fully quantized and the logic diverges more.