Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KV Cache Layout Control and Attention Tests #667

Merged
merged 1 commit into from
Jun 3, 2024

Conversation

morgandu
Copy link
Collaborator

@morgandu morgandu commented May 23, 2024

This PR is to enable KV Cache Layout Control to allow permutation experiments seeking performance improvement.

Checklist

  • Add KV Cache Layout Control
  • Add Unit Tests for Attention
  • Update Docstring
  • Microbenchmark - Performance
  • E2E tests with JetStream v0.2.2 - Accuracy and Performance

Microbenchmark - Performance

Setup

  • v5e4
  • llama2-7b
  • quantization=int8
  • quantize_kvcache=true
  • per_device_batch_size=24

Results

Summary

Throughput Before Throughput After Throughput Improvement
1469.555 2150.441 46.33%

Before

  • default layout, i.e. 1,2,0,3 for both prefill/ar key/value
AutoRegressive results:
        AR step average time: 65.326 ms
        AR step average time per seq: 0.680 ms
        AR global batch size: 96
        AR throughput: 1469.555 tokens/second
        AR memory bandwidth per device: 227.221 GB/s

Xprof: https://xprof.corp.google.com/overview_page/morgandu-14884055542603397573

  • HBM BW util of qk_product's dot_general is 26%
  • HBM BW util of wv_product's dot_general fusions are around 27%

After

  • layout control, i.e. 0,2,3,1 for both prefill/ar key/value
AutoRegressive results:
        AR step average time: 44.642 ms
        AR step average time per seq: 0.465 ms
        AR global batch size: 96
        AR throughput: 2150.441 tokens/second
        AR memory bandwidth per device: 332.499 GB/s

Xprof: https://xprof.corp.google.com/overview_page/morgandu-16660388030961792706

  • HBM BW util of qk_product's dot_general is 89.12%
  • HBM BW util of wv_product's dot_general fusions are around 85%

NOTE:

  • There is an expensive dynamic-update-slice ops in the After run, this is likely due to layout not being optimal.

E2E tests with JetStream v0.2.2 - Accuracy and Performance

Setup

  • Llama2-7b
  • v5e-8
  • JetStream release v0.2.2
  • openorca chat dataset (base acc is not accountable, perf only)
  • Same maybe sub optimal layout as mentioned in above Note

Results

Summary

  • No Acc regression after introducing layout control
  • Performance improvement is observed as below
    • For base model, quantized kv cache only, full warmup, throughput improved 33.22%
    • For chat model, quantized kv cache only, full warmup, throughput improved 26.80%
    • NOTE: The model mode chat and quant mode w-b16-kv-i8 combination is the main mode of interest, as weight quantization is still dev in progress
  • Full warmup yield better result than existing partial warmup, especially for chat mode model, throughput improved about 20% purely from full warmup

Details

Model Name Model Mode Quant Mode Batch Size Exit Rouge1 Layout Rouge1 Exit Thrpt Layout Thrpt Layout Improv
llama2-7b base w-b16-kv-b16 10 - - 2586.307341 3335.753476 28.98%
llama2-7b base w-b16-kv-i8 24 - - 2348.965528 3129.199496 33.22%
llama2-7b base w-i8-kv-i8 24 - - 2491.088673 3347.33998 34.37%
llama2-7b chat w-b16-kv-b16 10 45.3415 45.3326 2003.580613 2470.159128 23.29%
llama2-7b chat w-b16-kv-i8 24 45.4035 45.24 1470.583332 1864.755825 26.80%
llama2-7b chat w-i8-kv-i8 24 44.8522 44.7558 1470.031543 1930.533052 31.33%

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, few comments. Also, the microbenchmark results looks great, were you able to do a full end-2-end on JetStream?

MaxText/configs/base.yml Show resolved Hide resolved
MaxText/layers/attentions.py Show resolved Hide resolved
MaxText/layers/quantizations.py Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
@morgandu morgandu force-pushed the mor--kv-cache-layout branch 2 times, most recently from 2c3ecf8 to 185563d Compare May 23, 2024 18:14
@rwitten rwitten assigned gobbleturk and unassigned rwitten May 23, 2024
Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@morgandu morgandu marked this pull request as draft May 24, 2024 16:42
@morgandu morgandu marked this pull request as draft May 24, 2024 16:42
@morgandu morgandu changed the title Add KV Cache Layout Control [WIP] Add KV Cache Layout Control May 24, 2024
@morgandu morgandu force-pushed the mor--kv-cache-layout branch 2 times, most recently from 58796a7 to 187cb3d Compare May 30, 2024 05:34
@morgandu morgandu changed the title [WIP] Add KV Cache Layout Control Add KV Cache Layout Control and Attention Tests May 30, 2024
@morgandu morgandu marked this pull request as ready for review May 30, 2024 06:02
@morgandu morgandu force-pushed the mor--kv-cache-layout branch 2 times, most recently from 58b5b31 to 4b4eaaa Compare May 31, 2024 19:16
@copybara-service copybara-service bot merged commit 34412d4 into main Jun 3, 2024
13 checks passed
@copybara-service copybara-service bot deleted the mor--kv-cache-layout branch June 3, 2024 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants