From d21182046cabb359ad30157908478fc661e119e0 Mon Sep 17 00:00:00 2001 From: botbw Date: Thu, 23 May 2024 22:58:09 +0800 Subject: [PATCH] [gemini] fix ci (#5748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Inference] First PR for rebuild colossal-infer (#5143) * add engine and scheduler * add dirs --------- Co-authored-by: CjhHa1 * [Inference] Add readme (roadmap) and fulfill request handler (#5147) * request handler * add readme --------- Co-authored-by: CjhHa1 * [Inference/NFC] Clean outdated inference tests and deprecated kernels (#5159) * [inference/nfc] remove outdated inference tests * remove outdated kernel tests * remove deprecated triton kernels * remove imports from deprecated kernels * [Inference]Add BatchInferState, Sequence and InferConfig (#5149) * add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * [Inference] Add CacheBlock and KV-Cache Manager (#5156) * [Inference] Add KVCache Manager * function refactored * add test for KVCache Manager * add attr beam width * Revise alloc func in CacheManager * Fix docs and pytests * add tp slicing for head number * optimize shapes of tensors used as physical cache * Apply using InferenceConfig on KVCacheManager * rm duplicate config file * Optimize cache allocation: use contiguous cache * Fix config in pytest (and config) * [Inference]Update inference config and fix test (#5178) * unify the config setting * fix test * fix import * fix test * fix * fix * add logger * revise log info --------- Co-authored-by: CjhHa1 * [Inference] Add the logic of the inference engine (#5173) * add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt * [Inference] add logit processor and request handler (#5166) * add logit processor and request handler * add * add * add * fix * add search tokens and update func * finish request handler * add running list test * fix test * fix some bug * add * add * fix bugs * fix some bugs * fix bug * fix * fix * add copy fun * del useless attn * fix request status --------- Co-authored-by: CjhHa1 * Add padding llama model * Fixed a bug in the inference frame * fix bugs in request_handler * precision alignment * Fixed a writing error * [kernel] Add triton kernel for context attention (FAv2) without padding (#5192) * add context attn unpadded triton kernel * test compatibility * kv cache copy (testing) * fix k/v cache copy * fix kv cache copy and test * fix boundary of block ptrs * add support for GQA/MQA and testing * fix import statement --------- Co-authored-by: Round Heng * add context_attention_unpadded * fix bugs in sampler * Fixed a typo * fix beam_width * [Inference] Pytorch Attention func, pad&nopad input support (#5219) * add attn * add attention test * fix attn forward * fix decoding * fix bugs in attention.py and request_handler.py * adapted to pad_context_forward * [Hotfix] Fix accuracy and align attention method api with Triton kernel (#5229) * fix accuracy * alignment in attention * fix attention * fix * fix bugs * fix bugs * fix bugs * fix bugs related to processing padding mask * fix CI bugs * rm torch.cuda.synchronize * fix bugs in request_handler.py and engine.py * [Inference] Kernel: no pad rotary embedding (#5252) * fix bugs * comment * use more accurate atol * fix * [kernel] Add flash decoding triton kernel for blocked kv cache (#5249) * add flash decoding unpad triton kernel * rename flash decoding kernel * add kernel testing (draft) * revise pytest * support kv group (GQA) * (trivial) fix api and pytest * (trivial) func renaming * (trivial) func/file renaming * refactor pytest for attention * (trivial) format and consistent vars of context/decode attn * (trivial) remove test redundancy * [git] fixed rebased files * [kernel] Add KV cache copy kernel during decoding (#5261) * add kv copy triton kernel during decoding stage * add pytest and fix kernel * fix test utilities * revise kernel config * add benchmark for kvcache copy * [doc] updated inference readme (#5269) * [Inference] Fix request handler and add recycle logic (#5260) * fix request handler * fix comment * [kernel] Revise KVCache copy triton kernel API (#5273) * [kernel/fix] revise kvcache copy kernel api * fix benchmark * [Inference]Adapted to the triton attn kernels (#5264) * adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove print * [kernel] Add RMSLayerNorm triton kernel (#5262) * add layerrmsnorm triton kernel * add layerrmsnorm kernel * modify the atol and rtol in test file * Remove the logics of mean computations, and update the name of ther kernel functions and files * add benchmark of rms norm * [Hotfix] Fix bugs in testing continuous batching (#5270) * fix bug * fix bugs * fix bugs * fix bugs and add padding * add funcs and fix bugs * fix typos * fix bugs * add func * [kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274) * prevent re-creating intermediate tensors * add singleton class holding intermediate values * fix triton kernel api * add benchmark in pytest * fix kernel api and add benchmark * revise flash decoding triton kernel in/out shapes * fix calling of triton kernel in modeling * fix pytest: extract to util functions * [inference] Adapted to Rotary Embedding and RMS Norm (#5283) * adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.py * add utils.py * [Inference] Benchmarking rotary embedding and add a fetch function (#5277) * fix bugs and add a cos/sin cache fetch func * add docstring * fix bug * fix * [Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301) * fix decoding kernel pytest * revise and add triton context attn benchmark * [Inference]Add fused rotary kernel and get cos cache kernel (#5302) * add fused rotary and get cos cache func * staged * fix bugs * fix bugs * [hotfix] fix boundary check in batch (#5306) * [inference]Optimize the usage of the mid tensors space in flash attn (#5304) * opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py * fix (#5311) * [Inference] Update rms norm kernel, benchmark with vLLM (#5315) * add * xi * del * del * fix * [DOC] Update inference readme (#5280) * add readme * add readme * 1 * update engine * finish readme * add readme * [Inference]Add Nopadding Llama Modeling (#5327) * add nopadding llama modeling * add nopadding_llama.py * rm unused codes * fix bugs in test_xine_copy.py * fix code style * [Infer] Optimize Blocked KVCache And Kernels Using It (#5325) * revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval * [Inference] Kernel Fusion, fused copy kv cache into rotary embedding (#5336) * revise rotary embedding * remove useless print * adapt * [inference] simplified config verification (#5346) * [inference] simplified config verification * polish * polish * [Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340) * add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py * [inference] removed redundancy init_batch (#5353) * [inference] moved ops tests to test_infer (#5354) * [doc] updated inference readme (#5343) * [Inference/opt]Optimize the mid tensor of RMS Norm (#5350) * opt rms_norm * fix bugs in rms_layernorm * [Inference]Optimize generation process of inference engine (#5356) * opt inference engine * fix run_benchmark.sh * fix generate in engine.py * rollback tesh_inference_engine.py * [Fix/Infer] Remove unused deps and revise requirements (#5341) * remove flash-attn dep * rm padding llama * revise infer requirements * move requirements out of module * [Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365) * fused the gate and up proj in mlp * fix code styles * opt auto_grad * rollback test_inference_engine.py * modifications based on the review feedback. * fix bugs in flash attn * Change reshape to view * fix test_rmsnorm_triton.py * [Inference] Adapt to Fused rotary (#5348) * revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * Revert "[Inference] Adapt to Fused rotary (#5348)" (#5373) This reverts commit 9f4ab2eb924b938348df2c713bb4580972f18eb1. * [inference] added inference template (#5375) * [Inference/opt] Fused KVCahce Memcopy (#5374) * fused kv memcopy * add TODO in test_kvcache_copy.py * [Inference] User Experience: update the logic of default tokenizer and generation config. (#5337) * add * fix * fix * pause * fix * fix pytest * align * fix * license * fix * fix * fix readme * fix some bugs * remove tokenizer config * [inference] refactored config (#5376) * [Inference]Support vllm testing in benchmark scripts (#5379) * add vllm benchmark scripts * fix code style * update run_benchmark.sh * fix code style * [Inference] Optimize and Refactor Inference Batching/Scheduling (#5367) * add kvcache manager funcs for batching * add batch bucket for batching * revise RunningList struct in handler * add kvcache/batch funcs for compatibility * use new batching methods * fix indexing bugs * revise abort logic * use cpu seq lengths/block tables * rm unused attr in Sequence * fix type conversion/default arg * add and revise pytests * revise pytests, rm unused tests * rm unused statements * fix pop finished indexing issue * fix: use index in batch when retrieving inputs/update seqs * use dict instead of odict in batch struct * arg type hinting * fix make compress * refine comments * fix: pop_n_seqs to pop the first n seqs * add check in request handler * remove redundant conversion * fix test for request handler * fix pop method in batch bucket * fix prefill adding * [Inference]Fused kv copy into rotary calculation (#5383) * revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * fused kv copy * fused copy * colossalai/kernel/triton/no_pad_rotary_embedding.py * del padding llama * del * Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390) * opt_view_and_memcopy * fix bugs in ci * fix ci bugs * update benchmark scripts * fix ci bugs * [Fix/Inference] Fix format of input prompts and input model in inference engine (#5395) * Fix bugs in inference_engine * fix bugs in engine.py * rm CUDA_VISIBLE_DEVICES * add request_ids in generate * fix bug in engine.py * add logger.debug for BatchBucket * [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * [Inference]Add CUDA KVCache Kernel (#5406) * add cuda KVCache kernel * annotation benchmark_kvcache_copy * add use cuda * fix import path * move benchmark scripts to example/ * rm benchmark codes in test_kv_cache_memcpy.py * rm redundancy codes * rm redundancy codes * pr was modified according to the review * [Inference]Move benchmark-related code to the example directory. (#5408) * move benchmark-related code to the example directory. * fix bugs in test_fused_rotary_embedding.py * add silu_and_mul for infer * [feat] cuda graph support and refactor non-functional api * add reusable utils for cuda * refactor code * feat rmsnorm cuda kernel and add unittest, benchmark script (#5417) * [fix] multi graphs capture error * [fix] multi graphs capture error * [doc] add doc * refactor code * optimize rmsnorm: add vectorized elementwise op, feat loop unrolling (#5441) * fix include path * fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454) * [Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418) * add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline * [fix] pytest and fix dyn grid bug * diverse tests * add implementatino for GetGPULaunchConfig1D * [fix] tmp for test * add some comments * refactor vector utils * [feat] add use_cuda_kernel option * add vec_type_trait implementation (#5473) * [fix] unused option * [fix] * [fix] * [fix] remove unused comment * [Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461) * Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32 * [fix] PR #5354 (#5501) * [fix] * [fix] * Update config.py docstring * [fix] docstring align * [fix] docstring align * [fix] docstring align * [Inference] Optimize request handler of llama (#5512) * optimize request_handler * fix ways of writing * The writing style of tail processing and the logic related to macro definitions have been optimized. (#5519) * [Inference/Kernel]Add get_cos_and_sin Kernel (#5528) * Add get_cos_and_sin kernel * fix code comments * fix code typos * merge common codes of get_cos_and_sin kernel. * Fixed a typo * Changed 'asset allclose' to 'assert equal'. * [Inference] Add Reduce Utils (#5537) * add reduce utils * add using to delele namespace prefix * [Fix/Inference] Remove unused and non-functional functions (#5543) * [fix] remove unused func * rm non-functional partial * add cast and op_functor for cuda build-in types (#5546) * remove unused triton kernels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove outdated triton test * [Infer] Revise and Adapt Triton Kernels for Spec-Dec (#5401) * [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * resolve conflicts for revising flash-attn * adapt kv cache copy kernel for spec-dec * fix seqlen-n kvcache copy kernel/tests * test kvcache copy - use torch.equal * add assertions * (trivial) comment out * [Inference/SpecDec] Add Basic Drafter Model Container (#5405) * [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * add drafter model container (basic ver) * [Inference/SpecDec] Add Speculative Decoding Implementation (#5423) * fix flash decoding mask during verification * add spec-dec * add test for spec-dec * revise drafter init * remove drafter sampling * retire past kv in drafter * (trivial) rename attrs * (trivial) rename arg * revise how we enable/disable spec-dec * [SpecDec] Fix inputs for speculation and revise past KV trimming (#5449) * fix drafter pastkv and usage of batch bucket * [Inference/SpecDec] Support GLIDE Drafter Model (#5455) * add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in engine * [doc] Add inference/speculative-decoding README (#5552) * add README for spec-dec * update roadmap * [Fix] resolve conflicts of rebasing feat/speculative-decoding (#5557) - resolve conflicts of rebasing feat/speculative-decoding * [Fix] Llama Modeling Control with Spec-Dec (#5580) - fix ref before asgmt - fall back to use triton kernels when using spec-dec * refactor csrc (#5582) * [Inference/Refactor] Delete Duplicated code and refactor vec_copy utils and reduce utils (#5593) * delete duplicated code and refactor vec_copy utils and reduce utils * delete unused header file * [inference/model]Adapted to the baichuan2-7B model (#5591) * Adapted to the baichuan2-7B model * modified according to the review comments. * Modified the method of obtaining random weights. * modified according to the review comments. * change mlp layewr 'NOTE' * [Inference/Kernel] Add Paged Decoding kernel, sequence split within the same thread block (#5531) * feat flash decoding for paged attention * refactor flashdecodingattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feat]Tensor Model Parallel Support For Inference (#5563) * tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample * feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611) * [Fix/Inference] Fix GQA Triton and Support Llama3 (#5624) * [fix] GQA calling of flash decoding triton * fix kv cache alloc shape * fix rotary triton - GQA * fix sequence max length assigning * Sequence max length logic * fix scheduling and spec-dec * skip without import error * fix pytest - skip without ImportError --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623) * fix rotary embedding GQA * change test_rotary_embdding_unpad.py KH * [example] Update Llama Inference example (#5629) * [example] add infernece benchmark llama3 * revise inference config - arg * remove unused args * add llama generation demo script * fix init rope in llama policy * add benchmark-llama3 - cleanup * [Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613) * refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gcc * [Fix/Inference]Fix vllm benchmark (#5630) * Fix bugs about OOM when running vllm-0.4.0 * rm used params * change generation_config * change benchmark log file name * [Inference/Kernel] Optimize paged attention: Refactor key cache layout (#5643) * optimize flashdecodingattention: refactor code with different key cache layout(from [num_blocks, num_kv_heads, block_size, head_size] to [num_blocks, num_kv_heads, head_size/x, block_size, x]) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Fix] Remove obsolete files - inference (#5650) * [Inference]Adapt to baichuan2 13B (#5614) * adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py * [kernel] Support new KCache Layout - Context Attention Triton Kernel (#5658) * add context attn triton kernel - new kcache layout * add benchmark triton * tiny revise * trivial - code style, comment * [Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656) * [Inference/Feat] Feat quant kvcache step2 (#5674) * [Inference] Adapt Baichuan2-13B TP (#5659) * adapt to baichuan2 13B * add baichuan2 13B TP * update baichuan tp logic * rm unused code * Fix TP logic * fix alibi slopes tp logic * rm nn.Module * Polished the code. * change BAICHUAN_MODEL_NAME_OR_PATH * Modified the logic for loading Baichuan weights. * fix typos * [Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663) * refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention * [Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680) * [inference]Add alibi to flash attn function (#5678) * add alibi to flash attn function * rm redundant modifications * [Inference] Fix quant bits order (#5681) * [kernel] Support New KCache Layout - Triton Kernel (#5677) * kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Fix] Fix & Update Inference Tests (compatibility w/ main) * [Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679) * [Inference/Feat] Add quant kvcache support for decode_kv_cache_memcpy (#5686) * [hotfix] Fix KV Heads Number Assignment in KVCacheManager (#5695) - Fix key value number assignment in KVCacheManager, as well as method of accessing * [Fix] Fix Inference Example, Tests, and Requirements (#5688) * clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe * [hotfix] fix OpenMOE example import path (#5697) * [Inference]Adapt temperature processing logic (#5689) * Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg * [Inference] Support the logic related to ignoring EOS token (#5693) * Adapt temperature processing logic * add ValueError for top_p and top_k * add GQA Test * fix except_msg * support ignore EOS token * change variable's name * fix annotation * [Inference] ADD async and sync Api server using FastAPI (#5396) * add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template * [Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432) * finish online test and add examples * fix test_contionus_batching * fix some bugs * fix bash * fix * fix inference * finish revision * fix typos * revision * [Online Server] Chat Api for streaming and not streaming response (#5470) * fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n * [Inference] resolve rebase conflicts fix * [Inference] Fix bugs and docs for feat/online-server (#5598) * fix test bugs * add do sample test * del useless lines * fix comments * fix tests * delete version tag * delete version tag * add * del test sever * fix test * fix * Revert "add" This reverts commit b9305fb02440d5cd566d32b508bee9f9c13dda15. * resolve rebase conflicts on Branch feat/online-serving * [Inference] Add example test_ci script * [Inference/Feat] Add quant kvcache interface (#5700) * add quant kvcache interface * delete unused output * complete args comments * [Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706) * add convert_fp8 op for fp8 test in the future * rerun ci * [Inference]Adapt repetition_penalty and no_repeat_ngram_size (#5708) * Adapt repetition_penalty and no_repeat_ngram_size * fix no_repeat_ngram_size_logit_process * remove batch_updated * fix annotation * modified codes based on the review feedback. * rm get_batch_token_ids * [Feat]Inference RPC Server Support (#5705) * rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add paged-attetionv2: support seq length split across thread block (#5707) * [Inference] Delete duplicated copy_vector (#5716) * [ci] Fix example tests (#5714) * [fix] revise timeout value on example CI * trivial * [Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717) * Fix Llama3 Load error * Omit Checkpoint IO Temporarily * [Inference] Fix API server, test and example (#5712) * fix api server * fix generation config * fix api server * fix comments * fix infer hanging bug * resolve comments, change backend to free port * 【Inference] Delete duplicated package (#5723) * [example] Update Inference Example (#5725) * [example] update inference example * [lazy] fix lazy cls init (#5720) * fix * fix * fix * fix * fix * remove kernel intall * rebase revert fix * fix * fix * [Inference] Fix Inference Generation Config and Sampling (#5710) * refactor and add * config default values * fix gen config passing * fix rpc generation config * [Fix/Inference] Add unsupported auto-policy error message (#5730) * [fix] auto policy error message * trivial * [doc] Update Inference Readme (#5736) * [doc] update inference readme * add contents * trivial * [Shardformer] Add parallel output for shardformer models(bloom, falcon) (#5702) * [pre-commit.ci] auto fixes from pre-commit.com hooks * add parallel cross entropy output for falcon model & fix some typos in bloom.py * fix module name error, self.model -> self.transformers in bloom, falcon model * Fix the overflow bug of distributed cross entropy loss function when training with fp16 * add dtype to parallel cross entropy loss function * fix dtype related typos adn prettify the loss.py * fix grad dtype and update dtype mismatch error * fix typo bugs * [bug] fix silly bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] add test for prefetch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [ci] Temporary fix for build on pr (#5741) * temporary fix for CI * timeout to 90 * [NFC] Fix code factors on inference triton kernels (#5743) * [NFC] fix requirements (#5744) * [inference] release (#5747) * [inference] release * [inference] release * [inference] release * [inference] release * [inference] release * [inference] release * [inference] release --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Round Heng Co-authored-by: FrankLeeeee Co-authored-by: Yaozheng Fang <62918515+nkfyz@users.noreply.github.com> Co-authored-by: xs_courtesy Co-authored-by: Runyu Lu Co-authored-by: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: Yuanheng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: CjhHa1 Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Haze188 Co-authored-by: binmakeswell --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/doc_test_on_pr.yml | 2 +- .github/workflows/example_check_on_pr.yml | 25 +- .../workflows/example_check_on_schedule.yml | 2 +- README.md | 37 +- colossalai/inference/README.md | 355 ++++---- colossalai/inference/__init__.py | 6 +- colossalai/inference/batch_bucket.py | 523 +++++++++++ colossalai/inference/config.py | 342 +++++++ colossalai/inference/core/__init__.py | 4 + colossalai/inference/core/async_engine.py | 332 +++++++ colossalai/inference/core/engine.py | 763 ++++++++++++++++ colossalai/inference/core/plugin.py | 140 +++ colossalai/inference/core/request_handler.py | 401 +++++++++ colossalai/inference/core/rpc_engine.py | 296 ++++++ colossalai/inference/engine/__init__.py | 3 - colossalai/inference/engine/engine.py | 195 ---- .../inference/engine/microbatch_manager.py | 248 ----- .../inference/engine/modeling/__init__.py | 5 - .../inference/engine/modeling/_utils.py | 67 -- colossalai/inference/engine/modeling/bloom.py | 452 ---------- .../inference/engine/modeling/chatglm2.py | 492 ---------- colossalai/inference/engine/modeling/llama.py | 503 ----------- .../inference/engine/policies/__init__.py | 11 - colossalai/inference/engine/policies/bloom.py | 127 --- .../inference/engine/policies/chatglm2.py | 89 -- colossalai/inference/engine/policies/llama.py | 206 ----- colossalai/inference/executor/rpc_worker.py | 306 +++++++ colossalai/inference/flash_decoding_utils.py | 82 ++ colossalai/inference/graph_runner.py | 100 ++ colossalai/inference/kv_cache/__init__.py | 6 +- .../inference/kv_cache/batch_infer_state.py | 118 --- colossalai/inference/kv_cache/block_cache.py | 58 ++ .../inference/kv_cache/kvcache_manager.py | 658 ++++++++++++-- colossalai/inference/logit_processors.py | 189 ++++ .../smoothquant => modeling}/__init__.py | 0 .../inference/modeling/layers/__init__.py | 0 .../inference/modeling/layers/attention.py | 313 +++++++ .../modeling/layers/baichuan_tp_linear.py | 43 + .../inference/modeling/models/__init__.py | 0 .../inference/modeling/models/glide_llama.py | 475 ++++++++++ .../modeling/models/nopadding_baichuan.py | 421 +++++++++ .../modeling/models/nopadding_llama.py | 697 ++++++++++++++ .../inference/modeling/policy/__init__.py | 16 + .../inference/modeling/policy/glide_llama.py | 45 + .../modeling/policy/nopadding_baichuan.py | 110 +++ .../modeling/policy/nopadding_llama.py | 112 +++ colossalai/inference/quant/__init__.py | 1 - colossalai/inference/quant/gptq/__init__.py | 5 - .../inference/quant/gptq/cai_gptq/__init__.py | 14 - .../quant/gptq/cai_gptq/cai_quant_linear.py | 354 -------- .../inference/quant/gptq/cai_gptq/gptq_op.py | 58 -- .../inference/quant/gptq/gptq_manager.py | 61 -- .../quant/smoothquant/models/__init__.py | 10 - .../quant/smoothquant/models/base_model.py | 494 ---------- .../quant/smoothquant/models/linear.py | 189 ---- .../quant/smoothquant/models/llama.py | 852 ------------------ .../smoothquant/models/parallel_linear.py | 264 ------ colossalai/inference/sampler.py | 111 +++ colossalai/inference/server/README.md | 27 + colossalai/inference/server/__init__.py | 0 colossalai/inference/server/api_server.py | 228 +++++ colossalai/inference/server/chat_service.py | 142 +++ .../inference/server/completion_service.py | 34 + colossalai/inference/server/utils.py | 36 + colossalai/inference/spec/__init__.py | 4 + colossalai/inference/spec/drafter.py | 121 +++ colossalai/inference/spec/struct.py | 55 ++ colossalai/inference/struct.py | 180 ++++ colossalai/inference/utils.py | 115 +++ colossalai/kernel/kernel_loader.py | 6 + colossalai/kernel/triton/__init__.py | 32 +- colossalai/kernel/triton/context_attention.py | 434 --------- .../kernel/triton/context_attn_unpad.py | 727 +++++++++++++++ .../kernel/triton/copy_kv_cache_dest.py | 71 -- colossalai/kernel/triton/custom_autotune.py | 176 ---- colossalai/kernel/triton/flash_decoding.py | 574 +++++++++++- colossalai/kernel/triton/fused_layernorm.py | 78 -- .../kernel/triton/fused_rotary_embedding.py | 181 ++++ colossalai/kernel/triton/gptq_triton.py | 543 ----------- .../triton/int8_rotary_embedding_kernel.py | 117 --- colossalai/kernel/triton/kvcache_copy.py | 296 ++++++ .../kernel/triton/no_pad_rotary_embedding.py | 659 ++++++++++++++ colossalai/kernel/triton/rms_layernorm.py | 116 +++ colossalai/kernel/triton/rotary_cache_copy.py | 147 +++ .../kernel/triton/self_attention_nofusion.py | 164 ---- colossalai/kernel/triton/smooth_attention.py | 652 -------------- .../kernel/triton/token_attention_kernel.py | 244 ----- colossalai/lazy/pretrained.py | 23 + .../legacy/inference/hybridengine/engine.py | 2 +- colossalai/shardformer/layer/embedding.py | 1 - colossalai/shardformer/layer/loss.py | 15 +- colossalai/shardformer/modeling/bloom.py | 100 +- colossalai/shardformer/modeling/falcon.py | 99 +- colossalai/shardformer/modeling/gpt2.py | 2 + colossalai/shardformer/modeling/llama.py | 2 + colossalai/shardformer/modeling/mistral.py | 2 + colossalai/shardformer/modeling/opt.py | 2 + colossalai/shardformer/policies/bloom.py | 9 +- colossalai/shardformer/policies/falcon.py | 16 +- colossalai/shardformer/shard/shard_config.py | 6 - colossalai/shardformer/shard/shardformer.py | 7 +- docs/README-zh-Hans.md | 37 +- examples/inference/benchmark_llama.py | 167 ---- .../benchmark_context_attn_unpad.py | 133 +++ .../benchmark_ops/benchmark_decoding_attn.py | 143 +++ .../benchmark_flash_decoding_attention.py | 185 ++++ .../benchmark_fused_rotary_embdding_unpad.py | 137 +++ .../benchmark_kv_cache_memcopy.py | 91 ++ .../benchmark_ops/benchmark_rmsnorm.py | 87 ++ .../benchmark_rotary_embedding.py | 76 ++ .../benchmark_ops/benchmark_xine_copy.py | 54 ++ examples/inference/benchmark_ops/test_ci.sh | 0 .../inference/build_smoothquant_weight.py | 59 -- examples/inference/client/locustfile.py | 58 ++ examples/inference/client/run_locust.sh | 28 + examples/inference/client/test_ci.sh | 4 + examples/inference/llama/README.md | 47 + examples/inference/llama/benchmark_llama.py | 275 ++++++ examples/inference/llama/benchmark_llama3.py | 216 +++++ examples/inference/llama/llama_generation.py | 105 +++ examples/inference/llama/run_benchmark.sh | 33 + examples/inference/llama/test_ci.sh | 4 + examples/inference/run_benchmark.sh | 15 - examples/inference/run_llama_inference.py | 98 -- .../openmoe/model/modeling_openmoe.py | 15 +- extensions/__init__.py | 19 +- extensions/cpp_extension.py | 4 + extensions/csrc/__init__.py | 11 - extensions/csrc/common/data_type.h | 53 ++ .../{cuda/type_shim.h => common/micros.h} | 131 +-- extensions/csrc/common/mp_type_traits.h | 55 ++ extensions/csrc/common/target.h | 134 +++ extensions/csrc/common/vec_type_traits.h | 76 ++ extensions/csrc/cuda/compat.h | 10 - extensions/csrc/cuda/include/block_reduce.h | 312 ------- .../csrc/cuda/scaled_masked_softmax_cuda.cu | 89 -- ...scaled_upper_triang_masked_softmax_cuda.cu | 75 -- extensions/csrc/funcs/binary_functor.h | 231 +++++ extensions/csrc/funcs/cast_functor.h | 503 +++++++++++ extensions/csrc/funcs/reduce_function.h | 94 ++ extensions/csrc/funcs/ternary_functor.h | 214 +++++ extensions/csrc/funcs/unary_functor.h | 67 ++ .../csrc/{ => kernel}/arm/cpu_adam_arm.cpp | 0 .../csrc/{ => kernel}/arm/cpu_adam_arm.h | 0 .../csrc/kernel/cuda/activation_kernel.cu | 77 ++ .../kernel/cuda/attention/attention_utils.h | 210 +++++ .../cuda/context_kv_cache_memcpy_kernel.cu | 238 +++++ .../csrc/kernel/cuda/convert_fp8_kernel.cu | 127 +++ .../cuda/decode_kv_cache_memcpy_kernel.cu | 216 +++++ .../cuda/flash_decoding_attention_kernel.cu | 831 +++++++++++++++++ .../cuda/fused_rotary_emb_and_cache_kernel.cu | 526 +++++++++++ .../kernel/cuda/get_cos_and_sin_kernel.cu | 218 +++++ .../cuda/layer_norm_kernel.cu} | 34 +- .../cuda/moe_kernel.cu} | 56 +- .../cuda/multi_tensor_adam_kernel.cu} | 2 +- .../{ => kernel}/cuda/multi_tensor_apply.cuh | 4 +- .../cuda/multi_tensor_l2norm_kernel.cu | 114 ++- .../cuda/multi_tensor_lamb_kernel.cu} | 8 +- .../cuda/multi_tensor_scale_kernel.cu | 2 +- .../cuda/multi_tensor_sgd_kernel.cu | 2 +- .../csrc/kernel/cuda/rms_layernorm_kernel.cu | 348 +++++++ .../cuda/scaled_masked_softmax_kernel.cu} | 199 ++-- ...led_upper_triang_masked_softmax_kernel.cu} | 221 ++--- .../kernel/cuda/utils/gpu_launch_config.h | 78 ++ extensions/csrc/kernel/cuda/utils/micros.h | 18 + .../csrc/kernel/cuda/utils/nvgpu_dev_info.h | 60 ++ extensions/csrc/kernel/cuda/utils/vec_copy.h | 64 ++ .../csrc/{cuda => kernel/x86}/cpu_adam.cpp | 0 .../csrc/{cuda => kernel/x86}/cpu_adam.h | 0 extensions/csrc/scaled_softmax.py | 190 ---- extensions/cuda_extension.py | 7 + extensions/pybind/__init__.py | 0 extensions/{ => pybind}/cpu_adam/__init__.py | 0 .../{ => pybind}/cpu_adam/cpu_adam_arm.py | 9 +- .../{ => pybind}/cpu_adam/cpu_adam_x86.py | 11 +- .../{ => pybind}/flash_attention/__init__.py | 0 .../flash_attention_dao_cuda.py | 2 +- .../flash_attention/flash_attention_npu.py | 2 +- .../flash_attention_sdpa_cuda.py | 2 +- extensions/pybind/inference/__init__.py | 3 + extensions/pybind/inference/inference.cpp | 111 +++ .../pybind/inference/inference_ops_cuda.py | 32 + extensions/{ => pybind}/layernorm/__init__.py | 0 .../layernorm/layer_norm.cpp} | 2 +- .../{ => pybind}/layernorm/layernorm_cuda.py | 12 +- extensions/{ => pybind}/moe/__init__.py | 0 .../cuda/moe_cuda.cpp => pybind/moe/moe.cpp} | 0 extensions/{ => pybind}/moe/moe_cuda.py | 14 +- extensions/{ => pybind}/optimizer/__init__.py | 0 .../optimizer/fused_optimizer_cuda.py | 23 +- .../optimizer/optimizer.cpp} | 0 extensions/{ => pybind}/softmax/__init__.py | 0 .../softmax}/scaled_masked_softmax.cpp | 26 +- .../softmax/scaled_masked_softmax_cuda.py | 14 +- .../scaled_upper_triang_masked_softmax.cpp | 14 +- ...aled_upper_triangle_masked_softmax_cuda.py | 14 +- requirements/requirements-infer.txt | 4 - requirements/requirements-test.txt | 1 - requirements/requirements.txt | 1 + setup.py | 1 - tests/test_gptq/test_gptq_linear.py | 144 --- tests/test_infer/__init__.py | 0 tests/test_infer/_utils.py | 0 .../test_async_engine/test_async_engine.py | 80 ++ .../test_async_engine/test_request_tracer.py | 68 ++ tests/test_infer/test_batch_bucket.py | 144 +++ tests/test_infer/test_config_and_struct.py | 46 + tests/test_infer/test_continuous_batching.py | 103 +++ tests/test_infer/test_cuda_graph.py | 96 ++ tests/test_infer/test_drafter.py | 74 ++ tests/test_infer/test_hybrid_bloom.py | 121 --- tests/test_infer/test_hybrid_chatglm2.py | 129 --- tests/test_infer/test_hybrid_llama.py | 126 --- tests/test_infer/test_inference_engine.py | 208 +++++ tests/test_infer/test_kernels/__init__.py | 0 .../test_infer/test_kernels/cuda/__init__.py | 0 .../test_kernels/cuda/test_convert_fp8.py | 57 ++ .../cuda/test_flash_decoding_attention.py | 334 +++++++ .../test_kernels/cuda/test_get_cos_and_sin.py | 53 ++ .../test_kernels/cuda/test_kv_cache_memcpy.py | 157 ++++ .../test_kernels/cuda/test_rms_layernorm.py | 51 ++ .../cuda/test_rotary_embdding_unpad.py | 130 +++ .../test_kernels/cuda/test_silu_and_mul.py | 33 + .../test_kernels/triton/__init__.py | 0 .../test_kernels/triton/kernel_utils.py | 348 +++++++ .../triton/test_context_attn_unpad.py | 179 ++++ .../test_kernels/triton/test_decoding_attn.py | 197 ++++ .../triton/test_fused_rotary_embedding.py | 50 + .../test_kernels/triton/test_kvcache_copy.py | 168 ++++ .../triton/test_rmsnorm_triton.py | 55 ++ .../triton/test_rotary_embdding_unpad.py | 100 ++ .../test_kernels/triton/test_xine_copy.py | 66 ++ tests/test_infer/test_kvcache_manager.py | 213 ++++- .../test_infer/test_models/test_attention.py | 145 +++ tests/test_infer/test_models/test_baichuan.py | 138 +++ tests/test_infer/test_request_handler.py | 105 +++ tests/test_infer/test_rpc_engine.py | 105 +++ 238 files changed, 20771 insertions(+), 9780 deletions(-) create mode 100644 colossalai/inference/batch_bucket.py create mode 100644 colossalai/inference/config.py create mode 100644 colossalai/inference/core/__init__.py create mode 100644 colossalai/inference/core/async_engine.py create mode 100644 colossalai/inference/core/engine.py create mode 100644 colossalai/inference/core/plugin.py create mode 100644 colossalai/inference/core/request_handler.py create mode 100644 colossalai/inference/core/rpc_engine.py delete mode 100644 colossalai/inference/engine/__init__.py delete mode 100644 colossalai/inference/engine/engine.py delete mode 100644 colossalai/inference/engine/microbatch_manager.py delete mode 100644 colossalai/inference/engine/modeling/__init__.py delete mode 100644 colossalai/inference/engine/modeling/_utils.py delete mode 100644 colossalai/inference/engine/modeling/bloom.py delete mode 100644 colossalai/inference/engine/modeling/chatglm2.py delete mode 100644 colossalai/inference/engine/modeling/llama.py delete mode 100644 colossalai/inference/engine/policies/__init__.py delete mode 100644 colossalai/inference/engine/policies/bloom.py delete mode 100644 colossalai/inference/engine/policies/chatglm2.py delete mode 100644 colossalai/inference/engine/policies/llama.py create mode 100644 colossalai/inference/executor/rpc_worker.py create mode 100644 colossalai/inference/flash_decoding_utils.py create mode 100644 colossalai/inference/graph_runner.py delete mode 100644 colossalai/inference/kv_cache/batch_infer_state.py create mode 100644 colossalai/inference/kv_cache/block_cache.py create mode 100644 colossalai/inference/logit_processors.py rename colossalai/inference/{quant/smoothquant => modeling}/__init__.py (100%) create mode 100644 colossalai/inference/modeling/layers/__init__.py create mode 100644 colossalai/inference/modeling/layers/attention.py create mode 100644 colossalai/inference/modeling/layers/baichuan_tp_linear.py create mode 100644 colossalai/inference/modeling/models/__init__.py create mode 100644 colossalai/inference/modeling/models/glide_llama.py create mode 100644 colossalai/inference/modeling/models/nopadding_baichuan.py create mode 100644 colossalai/inference/modeling/models/nopadding_llama.py create mode 100644 colossalai/inference/modeling/policy/__init__.py create mode 100644 colossalai/inference/modeling/policy/glide_llama.py create mode 100644 colossalai/inference/modeling/policy/nopadding_baichuan.py create mode 100644 colossalai/inference/modeling/policy/nopadding_llama.py delete mode 100644 colossalai/inference/quant/__init__.py delete mode 100644 colossalai/inference/quant/gptq/__init__.py delete mode 100644 colossalai/inference/quant/gptq/cai_gptq/__init__.py delete mode 100644 colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py delete mode 100644 colossalai/inference/quant/gptq/cai_gptq/gptq_op.py delete mode 100644 colossalai/inference/quant/gptq/gptq_manager.py delete mode 100644 colossalai/inference/quant/smoothquant/models/__init__.py delete mode 100644 colossalai/inference/quant/smoothquant/models/base_model.py delete mode 100644 colossalai/inference/quant/smoothquant/models/linear.py delete mode 100644 colossalai/inference/quant/smoothquant/models/llama.py delete mode 100644 colossalai/inference/quant/smoothquant/models/parallel_linear.py create mode 100644 colossalai/inference/sampler.py create mode 100644 colossalai/inference/server/README.md create mode 100644 colossalai/inference/server/__init__.py create mode 100644 colossalai/inference/server/api_server.py create mode 100644 colossalai/inference/server/chat_service.py create mode 100644 colossalai/inference/server/completion_service.py create mode 100644 colossalai/inference/server/utils.py create mode 100644 colossalai/inference/spec/__init__.py create mode 100644 colossalai/inference/spec/drafter.py create mode 100644 colossalai/inference/spec/struct.py create mode 100644 colossalai/inference/struct.py create mode 100644 colossalai/inference/utils.py delete mode 100644 colossalai/kernel/triton/context_attention.py create mode 100644 colossalai/kernel/triton/context_attn_unpad.py delete mode 100644 colossalai/kernel/triton/copy_kv_cache_dest.py delete mode 100644 colossalai/kernel/triton/custom_autotune.py delete mode 100644 colossalai/kernel/triton/fused_layernorm.py create mode 100644 colossalai/kernel/triton/fused_rotary_embedding.py delete mode 100644 colossalai/kernel/triton/gptq_triton.py delete mode 100644 colossalai/kernel/triton/int8_rotary_embedding_kernel.py create mode 100644 colossalai/kernel/triton/kvcache_copy.py create mode 100644 colossalai/kernel/triton/no_pad_rotary_embedding.py create mode 100644 colossalai/kernel/triton/rms_layernorm.py create mode 100644 colossalai/kernel/triton/rotary_cache_copy.py delete mode 100644 colossalai/kernel/triton/self_attention_nofusion.py delete mode 100644 colossalai/kernel/triton/smooth_attention.py delete mode 100644 colossalai/kernel/triton/token_attention_kernel.py delete mode 100644 examples/inference/benchmark_llama.py create mode 100644 examples/inference/benchmark_ops/benchmark_context_attn_unpad.py create mode 100644 examples/inference/benchmark_ops/benchmark_decoding_attn.py create mode 100644 examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py create mode 100644 examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py create mode 100644 examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py create mode 100644 examples/inference/benchmark_ops/benchmark_rmsnorm.py create mode 100644 examples/inference/benchmark_ops/benchmark_rotary_embedding.py create mode 100644 examples/inference/benchmark_ops/benchmark_xine_copy.py create mode 100644 examples/inference/benchmark_ops/test_ci.sh delete mode 100644 examples/inference/build_smoothquant_weight.py create mode 100644 examples/inference/client/locustfile.py create mode 100644 examples/inference/client/run_locust.sh create mode 100644 examples/inference/client/test_ci.sh create mode 100644 examples/inference/llama/README.md create mode 100644 examples/inference/llama/benchmark_llama.py create mode 100644 examples/inference/llama/benchmark_llama3.py create mode 100644 examples/inference/llama/llama_generation.py create mode 100755 examples/inference/llama/run_benchmark.sh create mode 100644 examples/inference/llama/test_ci.sh delete mode 100755 examples/inference/run_benchmark.sh delete mode 100644 examples/inference/run_llama_inference.py create mode 100644 extensions/csrc/common/data_type.h rename extensions/csrc/{cuda/type_shim.h => common/micros.h} (85%) create mode 100644 extensions/csrc/common/mp_type_traits.h create mode 100644 extensions/csrc/common/target.h create mode 100644 extensions/csrc/common/vec_type_traits.h delete mode 100644 extensions/csrc/cuda/compat.h delete mode 100644 extensions/csrc/cuda/include/block_reduce.h delete mode 100644 extensions/csrc/cuda/scaled_masked_softmax_cuda.cu delete mode 100644 extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu create mode 100644 extensions/csrc/funcs/binary_functor.h create mode 100644 extensions/csrc/funcs/cast_functor.h create mode 100644 extensions/csrc/funcs/reduce_function.h create mode 100644 extensions/csrc/funcs/ternary_functor.h create mode 100644 extensions/csrc/funcs/unary_functor.h rename extensions/csrc/{ => kernel}/arm/cpu_adam_arm.cpp (100%) rename extensions/csrc/{ => kernel}/arm/cpu_adam_arm.h (100%) create mode 100644 extensions/csrc/kernel/cuda/activation_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/attention/attention_utils.h create mode 100644 extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/convert_fp8_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu create mode 100644 extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu rename extensions/csrc/{cuda/layer_norm_cuda_kernel.cu => kernel/cuda/layer_norm_kernel.cu} (96%) rename extensions/csrc/{cuda/moe_cuda_kernel.cu => kernel/cuda/moe_kernel.cu} (93%) rename extensions/csrc/{cuda/multi_tensor_adam.cu => kernel/cuda/multi_tensor_adam_kernel.cu} (99%) rename extensions/csrc/{ => kernel}/cuda/multi_tensor_apply.cuh (98%) rename extensions/csrc/{ => kernel}/cuda/multi_tensor_l2norm_kernel.cu (81%) rename extensions/csrc/{cuda/multi_tensor_lamb.cu => kernel/cuda/multi_tensor_lamb_kernel.cu} (98%) rename extensions/csrc/{ => kernel}/cuda/multi_tensor_scale_kernel.cu (99%) rename extensions/csrc/{ => kernel}/cuda/multi_tensor_sgd_kernel.cu (99%) create mode 100644 extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu rename extensions/csrc/{cuda/scaled_masked_softmax.h => kernel/cuda/scaled_masked_softmax_kernel.cu} (80%) rename extensions/csrc/{cuda/scaled_upper_triang_masked_softmax.h => kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu} (83%) create mode 100644 extensions/csrc/kernel/cuda/utils/gpu_launch_config.h create mode 100644 extensions/csrc/kernel/cuda/utils/micros.h create mode 100644 extensions/csrc/kernel/cuda/utils/nvgpu_dev_info.h create mode 100644 extensions/csrc/kernel/cuda/utils/vec_copy.h rename extensions/csrc/{cuda => kernel/x86}/cpu_adam.cpp (100%) rename extensions/csrc/{cuda => kernel/x86}/cpu_adam.h (100%) delete mode 100644 extensions/csrc/scaled_softmax.py create mode 100644 extensions/pybind/__init__.py rename extensions/{ => pybind}/cpu_adam/__init__.py (100%) rename extensions/{ => pybind}/cpu_adam/cpu_adam_arm.py (80%) rename extensions/{ => pybind}/cpu_adam/cpu_adam_x86.py (83%) rename extensions/{ => pybind}/flash_attention/__init__.py (100%) rename extensions/{ => pybind}/flash_attention/flash_attention_dao_cuda.py (98%) rename extensions/{ => pybind}/flash_attention/flash_attention_npu.py (97%) rename extensions/{ => pybind}/flash_attention/flash_attention_sdpa_cuda.py (97%) create mode 100644 extensions/pybind/inference/__init__.py create mode 100644 extensions/pybind/inference/inference.cpp create mode 100644 extensions/pybind/inference/inference_ops_cuda.py rename extensions/{ => pybind}/layernorm/__init__.py (100%) rename extensions/{csrc/cuda/layer_norm_cuda.cpp => pybind/layernorm/layer_norm.cpp} (99%) rename extensions/{ => pybind}/layernorm/layernorm_cuda.py (57%) rename extensions/{ => pybind}/moe/__init__.py (100%) rename extensions/{csrc/cuda/moe_cuda.cpp => pybind/moe/moe.cpp} (100%) rename extensions/{ => pybind}/moe/moe_cuda.py (58%) rename extensions/{ => pybind}/optimizer/__init__.py (100%) rename extensions/{ => pybind}/optimizer/fused_optimizer_cuda.py (50%) rename extensions/{csrc/cuda/colossal_C_frontend.cpp => pybind/optimizer/optimizer.cpp} (100%) rename extensions/{ => pybind}/softmax/__init__.py (100%) rename extensions/{csrc/cuda => pybind/softmax}/scaled_masked_softmax.cpp (71%) rename extensions/{ => pybind}/softmax/scaled_masked_softmax_cuda.py (66%) rename extensions/{csrc/cuda => pybind/softmax}/scaled_upper_triang_masked_softmax.cpp (79%) rename extensions/{ => pybind}/softmax/scaled_upper_triangle_masked_softmax_cuda.py (65%) delete mode 100644 requirements/requirements-infer.txt delete mode 100644 tests/test_gptq/test_gptq_linear.py create mode 100644 tests/test_infer/__init__.py mode change 100644 => 100755 tests/test_infer/_utils.py create mode 100644 tests/test_infer/test_async_engine/test_async_engine.py create mode 100644 tests/test_infer/test_async_engine/test_request_tracer.py create mode 100644 tests/test_infer/test_batch_bucket.py create mode 100755 tests/test_infer/test_config_and_struct.py create mode 100644 tests/test_infer/test_continuous_batching.py create mode 100644 tests/test_infer/test_cuda_graph.py create mode 100644 tests/test_infer/test_drafter.py delete mode 100644 tests/test_infer/test_hybrid_bloom.py delete mode 100644 tests/test_infer/test_hybrid_chatglm2.py delete mode 100644 tests/test_infer/test_hybrid_llama.py create mode 100644 tests/test_infer/test_inference_engine.py create mode 100644 tests/test_infer/test_kernels/__init__.py create mode 100644 tests/test_infer/test_kernels/cuda/__init__.py create mode 100644 tests/test_infer/test_kernels/cuda/test_convert_fp8.py create mode 100644 tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py create mode 100644 tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py create mode 100644 tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py create mode 100644 tests/test_infer/test_kernels/cuda/test_rms_layernorm.py create mode 100644 tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py create mode 100644 tests/test_infer/test_kernels/cuda/test_silu_and_mul.py create mode 100644 tests/test_infer/test_kernels/triton/__init__.py create mode 100644 tests/test_infer/test_kernels/triton/kernel_utils.py create mode 100644 tests/test_infer/test_kernels/triton/test_context_attn_unpad.py create mode 100644 tests/test_infer/test_kernels/triton/test_decoding_attn.py create mode 100644 tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py create mode 100644 tests/test_infer/test_kernels/triton/test_kvcache_copy.py create mode 100644 tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py create mode 100644 tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py create mode 100644 tests/test_infer/test_kernels/triton/test_xine_copy.py mode change 100644 => 100755 tests/test_infer/test_kvcache_manager.py create mode 100644 tests/test_infer/test_models/test_attention.py create mode 100644 tests/test_infer/test_models/test_baichuan.py create mode 100644 tests/test_infer/test_request_handler.py create mode 100644 tests/test_infer/test_rpc_engine.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 5bdadca783b3..0c3a55905764 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -91,7 +91,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny - timeout-minutes: 60 + timeout-minutes: 90 defaults: run: shell: bash diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index 27f7e76af4fe..31c421846e2c 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -58,7 +58,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm - timeout-minutes: 20 + timeout-minutes: 30 defaults: run: shell: bash diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 6170628e10a9..56fa006b1633 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -8,6 +8,7 @@ on: # any change in the examples folder will trigger check for the corresponding example. paths: - "examples/**" + - "!examples/**.md" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. @@ -19,6 +20,7 @@ jobs: outputs: matrix: ${{ steps.setup-matrix.outputs.matrix }} anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }} + anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} name: Detect changed example files concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change @@ -37,6 +39,16 @@ jobs: echo $commonCommit echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + - name: Find the changed extension-related files + id: find-extension-change + uses: tj-actions/changed-files@v35 + with: + base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} + files: | + op_builder/** + colossalai/kernel/** + setup.py + - name: Get all changed example files id: changed-files uses: tj-actions/changed-files@v35 @@ -79,17 +91,28 @@ jobs: container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm - timeout-minutes: 20 + timeout-minutes: 30 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true steps: - uses: actions/checkout@v3 + - name: Restore Colossal-AI Cache + if: needs.detect.outputs.anyExtensionFileChanged != 'true' + run: | + if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then + cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ + fi + - name: Install Colossal-AI run: | BUILD_EXT=1 pip install -v . + - name: Store Colossal-AI Cache + run: | + cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ + - name: Test the example run: | example_dir=${{ matrix.directory }} diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 2588ac8243d9..6ec1b0591fc3 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -36,7 +36,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm - timeout-minutes: 10 + timeout-minutes: 30 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/README.md b/README.md index 3157d74c91c5..e41b75c46987 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## Latest News +* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) @@ -75,11 +76,9 @@
  • Inference
  • @@ -377,6 +376,19 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt ## Inference +### Colossal-Inference +

    + +

    + +

    + +

    + + - Large AI models inference speed doubled, compared to the offline inference performance of vLLM in some cases. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference) +[[blog]](https://hpc-ai.com/blog/colossal-inference) + ### Grok-1

    @@ -389,30 +401,13 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt [[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) [[ModelScope Grok-1 PyTorch model weights]](https://www.modelscope.cn/models/colossalai/grok-1-pytorch/summary) +### SwiftInfer

    - [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Inference performance improved by 46%, open source solution breaks the length limit of LLM for multi-round conversations -

    - -

    - -- [Energon-AI](https://github.com/hpcaitech/EnergonAI): 50% inference acceleration on the same hardware - -

    - -

    - -- [OPT Serving](https://colossalai.org/docs/advanced_tutorials/opt_service): Try 175-billion-parameter OPT online services - -

    - -

    - -- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 176-billion-parameter BLOOM by more than 10 times. -

    (back to top)

    ## Installation diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 0bdaf347d295..cdb32a0f89a2 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -1,229 +1,256 @@ -# 🚀 Colossal-Inference +# ⚡️ ColossalAI-Inference +## 📚 Table of Contents -## Table of Contents +- [⚡️ ColossalAI-Inference](#️-colossalai-inference) + - [📚 Table of Contents](#-table-of-contents) + - [📌 Introduction](#-introduction) + - [🕹 Usage](#-usage) + - [🗺 Roadmap](#-roadmap) + - [🪅 Support Matrix](#-support-matrix) + - [🛠 Design and Components](#-design-and-components) + - [Overview](#overview) + - [Engine](#engine) + - [Blocked KV Cache Manager](#kv-cache) + - [Batching](#batching) + - [Modeling](#modeling) + - [🌟 Acknowledgement](#-acknowledgement) -- [💡 Introduction](#introduction) -- [🔗 Design](#design) -- [🔨 Usage](#usage) - - [Quick start](#quick-start) - - [Example](#example) -- [📊 Performance](#performance) -## Introduction +## 📌 Introduction +ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference) -`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. +

    + +

    -## Design +

    + +

    -Colossal Inference is composed of three main components: +## 🕹 Usage -1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly. -2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference. - 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release. - 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch. -3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods. - 1. `HybridEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel, pipline parallel) inference: - 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama) - 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way. +### :arrow_right: Quick Start +The sample usage of the inference engine is given below: -## Architecture of inference: - -In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. - -Colossal-Inference - -## Roadmap of our implementation - -- [x] Design cache manager and batch infer state -- [x] Design TpInference engine to integrates with `Shardformer` -- [x] Register corresponding high-performance `kernel` and `ops` -- [x] Design policies and forwards (e.g. `Llama` and `Bloom`) - - [x] policy - - [x] context forward - - [x] token forward - - [x] support flash-decoding -- [x] Support all models - - [x] Llama - - [x] Llama-2 - - [x] Bloom - - [x] Chatglm2 -- [x] Quantization - - [x] GPTQ - - [x] SmoothQuant -- [ ] Benchmarking for all models - -## Get started +```python +import torch +import transformers +import colossalai +from colossalai.inference import InferenceEngine, InferenceConfig +from pprint import pprint -### Installation +colossalai.launch_from_torch() -```bash -pip install -e . +# Step 1: create a model in "transformers" way +model_path = "lmsys/vicuna-7b-v1.3" +model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + +# Step 2: create an inference_config +inference_config = InferenceConfig( + dtype=torch.float16, + max_batch_size=4, + max_input_len=1024, + max_output_len=512, + use_cuda_kernel=True, + ) + +# Step 3: create an engine with model and config +engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + +# Step 4: try inference +prompts = ['Who is the best player in the history of NBA?'] +response = engine.generate(prompts) +pprint(response) ``` -### Requirements - -Install dependencies. - +You could run the sample code by ```bash -pip install -r requirements/requirements-infer.txt - -# if you want use smoothquant quantization, please install torch-int -git clone --recurse-submodules https://github.com/Guangxuan-Xiao/torch-int.git -cd torch-int -git checkout 65266db1eadba5ca78941b789803929e6e6c6856 -pip install -r requirements.txt -source environment.sh -bash build_cutlass.sh -python setup.py install +colossalai run --nproc_per_node 1 your_sample_name.py ``` -### Docker +For detailed examples, you might want to check [inference examples](../../examples/inference/llama/README.md). -You can use docker run to use docker container to set-up environment +### :bookmark: Customize your inference engine +Besides the basic quick-start inference, you can also customize your inference engine via modifying inference config or uploading your own models, policies, or decoding components (logits processors or sampling strategies). -``` -# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support -docker pull hpcaitech/colossalai-inference:v2 -docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash +#### Inference Config +Inference Config is a unified config for initializing the inference engine, controlling multi-GPU generation (Tensor Parallelism), as well as presetting generation configs. Below are some commonly used `InferenceConfig`'s arguments: -# enter into docker container -cd /path/to/ColossalAI -pip install -e . +- `max_batch_size`: The maximum batch size. Defaults to 8. +- `max_input_len`: The maximum input length (number of tokens). Defaults to 256. +- `max_output_len`: The maximum output length (number of tokens). Defaults to 256. +- `dtype`: The data type of the model for inference. This can be one of `fp16`, `bf16`, or `fp32`. Defaults to `fp16`. +- `kv_cache_dtype`: The data type used for KVCache. Defaults to the same data type as the model (`dtype`). KVCache quantization will be automatically enabled if it is different from that of model (`dtype`). +- `use_cuda_kernel`: Determine whether to use CUDA kernels or not. If disabled, Triton kernels will be used. Defaults to False. +- `tp_size`: Tensor-Parallelism size. Defaults to 1 (tensor parallelism is turned off by default). -``` +#### Generation Config +Refer to transformers [GenerationConfig](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig) on functionalities and usage of specific configs. In ColossalAI-Inference, generation configs can be preset in `InferenceConfig`. Supported generation configs include: -## Usage -### Quick start +- `do_sample`: Whether or not to use sampling. Defaults to False (greedy decoding). +- `top_k`: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50. +- `top_p`: If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 1.0. +- `temperature`: The value used to modulate the next token probabilities. Defaults to 1.0. +- `no_repeat_ngram_size`: If set to int > 0, all ngrams of that size can only occur once. Defaults to 0. +- `repetition_penalty`: The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0. +- `forced_eos_token_id`: The id of the token to force as the last generated token when max_length is reached. Defaults to `None`. -example files are in +Users can also create a transformers [GenerationConfig](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig) as an input argument for `InferenceEngine.generate` API. For example -```bash -cd ColossalAI/examples -python hybrid_llama.py --path /path/to/model --tp_size 2 --pp_size 2 --batch_size 4 --max_input_size 32 --max_out_len 16 --micro_batch_size 2 +```python +generation_config = GenerationConfig( + max_length=128, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=1.0, +) +response = engine.generate(prompts=prompts, generation_config=generation_config) ``` +## 🗺 Roadmap +We will follow the following roadmap to develop major features of ColossalAI-Inference: -### Example -```python -# import module -from colossalai.inference import CaiInferEngine -import colossalai -from transformers import LlamaForCausalLM, LlamaTokenizer +- [x] Blocked KV Cache +- [x] Paged Attention +- 🟩 Fused Kernels +- [x] Speculative Decoding +- [x] Continuous Batching +- 🟩 Tensor Parallelism +- [ ] Online Inference +- [ ] Beam Search +- [ ] SplitFuse -#launch distributed environment -colossalai.launch_from_torch() +Notations: +- [x] Completed +- 🟩 Model specific and in still progress. -# load original model and tokenizer -model = LlamaForCausalLM.from_pretrained("/path/to/model") -tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") - -# generate token ids -input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] -data = tokenizer(input, return_tensors='pt') - -# set parallel parameters -tp_size=2 -pp_size=2 -max_output_len=32 -micro_batch_size=1 - -# initial inference engine -engine = CaiInferEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, -) +## 🪅 Support Matrix -# inference -output = engine.generate(data) +| Model | Model Card | Tensor Parallel | Lazy Initialization | Paged Attention | Fused Kernels | Speculative Decoding | +|-----------|------------------------------------------------------------------------------------------------|-----------------|---------------------|-----------------|---------------|----------------------| +| Baichuan | `baichuan-inc/Baichuan2-7B-Base`,
    `baichuan-inc/Baichuan2-13B-Base`, etc | ✅ | [ ] | ✅ | ✅ | [ ] | +| ChatGLM | | [ ] | [ ] | [ ] | [ ] | [ ] | +| DeepSeek | | [ ] | [ ] | [ ] | [ ] | [ ] | +| Llama | `meta-llama/Llama-2-7b`,
    `meta-llama/Llama-2-13b`,
    `meta-llama/Meta-Llama-3-8B`,
    `meta-llama/Meta-Llama-3-70B`, etc | ✅ | [ ] | ✅ | ✅ | ✅ | +| Mixtral | | [ ] | [ ] | [ ] | [ ] | [ ] | +| Qwen | | [ ] | [ ] | [ ] | [ ] | [ ] | +| Vicuna | `lmsys/vicuna-13b-v1.3`,
    `lmsys/vicuna-7b-v1.5` | ✅ | [ ] | ✅ | ✅ | ✅ | +| Yi | `01-ai/Yi-34B`, etc | ✅ | [ ] | ✅ | ✅ | ✅ | -# get results -if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" -``` +## 🛠 Design and Components -## Performance +### Overview -### environment: +ColossalAI-Inference has **4** major components, namely `engine`, `request handler`, `kv cache manager`, and `modeling`. -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`. +

    + colossalai-inference-components-overview +
    +

    -For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future): +- **Engine**: It orchestrates the inference step. During inference, it recives a request, calls `request handler` to schedule a decoding batch, and executes the model forward pass to perform a iteration. It returns the inference results back to the user at the end. +- **Request Handler**: It manages requests and schedules a proper batch from exisiting requests. +- **KV Cache Manager** It is bound within the `request handler`, updates cache blocks and logical block tables as scheduled by the `request handler`. +- **Modelling**: We rewrite the model and layers of LLMs to simplify and optimize the forward pass for inference. -### Single GPU Performance: -Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned. +An overview of the inter-component interaction is given below (RPC version). We would also introduce more details in the next few sections. -### Tensor Parallelism Inference +

    + colossalai-inference-framework-rpc +
    +

    -##### Llama +### Engine -| batch_size | 8 | 16 | 32 | -|:-----------------------:|:------:|:------:|:------:| -| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 | -| colossal-inference | 326.4 | 582.72 | 816.64 | +Engine is designed as the entry point where the user kickstarts an inference loop. User can easily initialize an inference engine with the inference configurations and execute with their requests. We provided several versions of inference engines, namely `InferenceEngine`, `RPCInferenceEngine`, and `AsyncInferenceEngine`, which are used for different conditions and purposes. -![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png) +For examples/inference/llama and `RPCInferenceEngine`, we expose the following APIs for inference: -#### Bloom +- `generate`: main function which handles inputs, performs inference and returns outputs. +- `add_request`: add a single or multiple requests to the inference engine. +- `step`: perform one decoding iteration. The `request handler` first schedules a batch to do prefill/decoding. Then, it invokes a model to generate a batch of token and afterwards does logit processing and sampling, checks and decodes finished requests. +- `enable_spec_dec`: used for speculative decoding. Enable speculative decoding for subsequent generations. +- `disable_spec_dec`: used for speculative decoding. Disable speculative decoding for subsequent generations +- `clear_spec_dec`: clear structures and models related to speculative decoding, if exists. -| batch_size | 8 | 16 | 32 | -|:-----------------------:|:------:|:------:|:------:| -| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 | -| colossal-inference | 323.28 | 538.52 | 611.64 | +For `AsyncInferenceEngine`, we expose the following APIs for inference: +- `add_request`: async method. Add a request to the inference engine, as well as to the waiting queue of the background tracker. +- `generate`: async method. Perform inference from a request. +- `step`: async method. Perform one decoding iteration, if there exists any request in waiting queue. -![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png) +For now, `InferenceEngine` is used for offline generation; `AsyncInferenceEngine` is used for online serving with a single card; and `RPCInferenceEngine` is used for online serving with multiple cards. In future, we will focus on `RPCInferenceEngine` and improve user experience of LLM serving. -### Pipline Parallelism Inference -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. We set input length=1024, output length=128. +### KV cache +Learnt from [PagedAttention](https://arxiv.org/abs/2309.06180) by [vLLM](https://github.com/vllm-project/vllm) team, we use a unified blocked KV cache and cache manager to allocate and manage memory. The physical memory is pre-allocated during initialization and represented by a logical block table. During decoding process, cache manager administrates the physical memory through `block table` of a batch and so that other components (i.e. engine) can focus on the lightweight `block table`. More details are given below. -#### A10 7b, fp16 +- `logical cache block`: We group physical memory into different memory blocks. A typical cache block is shaped `(num_kv_heads, block_size, head_size)`. We determine the block number beforehand. The memory allocation and computation are executed at the granularity of memory block. +- `block table`: Block table is the logical representation of cache blocks. Concretely, a block table of a single sequence is a 1D tensor, with each element holding a block ID. Block ID of `-1` means "Not Allocated". In each iteration, we pass through a batch block table to the corresponding model. -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) | -|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:| -| Pipeline Inference | 40.35 | 77.10 | 139.03 | 232.70 | 257.81 | OOM | -| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM | +

    + +
    + Example of block table for a batch +

    -![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png) +### Batching -#### A10 13b, fp16 +Request handler is responsible for managing requests and scheduling a proper batch from exisiting requests. Based on [Orca's](https://www.usenix.org/conference/osdi22/presentation/yu) and [vLLM's](https://github.com/vllm-project/vllm) research and work on batching requests, we applied continuous batching with unpadded sequences, which enables various number of sequences to pass projections (i.e. Q, K, and V) together in different steps by hiding the dimension of number of sequences, and decrement the latency of incoming sequences by inserting a prefill batch during a decoding step and then decoding together. -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) | -|:----------------------------:|:-----:|:-----:|:-----:|:-----:| -| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | -| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | +

    + +
    + Naive Batching: decode until each sequence encounters eos in a batch +

    -![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png) +

    + +
    + Continuous Batching: dynamically adjust the batch size by popping out finished sequences and inserting prefill batch +

    +### Modeling -#### A800 7b, fp16 +Modeling contains models, layers, and policy, which are hand-crafted for better performance easier usage. Integrated with `shardformer`, users can define their own policy or use our preset policies for specific models. Our modeling files are aligned with [Transformers](https://github.com/huggingface/transformers). For more details about the usage of modeling and policy, please check `colossalai/shardformer`. -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -|:----------------------------:|:-----:|:------:|:------:|:------:|:------:| -| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | -| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | -![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png) +## 🌟 Acknowledgement -### Quantization LLama +This project was written from scratch but we learned a lot from several other great open-source projects during development. Therefore, we wish to fully acknowledge their contribution to the open-source community. These projects include -| batch_size | 8 | 16 | 32 | -|:-------------:|:------:|:------:|:------:| -| auto-gptq | 199.20 | 232.56 | 253.26 | -| smooth-quant | 142.28 | 222.96 | 300.59 | -| colossal-gptq | 231.98 | 388.87 | 573.03 | +- [vLLM](https://github.com/vllm-project/vllm) +- [flash-attention](https://github.com/Dao-AILab/flash-attention) -![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png) +If you wish to cite relevant research papars, you can find the reference below. +```bibtex +# vllm +@inproceedings{kwon2023efficient, + title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, + author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica}, + booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, + year={2023} +} - -The results of more models are coming soon! +# flash attention v1 & v2 +@inproceedings{dao2022flashattention, + title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author={Dao, Tri}, + year={2023} +} +``` diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index a95205efaa78..5f2effca65a0 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ -from .engine import InferenceEngine -from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy +from .config import InferenceConfig +from .core import InferenceEngine -__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] +__all__ = ["InferenceConfig", "InferenceEngine"] diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py new file mode 100644 index 000000000000..f8571c0ca030 --- /dev/null +++ b/colossalai/inference/batch_bucket.py @@ -0,0 +1,523 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from colossalai.inference.struct import Sequence +from colossalai.utils import get_current_device + + +class BatchBucket: + """Container for a batch of Sequences, which is used to manage the batch of sequences. + + Attrs: + _sequences_dict (Dict[int, Sequence]): Map sequence uid to sequence struct + seq_uid -> Sequence + _sequences_indexes (Dict[int, int]): Map sequence uid to index in the batch + seq_uid -> index in the batch (indexing used in sequence_lengths and block_tables) + _sequence_lengths (torch.Tensor): Length of each sequence in the batch. + The size of the tensor is (max_batch_size,) + _block_tables (torch.Tensor): Block table of each sequence in the batch + The size of the tensor is (max_batch_size, max_blocks_per_seq) + """ + + def __init__( + self, + num_heads, + head_dim, + max_batch_size, + max_length, + block_size, + kv_max_split_num, + fd_interm_tensor=None, + device=None, + dtype=torch.float16, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.max_batch_size = max_batch_size + self.max_length = max_length # in + out len + self.block_size = block_size + self.kv_max_split_num = kv_max_split_num # Hint used for flash decoding + self.fd_interm_tensor = fd_interm_tensor + self.device = device or get_current_device() + self.dtype = dtype + + self._use_spec_dec = False + self._num_tokens_to_verify = None + + self._current_batch_size = 0 + self._sequences_dict = dict() + self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) + self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) + self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) + max_blocks_per_seq = (self.max_length + block_size - 1) // block_size + self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) + self._block_tables_helper = torch.full_like(self._block_tables, -1) + + @property + def is_empty(self): + return self._current_batch_size == 0 + + @property + def current_batch_size(self): + return self._current_batch_size + + def __len__(self): + return self._current_batch_size + + @property + def available_batch_size(self): + return self.max_batch_size - self._current_batch_size + + @property + def block_tables(self): + return self._block_tables + + @property + def seq_lengths(self): + return self._sequence_lengths + + @property + def seqs_ids(self): + return list(self._sequences_dict.keys()) + + @property + def seqs_li(self): + return list(self._sequences_dict.values()) + + @property + def is_compact(self): + assert len(self._sequences_dict) == len(self._sequences_indexes), "BatchBucket indexing is not consistent" + return ( + len(self._sequences_dict) + == torch.nonzero(self._sequence_lengths).view(-1).numel() + == torch.nonzero(self._block_tables[:, 0] >= 0).numel() + ) + + @property + def use_spec_dec(self) -> bool: + return self._use_spec_dec + + @property + def num_tokens_to_verify(self) -> int: + return self._num_tokens_to_verify + + @property + def batch_token_ids(self) -> List[List[int]]: + out = [] + for seq in self.seqs_li: + out.append(seq.input_token_id + seq.output_token_id) + return out + + def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: + """Set batch bucket to use speculatvie decoding. + This will notify the adjust the lengths of inputs during modeling, + and let the main model verifies tokens in parallel. + """ + self._use_spec_dec = True + self._num_tokens_to_verify = num_tokens_to_verify + + def reset_use_spec_dec(self) -> None: + """Reset the usage of speculative decoding for the batch bucket""" + self._use_spec_dec = False + self._num_tokens_to_verify = None + + def _make_compact(self) -> None: + # Clean and Compress the batch based on its sequences dict. + # Namely,compress sequences to the front and clean the seq lengths and block tables tensors. + # NOTE Prevent calling this method multiple times in a single step + if self.is_compact: + return + valid_seq_ids = self._sequences_dict.keys() + valid_num = len(valid_seq_ids) + valid_indexes = [self._sequences_indexes[seq_id] for seq_id in valid_seq_ids] + assert valid_num == len(self._sequences_indexes), "BatchBucket indexing is not consistent" + self._sequence_lengths_helper[:valid_num] = self._sequence_lengths[valid_indexes] + self._sequence_lengths[:] = self._sequence_lengths_helper[:] + self._block_tables_helper[:valid_num, :] = self.block_tables[valid_indexes] + self.block_tables[:] = self._block_tables_helper[:] + new_idx = 0 + for seq_id in valid_seq_ids: + self._sequences_indexes[seq_id] = new_idx + new_idx += 1 + self._sequence_lengths_helper.fill_(0) + self._block_tables_helper.fill_(-1) + self._current_batch_size = valid_num + + def add_seq( + self, + seq: Sequence, + alloc_block_table: torch.Tensor = None, + alloc_block_table_fn: Callable[[torch.Tensor, int], None] = None, + ) -> Union[torch.Tensor, None]: + """Add a single sequence to the batch. + User could opt to provide either a block table or a function to allocate block tables. + + Args: + seq (Sequence): The sequence to be added to the batch + alloc_block_table (torch.Tensor): The block tables to be copied and used for the sequence + alloc_block_table_fn (Callable[[torch.Tensor, int], None]): The function to allocate blocks for the sequence, + which is expected to reserve blocks and update status of kv-cache manager. + + Returns: + block_table (torch.Tensor): The block table of the added sequence, used for block allocation in kv-cache manager. + None if the sequence cannot be added. + """ + block_table = None + # TODO might consider sorting by length + if self._current_batch_size < self.max_batch_size: + self._sequences_dict[seq.request_id] = seq + self._sequences_indexes[seq.request_id] = self._current_batch_size + self._sequence_lengths[self._current_batch_size] = seq.sentence_len + # NOTE the added seq still require block table allocation by kvcache manager + block_table = self._block_tables[self._current_batch_size - 1] + if alloc_block_table is not None: + # copy block ids from provided block tables + self._block_tables[self._current_batch_size - 1] = alloc_block_table + elif alloc_block_table_fn: + alloc_block_table_fn(block_table, self._sequence_lengths[self._current_batch_size - 1].item()) + self._current_batch_size += 1 + return block_table + + def add_seqs( + self, + seqs: List[Sequence], + alloc_block_tables: torch.Tensor = None, + alloc_block_tables_fn: Callable[[torch.Tensor, torch.Tensor], None] = None, + ) -> Union[torch.Tensor, None]: + """Add a list of sequences to the batch. + User could opt to provide either block tables or a function to allocate block tables. + + Args: + seqs (List[Sequence]): The sequences to be added to the batch + alloc_block_tables (torch.Tensor): The block tables to be copied and used for the sequence + alloc_block_table_fn (Callable[[torch.Tensor, torch.Tensor], None]): The function to allocate blocks for multiple sequences, + which is expected to reserve blocks and update status of kv-cache manager. + + Returns: + block_tables (torch.Tensor): The block tables of the added sequences, used for block allocation in kv-cache manager. + None if the sequences cannot be added. + """ + + assert ( + alloc_block_tables is None or alloc_block_tables_fn is None + ), "`alloc_block_tables` and `alloc_block_tables_fn` cannot be provided at the same time" + + num_seqs_to_add = min(self.max_batch_size - self._current_batch_size, len(seqs)) + block_tables = None + if num_seqs_to_add > 0: + for i, seq in enumerate(seqs[:num_seqs_to_add]): + self._sequences_dict[seq.request_id] = seq + self._sequences_indexes[seq.request_id] = self._current_batch_size + i + # TODO external (rename): modify Sequence.sentence_len to seq_len + self._sequence_lengths[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = torch.tensor([seq.sentence_len for seq in seqs[:num_seqs_to_add]], dtype=torch.int32) + # NOTE block tables to be updated by kvcache manager + block_tables = self._block_tables[self._current_batch_size : self._current_batch_size + num_seqs_to_add] + if alloc_block_tables is not None: + # copy block ids from provided block tables + self._block_tables[ + self._current_batch_size : self._current_batch_size + num_seqs_to_add + ] = alloc_block_tables + elif alloc_block_tables_fn: + alloc_block_tables_fn( + block_tables, + self._sequence_lengths[self._current_batch_size : self._current_batch_size + num_seqs_to_add], + ) + + self._current_batch_size += num_seqs_to_add + seqs[:] = seqs[num_seqs_to_add:] + + return block_tables + + def pop_seq_update_batch( + self, request_id: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[Sequence, Union[torch.Tensor, None]]: + """Pop a single sequence by id from the batch, and update the batch bucket status. + + Args: + request_id (int): The uid of the sequence + free_block_table_fn (Callable): The function to free the block table of a sequence, + if not provided, then we have to release the block table manually after calling this method + + Returns: + A tuple of: seq (Sequence): The target sequence + and block_table (torch.Tensor): block table of the target sequence indicating corresponding blocks, + none if the sequence is not found or free_block_table_fn is provided. + """ + seq: Sequence = self._sequences_dict.get(request_id) + block_table = None + if seq is not None: + assert request_id in self._sequences_indexes, "Inconsistency in BatchBucket indexing" + self._sequences_dict.pop(request_id) + seq_b_idx = self._sequences_indexes.get(request_id) + + if self.current_batch_size > 1: + # replace seq length of the target seq with that of the last seq in the batch + last_seq_b_idx = self.current_batch_size - 1 + last_seq_id = next( + (uid for uid, index in self._sequences_indexes.items() if index == last_seq_b_idx), + None, + ) + assert last_seq_id is not None + self._sequences_indexes[last_seq_id] = seq_b_idx + self._sequence_lengths[seq_b_idx] = self._sequence_lengths[last_seq_b_idx] + self._sequence_lengths[last_seq_b_idx].fill_(0) + # free the block table of the seq, or return a copy of the block table (to be processed outside) + if free_block_table_fn: + free_block_table_fn(self._block_tables[seq_b_idx]) + else: + block_table = self._block_tables[seq_b_idx].detach().clone() + # replace block table of the target seq with that of the last seq in the batch + self._block_tables[seq_b_idx] = self._block_tables[last_seq_b_idx] + self._block_tables[last_seq_b_idx].fill_(-1) + else: + if free_block_table_fn: + free_block_table_fn(self._block_tables[0]) + else: + block_table = self._block_tables[0].detach().clone() + self._sequence_lengths[0].fill_(0) + self._block_tables[0].fill_(-1) + self._sequences_indexes.pop(request_id) + self._current_batch_size -= 1 + + return seq, block_table + + def pop_seqs( + self, request_ids: List[int], free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Iteratively pop a list of sequences by uid. + + Args: + request_ids (List[int]): The uids of the sequences + free_block_table_fn (Callable): The function to free the block table of a sequence, + if not provided, then we have to release the block table manually after calling this method + Returns: + A tuple of: seqs (List[Sequence]): The target sequences + and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks + """ + seqs = [] + block_tables = [] + for request_id in request_ids: + seq, block_table = self.pop_seq_update_batch(request_id, free_block_table_fn) + if seq is not None: + seqs.append(seq) + if block_table is not None: + block_tables.append(block_table) + return seqs, block_tables + + def pop_n_seqs( + self, n: int, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Pop the first n sequences in the batch (FIFO). + If n is greater than the current batch szie, pop all the sequences in the batch. + + Args: + n (int): The number of sequences to pop out + free_block_table_fn (Callable): The function to free the block table of a single sequence + Returns: + A tuple of: seqs (List[Sequence]): The target sequences, + and block_tables (List[torch.Tensor]): block tables of the target sequences indicating corresponding blocks + """ + # NOTE Prevent calling this method multiple times in a single step + seqs = [] + block_tables = [] + n = min(n, self.current_batch_size) + seq_ids = list(self._sequences_dict.keys())[:n] + for seq_id in seq_ids: + seq = self._sequences_dict.pop(seq_id) + seq_b_idx = self._sequences_indexes.pop(seq_id) + if free_block_table_fn: + free_block_table_fn(self.block_tables[seq_b_idx]) + else: + block_tables.append(self.block_tables[seq_b_idx].detach().clone()) + seqs.append(seq) + if not self.is_compact: + self._make_compact() + + return seqs, block_tables + + def pop_finished( + self, free_block_table_fn: Callable[[torch.Tensor], None] = None + ) -> Tuple[List[Sequence], List[torch.Tensor]]: + """Pop finished sequences in the batch and a list of block tables of the finished sequences, + if free_block_table_fn is not provided. + + Args: + free_block_table_fn (Callable): The function to free the block table of a single sequence + Returns: + A tuple of: finished_seqs (List[Sequence]): The finished sequences, + and finished_block_tables (List[torch.Tensor]): block tables of the finished sequences. + """ + finished_seqs = [] + finished_block_tables = [] + for seq in self._sequences_dict.values(): + if seq.check_finish(): + finished_seqs.append(seq) + # Use `pop_seq_update_batch`` to update the batch status for just a few of finished seqs, + # otherwise, pop seqs directly and then call `_make_compact` to compress the batch. + # For now, the performance difference is not significant, so we use the frist method to pop seqs. + # Precise evaluations to be done. + for seq in finished_seqs: + _, block_table = self.pop_seq_update_batch(seq.request_id, free_block_table_fn) + if block_table is not None: + finished_block_tables.append(block_table) + + return finished_seqs, finished_block_tables + + # TODO arg type not support beam search sampling yet + def append_batch_tokens(self, tokens: torch.Tensor) -> None: + """Append a batch of tokens to the sequences in the batch""" + assert self.current_batch_size == tokens.size(0), "Batch size mismatch" + + if self.current_batch_size > 0: + tokens = tokens.tolist() + for seq_id, seq in self._sequences_dict.items(): + index_in_b = self._sequences_indexes[seq_id] + curr_tokens = tokens[index_in_b] + if not isinstance(curr_tokens, list): + curr_tokens = [curr_tokens] + seq.output_token_id += curr_tokens + seq.check_finish() + self._sequence_lengths[: self.current_batch_size] += 1 + + def revoke_batch_tokens(self, n_tokens: int, n_seqs: int = 1) -> None: + """Revoke the last n output tokens of the sequences in the batch + + Args: + n_tokens (int): The number of output tokens to revoke from each sequence. + It does not count in the context tokens (input tokens). + n_seqs (int): The first n sequences to revoke tokens from. Defaults to 1. + For now, speculative decoding only supports batch size 1. + """ + if n_tokens >= 1: + seqs_iter = iter(self._sequences_dict.items()) + for _ in range(n_seqs): + seq_id, seq = next(seqs_iter) + assert seq.output_len >= n_tokens, "Revoking len exceeds the current output len of the sequence" + seq.output_token_id = seq.output_token_id[:-n_tokens] + seq.revoke_finished_status() + self._sequence_lengths[self._sequences_indexes[seq_id]] -= n_tokens + + def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]: + """Clear all the sequences in the batch. + + free_block_tables_fn (Optional[Callable]): The function to free the block tables of all the sequences in a batch + """ + seqs = list(self._sequences_dict.values()) + self._sequences_dict.clear() + self._sequences_indexes.clear() + if free_block_tables_fn: + free_block_tables_fn(self.block_tables, self._current_batch_size) + self._block_tables.fill_(-1) + self._sequence_lengths.fill_(0) + self._current_batch_size = 0 + return seqs + + def merge(self, other: "BatchBucket") -> List[int]: + """Merge the sequences in the other batch into the current batch. + Merge as possible as the current batch can, if it does not have available spaces + holding all the sequences in the other batch + + Usage: + > New incoming sequence added to prefil batch + prefill bb curr batch size < prefil_ratio * prefill bb max batch size + > New incoming sequence added to prefil batch + prefill bb curr batch size == prefil_ratio * prefill bb max batch size + > Pause Decoding + > Prefill + > Move sequences in prefill bb => decoding bb + > Put back the out-of-volume sequences into the running pool + + Returns: + unmerged_ids (List[int]): a list of sequence uids that are not merged into the current batch + """ + unmerged_ids = [] + num_seqs_to_merge = min(self.available_batch_size, other.current_batch_size) + if num_seqs_to_merge > 0: + seqs, block_tables_li = other.pop_n_seqs(num_seqs_to_merge) + block_tables = torch.stack(block_tables_li) + self.add_seqs(seqs, alloc_block_tables=block_tables) + unmerged_ids = other.seqs_ids + + return unmerged_ids + + ########## The following methods are expected to be used in modeling ########### + + # For compatibility. + # NOTE: This is an assumption way to determine the stage of the batch. + @property + def is_prompts(self) -> bool: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) + if first_seq.output_len == 0: + return True + return False + + def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor: + # Used for main model verification in **Decoding Stage** + # `n` is the number of tokens to be verified, + # and so that prepare the last `n` tokens of each sequence as the inputs + assert len(self._sequences_dict) > 0, "No sequence in the batch" + assert all( + seq.output_len >= n for seq in self._sequences_dict.values() + ), "Sequence output tokens must be greater than or equal to the number of tokens to be verified." + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.output_token_id[-n:]) + return torch.tensor(out_li, dtype=torch.long, device=self.device) + + # For compatibility + def get_1D_inputs(self) -> torch.Tensor: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence + if first_seq.output_len == 0: + # Assume prefill stage + assert all( + seq.output_len == 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.input_token_id) + return torch.tensor(out_li, dtype=torch.long, device=self.device) + else: + # Assume decoding stage + if self.use_spec_dec: + # For Speculative Decoding + # the number of tokens to be verified in parallel plus the correct token in the last step + return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1) + assert all( + seq.output_len > 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + assert self.is_compact, "BatchBucket is not compact" + out = torch.empty([self.current_batch_size], dtype=torch.long) + for seq_id, index_in_b in self._sequences_indexes.items(): + seq: Sequence = self._sequences_dict[seq_id] + out[index_in_b] = seq.output_token_id[-1] + return out.to(device=self.device) + + # For compatibility + def get_block_table_tensor(self) -> torch.Tensor: + assert self.is_compact # Debug usage + block_table = self.block_tables[: self.current_batch_size] + return block_table.to(device=self.device) + + # For compatibility + def get_sequence_lengths(self) -> torch.Tensor: + assert self.is_compact # Debug usage + sequence_lengths = self.seq_lengths[: self.current_batch_size] + return sequence_lengths.to(device=self.device) + + # For compatibility + @property + def fd_inter_tensor(self) -> None: + assert self.fd_interm_tensor is not None, "fd_interm_tensor is not provided" + return self.fd_interm_tensor + + def __repr__(self) -> str: + return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})" diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py new file mode 100644 index 000000000000..61bc7c8abc9c --- /dev/null +++ b/colossalai/inference/config.py @@ -0,0 +1,342 @@ +""" +Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. +""" +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import Any, Dict, List, Optional, Union + +import torch +from transformers.generation import GenerationConfig + +from colossalai.inference.flash_decoding_utils import FDIntermTensors + +GibiByte = 1024**3 + +logger = logging.Logger(__name__) + +_DTYPE_MAPPING = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + +_ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + +_DEFAULT_PROMPT_TEMPLATES = { + "llama": "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n{input_text}[/INST]", + "baichuan": " {input_text} ", + "vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ", +} + + +class RPC_PARAM(ABC): + """ + NOTE(lry89757) We use rpyc to transport param between client and server. + Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes. + Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`. + """ + + @abstractmethod + def to_rpc_param(self): + return NotImplementedError + + @staticmethod + @abstractmethod + def from_rpc_param(): + return NotImplementedError + + +@dataclass +class InputMetaData(RPC_PARAM): + """The input info for a single step + + Args: + block_tables (torch.Tensor, optional): Sequences' BlockTables Defaults to None. + sequence_lengths (torch.Tensor): A tensor containing sequence lengths. + fd_inter_tensor (torch.Tensor, optional): A tensor representing intermediate data for flash decoding. Defaults to None. + batch_size (int, optional): The current batch size. Defaults to 64. + is_prompts (bool, optional): Indicates whether prefill or decoding. Defaults to False(decoding). + use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_cuda_graph (bool, optional): Indicates whether to use the CUDA graph. Defaults to False. + kv_seq_len (int, optional): Key-value sequence length. Defaults to 512. + head_dim (int, optional): Head dimension. Defaults to 32. + high_precision(bool, optional): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, Defaults to False. + dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. + use_spec_dec (bool): Indicate whether to use speculative decoding. + num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. + batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process. + """ + + block_tables: torch.Tensor = None + sequence_lengths: torch.Tensor = None + fd_inter_tensor: FDIntermTensors = None + batch_size: int = 64 # current_batch_size + is_prompts: bool = False + use_cuda_kernel: bool = False + use_cuda_graph: bool = False + kv_seq_len: int = 512 + head_dim: int = 32 + high_precision: bool = False + dtype: torch.dtype = torch.float32 + use_spec_dec: bool = False + num_tokens_to_verify: int = 0 + batch_token_ids: Optional[ + List[List[int]] + ] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process + + def to_rpc_param(self) -> Dict[str, any]: + return { + "block_tables": self.block_tables.tolist(), + "sequence_lengths": self.sequence_lengths.tolist(), + "batch_size": self.batch_size, + "is_prompts": self.is_prompts, + "use_cuda_kernel": self.use_cuda_kernel, + "use_cuda_graph": self.use_cuda_graph, + "kv_seq_len": self.kv_seq_len, + "head_dim": self.head_dim, + "high_precision": self.high_precision, + "dtype": str(self.dtype).split(".")[-1], + "use_spec_dec": self.use_spec_dec, + "num_tokens_to_verify": self.num_tokens_to_verify, + "batch_token_ids": self.batch_token_ids, + } + + @staticmethod + def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": + """ + We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message + """ + from colossalai.accelerator import get_accelerator + + dtype = getattr(torch, rpc_dict["dtype"]) + return InputMetaData( + block_tables=torch.tensor( + rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() + ), + sequence_lengths=torch.tensor( + rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() + ), + batch_size=rpc_dict["batch_size"], + is_prompts=rpc_dict["is_prompts"], + use_cuda_kernel=rpc_dict["use_cuda_kernel"], + use_cuda_graph=rpc_dict["use_cuda_graph"], + kv_seq_len=rpc_dict["kv_seq_len"], + head_dim=rpc_dict["head_dim"], + high_precision=rpc_dict["high_precision"], + dtype=dtype, + use_spec_dec=rpc_dict["use_spec_dec"], + num_tokens_to_verify=rpc_dict["num_tokens_to_verify"], + batch_token_ids=rpc_dict["batch_token_ids"], + ) + + def __repr__(self) -> str: + return ( + f"InputMetaData(block_tables={self.block_tables}, " + f"sequence_lengths={self.sequence_lengths}, " + f"fd_inter_tensor={self.fd_inter_tensor}, " + f"batch_size={self.batch_size}, " + f"is_prompts={self.is_prompts}, " + f"use_cuda_kernel={self.use_cuda_kernel}, " + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_seq_len={self.kv_seq_len}, " + f"use_spec_dec={self.use_spec_dec}, " + f"num_tokens_to_verify={self.num_tokens_to_verify})" + ) + + +@dataclass +class InferenceConfig(RPC_PARAM): + """The inference configuration. + + Args: + max_batch_size (int): Maximum batch size, defaults to 8. + max_output_len (int): Maximum output length, defaults to 256. + max_input_len (int): Maximum input length, defaults to 256. + dtype (Union[str, torch.dtype]): The data type for weights and activations. + kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None. + prompt_template (Optional[str]): The prompt template for generation, defaults to None. + do_sample (bool): Whether to use sampling for generation, defaults to False. + beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1. + During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. + prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill + when the actual value exceeds this ratio. + pad_input: Whether to pad all inputs to the max length. + early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False. + top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. + top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. + temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0. + repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. + no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. + n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. + glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. + block_size (int): The number of blocks in a logical block, defaults to 16. + tp_size (int): Tensor parallel size, defaults to 1. + pp_size (int): Pipeline parallel size, defaults to 1. + micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. + max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. + """ + + # NOTE: arrange configs according to their importance and frequency of usage + + # runtime limit + max_batch_size: int = 8 + max_output_len: int = 256 + max_input_len: int = 256 + + # general configs + dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default + kv_cache_dtype: Optional[str] = None + + # generation configs + prompt_template: Optional[str] = None + do_sample: bool = False + beam_width: int = 1 # TODO: beam search is not support for now + prefill_ratio: Optional[ + float + ] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio + pad_input: bool = False + early_stopping: Optional[bool] = False + top_k: Optional[int] = 50 + top_p: Optional[float] = 1.0 + temperature: Optional[float] = 1.0 + no_repeat_ngram_size: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + forced_eos_token_id: int = None + + # speculative decoding configs + max_n_spec_tokens: int = 5 + glimpse_large_kv: bool = False + + # paged attention configs + block_size: int = 16 + + # model parallelism configs + tp_size: int = 1 + pp_size: int = 1 + micro_batch_size: int = 1 + micro_batch_buffer_size: int = None + high_precision: Optional[bool] = False + + # cuda kernel option + use_cuda_kernel: bool = False + + # cuda_graph + use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference + max_context_len_to_capture: int = 512 + ignore_eos: bool = False + + def __post_init__(self): + self.max_context_len_to_capture = self.max_input_len + self.max_output_len + self._verify_config() + + def _verify_config(self) -> None: + """ + Verify the input config + """ + # check dtype + if isinstance(self.dtype, str): + # convert string dtype to torch dtype + assert ( + self.dtype in _DTYPE_MAPPING + ), f"Expected the dtype string argument to be in {list(_DTYPE_MAPPING.keys())} but found an unknown dtype: {self.dtype}" + self.dtype = _DTYPE_MAPPING[self.dtype] + assert ( + self.dtype in _ALLOWED_DTYPES + ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" + + if self.kv_cache_dtype: + assert ( + self.use_cuda_kernel and self.kv_cache_dtype == "fp8" + ), f"FP8 kv_cache is only supported with use_cuda_kernel open now" + self.kv_cache_dtype = torch.uint8 + + # skip using casting when the data type is float32 + if self.dtype == torch.float32: + self.high_precision = False + + # check prompt template + if self.prompt_template is None: + return + + if self.prompt_template in _DEFAULT_PROMPT_TEMPLATES: + self.prompt_template = _DEFAULT_PROMPT_TEMPLATES[self.prompt_template] + else: + # make sure the template can be formatted with input_text + assert ( + "{input_text}" in self.prompt_template + ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" + + def to_generation_config(self, model_config) -> GenerationConfig: + meta_config = { + "max_length": self.max_input_len + self.max_output_len, + "max_new_tokens": self.max_output_len, + "early_stopping": self.early_stopping, + "do_sample": self.do_sample, + "num_beams": self.beam_width, + } + for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]: + if hasattr(self, type): + meta_config[type] = getattr(self, type) + for type in ["pad_token_id", "bos_token_id", "eos_token_id"]: + if hasattr(model_config, type): + meta_config[type] = getattr(model_config, type) + + return GenerationConfig.from_dict(meta_config) + + def to_rpc_param(self) -> dict: + kwargs = { + "dtype": str(self.dtype).split(".")[-1], + "max_n_spec_tokens": self.max_n_spec_tokens, + "max_batch_size": self.max_batch_size, + "max_input_len": self.max_input_len, + "max_output_len": self.max_output_len, + "tp_size": self.tp_size, + "pp_size": self.pp_size, + "pad_input": self.pad_input, + "early_stopping": self.early_stopping, + "do_sample": self.do_sample, + "beam_width": self.beam_width, + "kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1], + } + return kwargs + + @staticmethod + def from_rpc_param(rpc_dict: dict) -> "InferenceConfig": + """ + We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message + """ + return InferenceConfig( + dtype=getattr(torch, rpc_dict["dtype"]), + max_n_spec_tokens=rpc_dict["max_n_spec_tokens"], + max_batch_size=rpc_dict["max_batch_size"], + max_input_len=rpc_dict["max_input_len"], + max_output_len=rpc_dict["max_output_len"], + tp_size=rpc_dict["tp_size"], + pp_size=rpc_dict["pp_size"], + pad_input=rpc_dict["pad_input"], + early_stopping=rpc_dict["early_stopping"], + do_sample=rpc_dict["do_sample"], + beam_width=rpc_dict["beam_width"], + kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None), + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in fields(cls)] + inference_config_args = {} + for attr in attrs: + if attr in config_dict: + inference_config_args[attr] = config_dict[attr] + else: + inference_config_args[attr] = getattr(cls, attr) + + # Set the attributes from the parsed arguments. + inference_config = cls(**inference_config_args) + return inference_config diff --git a/colossalai/inference/core/__init__.py b/colossalai/inference/core/__init__.py new file mode 100644 index 000000000000..c18c2e59b522 --- /dev/null +++ b/colossalai/inference/core/__init__.py @@ -0,0 +1,4 @@ +from .engine import InferenceEngine +from .request_handler import RequestHandler + +__all__ = ["InferenceEngine", "RequestHandler"] diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py new file mode 100644 index 000000000000..03f7f13f25bd --- /dev/null +++ b/colossalai/inference/core/async_engine.py @@ -0,0 +1,332 @@ +import asyncio +import logging +from functools import partial +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type + +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.sampler import search_tokens + +# CLI logger +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("colossalai-inference") + + +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None: + msg = "Task finished unexpectedly. This should never happen! " + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc + raise RuntimeError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + +class RequstStream: + """ + A stream of Output for a request that can be iterated over asynchronously. + Attributes: 1.request_id: The id of the request. + 2._future: A future that will be set when the request is finished. + Methods: set_result and get_result, results will be set when finished, for once, and + the `self.future` will be set to done. + + """ + + def __init__(self, request_id: int) -> None: + self.request_id = request_id + self._future = asyncio.Future() + + def set_result(self, result) -> None: + """Set final result and signal taht it's ready""" + if not self._future.done(): + self._future.set_result(result) + + async def get_result(self): + """Wait for the result to be set and return it.""" + return await self._future + + @property + def finished(self) -> bool: + """Check if the stream has finished by checking if the future is done.""" + return self._future.done() + + +class Tracer: + """ + Recording new requests and finished requests. + Attributes: 1._request_streams: We create one stream for each request to trace the output. + 2._finished_requests: A queue to store the finished requests. + 3._new_requests: New requests will be stored in this queue first, before sending them to the engine. + 4.new_requests_event: An event to notify the engine that there are new requests. + """ + + def __init__(self) -> None: + self._request_streams: Dict[int, RequstStream] = {} + self._finished_requests: asyncio.Queue[int] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._request_streams + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None: + """ + Propagate an exception to request streams (all if request_id is None). + """ + if request_id is not None: + self._request_streams[request_id].set_result(exc) + else: + for stream in self._request_streams.values(): + stream.set_result(exc) + + def process_finished_request(self, finished_request) -> None: + """Process a finished request from the engine.""" + request_id = finished_request.request_id + try: + self._request_streams[request_id].set_result(finished_request) + except: + raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check") + self.abort_request(request_id) + + def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream: + """ + Add a request to be sent to the engine on the next background + loop iteration. + """ + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = RequstStream(request_id) + logger.info(f"Added request {request_id}.") + self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: int, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if request_id not in self._request_streams or self._request_streams[request_id].finished: + # The request has already finished or been aborted. + # The requests in new_requests will be aborted when try to get them(if marked aborted) + return + + self._request_streams[request_id].set_result(None) + + def get_new_requests(self): + """ + Get new requests from http server. + """ + new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if new_request["request_id"] in finished_requests: + # The request has been aborted. + stream.set_result(None) + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class _AsyncInferenceEngine(InferenceEngine): + """ + Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for + Methods: 1. async_step: The async version of Engine.step() + """ + + async def async_step(self) -> List[str]: + """ + The async version of Engine.step() + Performs one decoding iteration and returns newly generated results. + + It first schedules the sequences to be executed in the next iteration. + Then, it executes the model and updates the scheduler with the model + outputs. Finally, it decodes the sequences and returns the newly + generated results. + """ + batch = self.request_handler.schedule() + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + loop = asyncio.get_running_loop() + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # Use run_in_executor to asyncally run the sync method model.forward(). + logits = await loop.run_in_executor( + None, + model_executable, + input_token_ids, + output_tensor, + input_meta_data, + self.k_cache, + self.v_cache, + ) + + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + + for sequence in finished_sequences: + sequence.output = self.tokenizer.decode(sequence.output_token_id) + + return finished_sequences, not self.request_handler.running_list.is_empty() + + def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None): + prompts = [prompt] + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict) + + +class AsyncInferenceEngine: + """An asynchronous wrapper for the InferenceEngine class. + + This class is used to wrap the InferenceEngine class to make it asynchronous. + It uses asyncio to create a background loop that keeps processing incoming + requests. Note that this class does not hold model directly, when incoming a new + request, it first called `add_request` and the Tracer will record the request, putting + it to the background `InferenceEngine`(done in background loop) to process. You can + consider this engine as an interface. + """ + + _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine + + def __init__(self, start_engine_loop: bool = True, **kwargs): + self.engine = self._init_engine(**kwargs) + self.background_loop = None + # reference to the unshielded loop + self._background_loop_unshielded = None + self.start_engine_loop = start_engine_loop + self._request_tracer = Tracer() + + @property + def background_loop_status(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.background_loop_status: + raise RuntimeError("Existing loop is running") + + self._request_tracer.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) + self._background_loop_unshielded.add_done_callback( + partial(_raise_exception_on_finish, request_tracker=self._request_tracer) + ) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def _init_engine(self, **kwargs): + return self._engine_class(**kwargs) + + async def step(self): + """ + Run engine to process requests + + Returns True if there are in-progress requests. + """ + new_requests = self._request_tracer.get_new_requests() + for new_request in new_requests: + self.engine.add_single_request(**new_request) + newly_finished_seqs, has_running_requests = await self.engine.async_step() + for seq in newly_finished_seqs: + self._request_tracer.process_finished_request(seq) + + return has_running_requests + + async def _engine_abort(self, request_ids: Iterable[int]): + self.engine.abort_request(request_ids) + + async def abort(self, request_id: int): + """ + Abort a single request + """ + if not self.background_loop_status: + raise RuntimeError("Background loop is not running or launched correctly.") + return self._abort(request_id) + + def _abort(self, request_id: int): + self._request_tracer.abort_request(request_id) + + async def run_engine_loop(self): + processing_requests = False + while True: + if not processing_requests: + await self._request_tracer.wait_for_new_requests() + processing_requests = await self.step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + generation_config=None, + ) -> RequstStream: + """ + Add a request to the background tracker(waiting queue), start the background loop if needed. + """ + if not self.background_loop_status: + if self.start_engine_loop: + self.start_background_loop() + else: + raise RuntimeError("Background loop is not running.") + stream = self._request_tracer.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + generation_config=generation_config, + ) + return stream + + async def generate( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + generation_config=None, + ) -> AsyncIterator[str]: + """ + Generate output from a request. It receives the request from http server, adds it into the + waitting queue of Async Engine and streams the output sequence. + """ + try: + stream = await self.add_request( + request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config + ) + return await stream.get_result() + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self._abort(request_id) + raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py new file mode 100644 index 000000000000..96c2b15ee16e --- /dev/null +++ b/colossalai/inference/core/engine.py @@ -0,0 +1,763 @@ +import time +from itertools import count +from typing import Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.graph_runner import CUDAGraphRunner +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.sampler import search_tokens +from colossalai.inference.spec import Drafter, GlideInput +from colossalai.inference.struct import Sequence +from colossalai.inference.utils import get_model_size +from colossalai.interface import ModelWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +from .request_handler import RequestHandler + +__all__ = ["InferenceEngine"] + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class InferenceEngine: + + """ + InferenceEngine which manages the inference process.. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model. + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + inference_config: InferenceConfig, + verbose: bool = False, + model_policy: Union[Policy, Type[Policy]] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + + self.init_model(model_or_path, model_policy) + + self.generation_config = inference_config.to_generation_config(self.model_config) + self.generation_config_dict = self.generation_config.to_dict() + + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.k_cache, self.v_cache = self.request_handler.get_kvcache() + # DISCUSS maybe move this into batch info? + + self.counter = count() + + self.use_cuda_graph = self.inference_config.use_cuda_graph + if self.use_cuda_graph: + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + if verbose: + self.logger.info("Colossal AI CUDA Graph Capture on") + + self.capture_model(self.k_cache, self.v_cache) + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = False + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self._verify_args() + + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model + """ + + if isinstance(model_or_path, str): + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + arch = getattr(hf_config, "architectures")[0] + if arch in _supported_models.keys(): + # NOTE(lry89757) Currently we load the model using transformers-api, + # but we will use lazy tensor and checkpoint io to accelerate + # the model load process in the future. + model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True) + else: + raise ValueError(f"Model {arch} is not supported.") + + except Exception as e: + self.logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + model = model.to(self.dtype).eval() + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor + # if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM): + # from colossalai.inference.core.plugin import InferCheckpoint_io + + # cpt_io = InferCheckpoint_io() + # if_has_index_file, model_index_file = has_index_file(model_or_path) + # assert if_has_index_file, "the model path is invalid" + # cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + @torch.inference_mode() + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): + assert self.use_cuda_graph, "please turn on the cuda graph" + + if self.verbose: + self.logger.info("Colossal AI CUDA Graph Capture begin") + + t_capture_begin = time.perf_counter() + + block_size = self.inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_context_len_to_capture = self.inference_config.max_context_len_to_capture + max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size + input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + output_tensor = torch.zeros( + (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device + ) + fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor + + max_num_seqs = self.inference_config.max_batch_size + batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] + sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + if self.verbose: + self.logger.info(f"batch size {batch_size} graph capturing") + + input_meta_data = InputMetaData( + block_tables=block_tables[:batch_size], + sequence_lengths=sequence_lengths[:batch_size], + fd_inter_tensor=fd_inter_tensor, + batch_size=batch_size, + is_prompts=False, + use_cuda_graph=True, + high_precision=False, + kv_seq_len=sequence_lengths[:batch_size].max().item(), + head_dim=head_dim, + dtype=self.dtype, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens_ids[:batch_size], + output_tensor[:batch_size], + input_meta_data, + k_caches=k_cache, + v_caches=v_cache, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + t_capture_end = time.perf_counter() + + if self.verbose: + self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.model, nn.Module): + raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + if isinstance(self.model, ModelWrapper): + model = self.model.module + assert ( + model.__class__.__name__ in _supported_models.keys() + ), f"Model {self.model.__class__.__name__} is not supported." + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model + + def enable_spec_dec( + self, + drafter_model: nn.Module = None, + n_spec_tokens: int = None, + use_glide_drafter: bool = False, + ) -> None: + """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. + + Args: + drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. + If provided, the previous drafter and drafter model, if exist, will be overwritten. + n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. + If not provided, `max_n_spec_tokens` in InferenceConfig will be used. + use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. + If True, the drafter model will be replaced by a glide model. + + ```python + ... + engine = InferenceEngine(model, tokenizer, inference_config) + + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + engine.generate(...) # Speculative Decoding + + engine.disable_spec_dec() + engine.generate(...) # Normal generation + + engine.enable_spec_dec() + engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens + engine.clear_spec_dec() + ``` + """ + if drafter_model is None and self.drafter is None: + raise ValueError("Drafter not initialized. Please provide a Drafter Model") + if n_spec_tokens is not None: + assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens + self.n_spec_tokens = n_spec_tokens + if drafter_model is not None: + assert isinstance(drafter_model, nn.Module) + # overwrite the drafter, if exists + self.clear_spec_dec() + self.drafter_model = drafter_model + self.drafter = Drafter( + self.drafter_model, + self.tokenizer, + device=self.device, + dtype=self.dtype, + ) + + # check if the provided drafter model is compatible with GLIDE structure + # when `use_glide_drafter` is set to True + if ( + use_glide_drafter + and hasattr(drafter_model, "model") + and hasattr(drafter_model.model, "layers") + and hasattr(drafter_model.model.layers[0], "cross_attn") + ): + self.use_glide = use_glide_drafter + elif use_glide_drafter: + self.logger.warning( + f"`use_glide_drafter` is provided as {use_glide_drafter}, " + f"but the provided drafter model is not compatible with GLIDE structure." + f"Falling back to use the default drafter model (non-GLIDE)." + ) + self.request_handler.set_spec_dec_mode(self.n_spec_tokens) + # using speculative decoding for subsequent generations + self.use_spec_dec = True + + def disable_spec_dec(self) -> None: + """Disable using speculative decoding for subsequent generations.""" + self.request_handler.unset_spec_dec_mode() + # set back to the maximum number of tokens to speculate + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.use_glide = False + self.use_spec_dec = False + + def clear_spec_dec(self) -> None: + """Clear relatable structures of speculative decoding, if exist.""" + if self.use_spec_dec: + self.disable_spec_dec() + if self.drafter_model or self.drafter: + self.drafter_model = None + self.drafter = None + torch.cuda.empty_cache() + self.use_glide = False + self.use_spec_dec = False + + def steps_spec_dec(self) -> List[Sequence]: + """ + Run Speculative Decoding steps. This is like retrieving a single batch and launch inference + with many steps of speculating by a drafter model as well as verifying by a main model. + + Returns: + List[Sequence]: finished sequences generated by one step. + """ + batch = self.request_handler.schedule() # prefill batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # 1. Prefill small model (Drafter) - fill past kv cache for drafter model + # NOTE For glide drafter models, we won't actually apply glide during prefill stage + drafter_out = self.drafter.speculate(input_token_ids, 1, None) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + # 2. Prefill main model (Verifier) - fill past kv cache for main model + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + # append new inputs to the batch, temporarily + batch.append_batch_tokens(next_tokens) + self.request_handler.allocate_batch_spec_dec(batch, 1) + already_allocated_kv_len = batch.seq_lengths[0].item() + input_token_ids = batch.get_1D_inputs_spec_dec(1) + + finished_sequences = self.request_handler.update() + + while True: + # HACK Retrieve the running batch + # Using RequestHandler.schedule here will re-allocate same kv cache for the batch + batch = self.request_handler.running_bb # running batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + # 3. Decoding - Drafter model speculates `n` tokens + glide_input = None + if self.use_glide: + glide_input = GlideInput( + batch.get_block_table_tensor(), + self.k_cache[-1], # use kv cahces of the last layer + self.v_cache[-1], + batch.get_sequence_lengths(), + ) + + drafter_out = self.drafter.speculate( + input_token_ids, + self.n_spec_tokens, + drafter_past_key_values, + glide_input=glide_input, + ) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + drafter_spec_length = drafter_out.speculated_length + + for next_token_id_spec in next_token_ids_spec: + self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) + cur_length = batch.seq_lengths[0].item() + if already_allocated_kv_len < cur_length: + self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) + already_allocated_kv_len = cur_length + + # 4. Decoding - Main model verifies `n` tokens in parallel + if drafter_spec_length < batch.num_tokens_to_verify: + batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + + # 5. Compare and process the results + diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) + n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + + # revoke appended tokens for each Sequence in the current batch + batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens + + # append the last correct token generated by the main model + self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) + + # trim past key values of the drafter model + drafter_past_key_values = Drafter.trim_kv_cache( + drafter_past_key_values, drafter_spec_length - n_matches - 1 + ) + + # prepare inputs for the next round of speculation + n = 1 if n_matches < drafter_spec_length else 2 + input_token_ids = batch.get_1D_inputs_spec_dec(n) + + self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) + finished_sequences = self.request_handler.update() + if len(finished_sequences) > 0: + break + + # Reset back the number of speculated tokens of the batch, + # this is used to handle the last round of speculation, in which case the number of speculated tokens + # by the drafter is less than the number of speculated tokens set to the engine. + batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) + + return finished_sequences + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + return_token_ids: bool = False, + generation_config: Optional[GenerationConfig] = None, + ) -> List[str]: + """ + Executing the inference step. + + Args: + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + request_ids (List[int], optional): The request ID. Defaults to None. + return_token_ids (bool): Whether to return output token ids. Defaults to False. + generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. + + Returns: + List[str]: Inference result returned by one generation. + """ + + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None or prompts_token_ids is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + prompts_token_ids=prompts_token_ids, + **gen_config_dict, + ) + + output_seqs_list = [] + total_tokens_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + if self.use_spec_dec: + assert self.drafter is not None, "Drafter Model is not initialized." + while self.request_handler.check_unfinished_seqs(): + output_seqs_list += self.steps_spec_dec() + else: + while self.request_handler.check_unfinished_seqs(): + output_seqs_list += self.step() + + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + + for seq in output_seqs_list: + total_tokens_list.append(seq.input_token_id + seq.output_token_id) + + output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) + + if return_token_ids: + output_tokens_list = [seq.output_token_id for seq in output_seqs_list] + return output_str, output_tokens_list + else: + return output_str + + @property + def has_prompt_template(self) -> bool: + """ """ + return self.inference_config.prompt_template is not None + + def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: + """ + This method will format the input prompt according to the prompt template given to the InferenceConfig. + """ + assert ( + self.has_prompt_template + ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." + + if isinstance(prompts, (list, tuple)): + return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] + elif isinstance(prompts, str): + return self.inference_config.prompt_template.format(input_text=prompts) + else: + raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + + def add_request( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + **kwargs, + ) -> None: + """ + Add requests. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + """ + + # apply the prompt template to the input prompts + + if self.has_prompt_template and prompts is not None: + prompts = self.format_prompt(prompts) + + block_size = self.inference_config.block_size + + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if prompts is not None and not isinstance(prompts, list): + prompts = [prompts] + + if prompts_token_ids is None: + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ + "input_ids" + ] + + # list of torch Tensor + if isinstance(prompts_token_ids, list): + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] + elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): + prompts_token_ids = prompts_token_ids.tolist() + else: + raise TypeError( + f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." + ) + + assert ( + len(prompts_token_ids[0]) <= self.inference_config.max_input_len + ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." + + prompts_num = len(prompts_token_ids) + + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + if prompts == None: + prompt = None + else: + prompt = prompts[i] + + max_length = kwargs.get("max_length", None) + max_new_tokens = kwargs.get("max_new_tokens", None) + if max_length is None and max_new_tokens is None: + max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len + elif max_length is not None: + max_new_tokens = max_length - len(prompts_token_ids[i]) + + sequence = Sequence( + request_id, + prompt, + prompts_token_ids[i], + block_size, + None, + self.tokenizer.eos_token_id, + self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) + + def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + output_tensor = torch.zeros( + (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) + + batch_token_ids = None + if ( + self.generation_config.repetition_penalty != 1.0 + or self.generation_config.no_repeat_ngram_size > 0 + or self.generation_config.forced_eos_token_id is not None + ): + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=batch.fd_inter_tensor, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids, output_tensor, input_meta_data + + def step(self) -> List[str]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. Get the input, inputinfo and output placeholder from the batchbucket + 3. Run model to generate the next token + 4. Update waiting list and running list in RequestHandler and get finished sequences. + 5. Decode and return finished sequences. + + Returns: + List[str]: Decoded finished sequences generated by one step. + """ + + batch = self.request_handler.schedule() + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + + return finished_sequences diff --git a/colossalai/inference/core/plugin.py b/colossalai/inference/core/plugin.py new file mode 100644 index 000000000000..d6a2b8b16550 --- /dev/null +++ b/colossalai/inference/core/plugin.py @@ -0,0 +1,140 @@ +import logging +import os +from functools import reduce +from pathlib import Path +from typing import Optional + +import torch + +from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class InferCheckpoint_io(GeneralCheckpointIO): + """ + This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. + Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference. + """ + + def __init__( + self, + verbose: bool = True, + ) -> None: + super().__init__() + self.verbose = verbose + self.coordinator = DistCoordinator() + + def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + model = model.unwrap() + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + missing_keys = [] + missing_file_keys = [] + + def _load(name: str): + if name not in weight_map: + missing_file_keys.append(name) + return + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + + load_state_dict_into_model( + model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persistent_buffers: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + _load(extra_state_key) + + if self.verbose and self.coordinator.is_master(): + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + if len(missing_keys) == 0: + raise RuntimeError( + "No weigth is loaded into the model. Please check the checkpoint files and the model structure." + ) + + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + remain_keys = remain_keys.union(set(missing_file_keys)) + if len(remain_keys) > 0: + if strict: + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + else: + if self.coordinator.is_master(): + logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}") + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + return NotImplementedError diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py new file mode 100644 index 000000000000..5085c55558b4 --- /dev/null +++ b/colossalai/inference/core/request_handler.py @@ -0,0 +1,401 @@ +from typing import Dict, List, Union + +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationConfig + +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager +from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) + +__all__ = ["RunningList", "RequestHandler"] + + +class RunningList: + """ + RunningList is an structure for recording the running sequences, contains prefill and decoding list. + Prefilling samples will be hold until the actual ratio of prefill samples versus decoding samples exceeds ratio. + + Args: + prefill_ratio: (float) A ratio for determing whether to perform prefill or not. + _prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence. + _decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence. + """ + + def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None: + self.prefill_ratio = prefill_ratio + self._decoding: Dict[int, Sequence] = dict() + self._prefill: Dict[int, Sequence] = ( + dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict() + ) + + @property + def decoding(self): + return list(self._decoding.values()) + + @property + def prefill(self): + return list(self._prefill.values()) + + @property + def prefill_seq_num(self): + return len(self._prefill) + + @property + def decoding_seq_num(self): + return len(self._decoding) + + @property + def total_seq_num(self): + return self.prefill_seq_num + self.decoding_seq_num + + def append(self, seq: Sequence): + assert (seq.request_id not in self._prefill) and ( + seq.request_id not in self._decoding + ), f"Sequence uid {seq.request_id} already exists." + self._prefill[seq.request_id] = seq + + def extend(self, seqs: List[Sequence]): + for seq in seqs: + self._prefill[seq.request_id] = seq + + def find_seq(self, request_id) -> Union[Sequence, None]: + seq = None + if request_id in self._decoding: + seq = self._decoding[request_id] + elif request_id in self._prefill: + seq = self._prefill[request_id] + return seq + + def remove(self, seq: Sequence) -> None: + if seq.request_id in self._decoding: + self._decoding.pop(seq.request_id) + elif seq.request_id in self._prefill: + self._prefill.pop(seq.request_id) + else: + raise ValueError(f"Sequence {seq.request_id} is not in running list") + + def ready_for_prefill(self): + if not self._decoding: + return len(self._prefill) > 0 + return len(self._prefill) / len(self._decoding) >= self.prefill_ratio + + def is_empty(self): + return not self._decoding and not self._prefill + + def mark_prefill_running(self) -> None: + for seq_id in self._prefill: + self._prefill[seq_id].mark_running() + + def move_prefill_to_decoding(self, seq_ids: List[int]) -> None: + for seq_id in seq_ids: + assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list" + self._decoding[seq_id] = self._prefill.pop(seq_id) + + +class RequestHandler: + """ + RequestHandler is the core for handling existing requests and updating current batch. + During generation process, we call schedule function each iteration to update current batch. + + Args: + inference_config: Configuration for initialize and manage kv cache. + model_config: Configuration for model + dtype (torch.dtype): The data type for weights and activations. + """ + + def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: + self.inference_config = inference_config + self.running_list: RunningList = RunningList(inference_config.prefill_ratio) + self.waiting_list: List[List] = [[], [], []] + self.done_list: List[Sequence] = [] + self.dtype = inference_config.dtype + self.max_batch_size = inference_config.max_batch_size + + # initialize cache + self._init_cache(model_config) + + # initialize batch + device = torch.cuda.current_device() + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = model_config.hidden_size // model_config.num_attention_heads + + fd_inter_tensor = FDIntermTensors() + + if fd_inter_tensor._tensors_initialized: + fd_inter_tensor._reset() + + # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq + max_n_tokens = self.max_batch_size + max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 + + fd_inter_tensor.initialize( + max_batch_size=max_n_tokens, + num_attn_heads=model_config.num_attention_heads // inference_config.tp_size, + kv_max_split_num=kv_max_split_num, + head_dim=head_dim, + dtype=self.dtype, + device=device, + ) + + # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, + # which may cause bugs and this issue should be fixed later. + self.running_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=fd_inter_tensor, + dtype=self.dtype, + device=device, + ) + self.prefill_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=fd_inter_tensor, + dtype=self.dtype, + device=device, + ) + + def _init_cache(self, model_config): + self.cache_manager = KVCacheManager(self.inference_config, model_config) + + def _has_waiting(self) -> bool: + return any(lst for lst in self.waiting_list) + + def get_kvcache(self): + return self.cache_manager.get_kv_cache() + + def set_spec_dec_mode(self, n_spec_tokens: int): + self.prefill_bb.set_use_spec_dec(n_spec_tokens) + self.running_bb.set_use_spec_dec(n_spec_tokens) + + def unset_spec_dec_mode(self): + self.prefill_bb.reset_use_spec_dec() + self.running_bb.reset_use_spec_dec() + + def schedule(self): + """ + The main logic of request handler. + """ + if self._has_waiting(): + # Try to allocate cache blocks for the sequence using a priority of prompt length. + for lst in reversed(self.waiting_list): + if lst: + remove_list = [] + for seq in lst: + if seq.input_len > self.inference_config.max_input_len: + # If the prompt length is longer than max_input_len, abort the sequence. + logger.warning( + f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence." + ) + self.abort_sequence(seq.request_id) + remove_list.append(seq) + break + + num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) + # for now the recycle logic is not working + remove_list.extend(lst[:num_seqs_to_add]) + self.running_list.extend(lst[:num_seqs_to_add]) + + for seq in remove_list: + lst.remove(seq) + + if self.running_list.ready_for_prefill(): + num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size) + # overwrite the number of sequences to add to 1 if use_spec_dec is enabled + # TODO (zhaoyuanheng): support speculative decoding for batch size > 1 + if self.prefill_bb.use_spec_dec: + num_seqs_to_add = 1 + + for seq in self.running_list.prefill[:num_seqs_to_add]: + seq.mark_running() + # allocate blocks for the prefill batch + self.prefill_bb.add_seqs( + self.running_list.prefill[:num_seqs_to_add], + alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables, + ) + + return self.prefill_bb + + if not self.running_bb.is_empty: + seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables( + self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size + ) + if seqs_ids_to_recycle: + seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle) + for seq in seqs_to_recycle: + seq.recycle() + self.running_list.remove(seq) + self.waiting_list[-1].append(seq) + # the recycled sequences are handled with highest priority. + + return self.running_bb + + def allocate_batch_spec_dec(self, batch: BatchBucket, n: int): + assert batch.use_spec_dec + if n > 0: + self.cache_manager.allocate_n_tokens_from_block_tables( + batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n + ) + + def add_sequence(self, req: Sequence): + """ + Add the request to waiting list. + """ + assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists." + assert ( + req.input_len <= self.inference_config.max_input_len + ), f"Sequence {req.request_id} exceeds input length limit" + self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req) + + def abort_sequence(self, request_id: int): + """ + Abort the request. + """ + result = self._find_sequence(request_id) + if result is not None: + seq, priority = result + if seq.status == RequestStatus.WAITING: + seq.mark_aborted() + self.waiting_list[priority].remove(seq) + elif seq.status.is_running(): + self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) + self.running_list.remove(seq) + else: + try: + self.done_list.remove(seq) + except: + return + return + + def _find_sequence(self, request_id: int) -> Sequence: + """ + Find the request by request_id. + """ + for priority, lst in enumerate(self.waiting_list): + for seq in lst: + if seq.request_id == request_id: + return seq, priority + + if self.running_list.find_seq(request_id): + return seq, None + + return None + + def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig): + if ( + sequence.output_token_id[-1] == generation_config.eos_token_id + or sequence.output_len >= generation_config.max_length + ): + sequence.mark_finished() + + def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig): + for seq in batch.seqs_li: + max_length = generation_config.max_length + max_new_tokens = generation_config.max_new_tokens + if max_length is not None: + max_new_tokens = max_length - seq.input_len + if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens: + seq.mark_finished() + + def check_unfinished_seqs(self) -> bool: + return self._has_waiting() or not self.running_list.is_empty() + + def total_requests_in_batch_bucket(self) -> int: + return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size + + def append_next_tokens(self, sample_tokens: torch.Tensor): + assert sample_tokens.dim() == 1 + n_elements = sample_tokens.size(0) + if not self.prefill_bb.is_empty: + assert ( + self.prefill_bb.current_batch_size == n_elements + ), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}" + self.prefill_bb.append_batch_tokens(sample_tokens) + else: + assert ( + self.running_bb.current_batch_size == n_elements + ), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}" + self.running_bb.append_batch_tokens(sample_tokens) + + def update(self): + """ + Update current running list and done list + """ + if not self.prefill_bb.is_empty: + self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids) + self.running_bb.merge(self.prefill_bb) + # clear the prefill batch without assigning a free_block_tables_fn + # since we want to reuse the memory recorded on the block tables + self.prefill_bb.clear(free_block_tables_fn=None) + + finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table) + for seq in finished_seqs: + self.running_list.remove(seq) + self.done_list.extend(finished_seqs) + + return finished_seqs + + +class RPCRequestHandler(RequestHandler): + """ + RPC Version of request handler + """ + + def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: + self.inference_config = inference_config + self.running_list: RunningList = RunningList(inference_config.prefill_ratio) + self.waiting_list: List[List] = [[], [], []] + self.done_list: List[Sequence] = [] + self.dtype = inference_config.dtype + self.max_batch_size = inference_config.max_batch_size + + # initialize cache + self._init_cache(model_config) + + # initialize batch + torch.cuda.current_device() + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = model_config.hidden_size // model_config.num_attention_heads + + # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, + # which may cause bugs and this issue should be fixed later. + self.running_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=None, + dtype=self.dtype, + ) + self.prefill_bb = BatchBucket( + num_heads=model_config.num_attention_heads // inference_config.tp_size, + head_dim=head_dim, + max_batch_size=self.max_batch_size, + max_length=inference_config.max_input_len + inference_config.max_output_len, + block_size=inference_config.block_size, + kv_max_split_num=kv_max_split_num, + fd_interm_tensor=None, + dtype=self.dtype, + ) + + def _init_cache(self, model_config): + self.cache_manager = RPCKVCacheManager(self.inference_config, model_config) diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py new file mode 100644 index 000000000000..439c4b0b5fff --- /dev/null +++ b/colossalai/inference/core/rpc_engine.py @@ -0,0 +1,296 @@ +import asyncio +from itertools import count +from time import sleep +from typing import List, Tuple, Union + +import rpyc +import torch +import torch.nn as nn +from rpyc.utils.server import ThreadedServer +from torch import multiprocessing as mp +from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.configuration_utils import PretrainedConfig + +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.executor.rpc_worker import rpcWorkerService +from colossalai.inference.utils import find_available_ports +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .engine import InferenceEngine +from .request_handler import RPCRequestHandler + +__all__ = ["RPCInferenceEngine"] + + +def run_server(host, port, event: mp.Event = None): + server = ThreadedServer( + rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True} + ) + if event: + event.set() + server.start() + + +class RPCInferenceEngine(InferenceEngine): + + """ + InferenceEngine which manages the inference process.. + + NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. + Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: Union[nn.Module, str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + inference_config: InferenceConfig, + verbose: bool = False, + model_policy: Policy = None, + ) -> None: + """ + If you input a real model loaded by transformers, the init will take quite a long time + Currently we don't support model(nn.Module) format as the param. + """ + + torch.multiprocessing.set_start_method("spawn", force=True) + + self.inference_config = inference_config + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + + try: + if isinstance(model_or_path, str): + self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + elif isinstance(model_or_path, nn.Module): + self.logger.error( + f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" + ) + # self.model_config = model_or_path.config + else: + self.logger.error( + f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n" + ) + except Exception as e: + self.logger.error( + f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n" + ) + self.generation_config = inference_config.to_generation_config(self.model_config) + + self.tp_size = inference_config.tp_size + self.events = [mp.Event() for _ in range(self.tp_size)] + + # This operation will init the dist env and models + self.workers: List[rpcWorkerService] = [] + self.init_workers() + + asyncio.run(self.init_model(model_or_path, model_policy)) + + # init the scheduler and logic block manager + self.request_handler = self.init_scheduler(self.inference_config, self.model_config) + + # init the physical cache + alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape() + self.init_device_cache(alloc_shape) + + self.use_cuda_graph = self.inference_config.use_cuda_graph + self.high_precision = inference_config.high_precision + self.dtype = inference_config.dtype + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = False + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self.counter = count() + self._verify_args() + + self.logger.info("engine init over ") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + + def init_workers(self): + rpc_ports = find_available_ports(self.tp_size) + self.worker_processes = [] + # mp.set_start_method('spawn') + for event, rpc_port in zip(self.events, rpc_ports): + p = mp.Process(target=run_server, args=("localhost", rpc_port, event)) + p.start() + self.worker_processes.append(p) + self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...") + + # Wait for all servers to start + for event in self.events: + event.wait() + event.clear() + + sleep(0.05) + + self.logger.info(f"init rpc server done.") + + for rpc_port in rpc_ports: + try: + conn = rpyc.connect( + "localhost", + rpc_port, + config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True}, + ) + self.workers.append(conn.root) + except: + raise Exception("conn error!") + self.logger.info(f"Build RPC Connection Success! Begin to load model...") + asyncio.run(self.init_worker_env()) + self.logger.info(f"init dist env over") + + async def async_parallel_wrapper(self, f, *args, **kwargs): + async_res = rpyc.async_(f)(*args, **kwargs) + await asyncio.to_thread(async_res.wait) + assert async_res.ready + return async_res.value + + async def init_worker_env(self): + assert len(self.workers) == self.tp_size, "init workers first" + + dist_group_port = find_available_ports(1)[0] + init_tasks = [ + self.async_parallel_wrapper( + worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port + ) + for rank, worker in enumerate(self.workers) + ] + + await asyncio.gather(*init_tasks) + + async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + assert len(self.workers) == self.tp_size, "init workers first" + + inference_config_param = self.inference_config.to_rpc_param() + model_path = model_or_path + model_policy_param = model_policy.to_rpc_param() if model_policy else None + + init_tasks = [ + self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) + for rank, worker in enumerate(self.workers) + ] + + await asyncio.gather(*init_tasks) + + def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: + return RPCRequestHandler(inference_config, model_config) + + async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): + assert len(self.workers) == self.tp_size, "init workers first" + + init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers] + + await asyncio.gather(*init_tasks) + + def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): + asyncio.run(self._init_device_cache(alloc_shape)) + + def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + + batch_token_ids = None + config_dict = self.generation_config.to_dict() + # process repetition_penalty, no_repeat_ngram_size + for type in ["repetition_penalty", "no_repeat_ngram_size"]: + if type in config_dict and config_dict[type] is not None: + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=None, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids.tolist(), input_meta_data + + async def step_(self, input_token_ids, input_meta_data: InputMetaData): + assert len(self.workers) == self.tp_size, "init workers first" + + init_tasks = [ + self.async_parallel_wrapper( + worker.execute_model_forward, + input_token_ids, + input_meta_data.to_rpc_param(), + self.generation_config_dict, + ) + for worker in self.workers + ] + ret = await asyncio.gather(*init_tasks) + + return ret[0] + + def step(self) -> List[str]: + batch = self.request_handler.schedule() + + input_token_ids, input_meta_data = self.prepare_input(batch) + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) + + # update the request_handler + next_tokens = torch.tensor(next_tokens, dtype=torch.int) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + return finished_sequences + + def kill_workers(self): + """ + I don't find a good way to implicit invoke self.kill_workers + """ + assert len(self.workers) != 0 + for proc in self.worker_processes: + proc.kill() + proc.join() + self.logger.info(f"worker killed, serving end") + + def __del__(self): + self.kill_workers() diff --git a/colossalai/inference/engine/__init__.py b/colossalai/inference/engine/__init__.py deleted file mode 100644 index 6e60da695a22..000000000000 --- a/colossalai/inference/engine/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .engine import InferenceEngine - -__all__ = ["InferenceEngine"] diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py deleted file mode 100644 index 61da5858aa86..000000000000 --- a/colossalai/inference/engine/engine.py +++ /dev/null @@ -1,195 +0,0 @@ -from typing import Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from transformers.utils import logging - -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.schedule.generate import GenerateSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.base_policy import Policy - -from ..kv_cache import MemoryManager -from .microbatch_manager import MicroBatchManager -from .policies import model_policy_map - -PP_AXIS, TP_AXIS = 0, 1 - -_supported_models = [ - "LlamaForCausalLM", - "BloomForCausalLM", - "LlamaGPTQForCausalLM", - "SmoothLlamaForCausalLM", - "ChatGLMForConditionalGeneration", -] - - -class InferenceEngine: - """ - InferenceEngine is a class that handles the pipeline parallel inference. - - Args: - tp_size (int): the size of tensor parallelism. - pp_size (int): the size of pipeline parallelism. - dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'. - model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. - model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided. - micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. - micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - max_batch_size (int): the maximum batch size. - max_input_len (int): the maximum input length. - max_output_len (int): the maximum output length. - quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None. - verbose (bool): whether to return the time cost of each step. - - """ - - def __init__( - self, - tp_size: int = 1, - pp_size: int = 1, - dtype: str = "fp16", - model: nn.Module = None, - model_policy: Policy = None, - micro_batch_size: int = 1, - micro_batch_buffer_size: int = None, - max_batch_size: int = 4, - max_input_len: int = 32, - max_output_len: int = 32, - quant: str = None, - verbose: bool = False, - # TODO: implement early_stopping, and various gerneration options - early_stopping: bool = False, - do_sample: bool = False, - num_beams: int = 1, - ) -> None: - if quant == "gptq": - from ..quant.gptq import GPTQManager - - self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len) - model = model.model - elif quant == "smoothquant": - model = model.model - - assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." - assert ( - tp_size * pp_size == dist.get_world_size() - ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" - assert model, "Model should be provided." - assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - - assert max_batch_size <= 64, "Max batch size exceeds the constraint" - assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" - assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" - self.pp_size = pp_size - self.tp_size = tp_size - self.quant = quant - - logger = logging.get_logger(__name__) - if quant == "smoothquant" and dtype != "fp32": - dtype = "fp32" - logger.warning_once("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32") - - if dtype == "fp16": - self.dtype = torch.float16 - model.half() - elif dtype == "bf16": - self.dtype = torch.bfloat16 - model.to(torch.bfloat16) - else: - self.dtype = torch.float32 - - if model_policy is None: - model_policy = model_policy_map[model.config.model_type]() - - # Init pg mesh - pg_mesh = ProcessGroupMesh(pp_size, tp_size) - - stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False) - self.cache_manager_list = [ - self._init_manager(model, max_batch_size, max_input_len, max_output_len) - for _ in range(micro_batch_buffer_size or pp_size) - ] - self.mb_manager = MicroBatchManager( - stage_manager.stage, - micro_batch_size, - micro_batch_buffer_size or pp_size, - max_input_len, - max_output_len, - self.cache_manager_list, - ) - self.verbose = verbose - self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) - - self.model = self._shardformer( - model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None - ) - if quant == "gptq": - self.gptq_manager.post_init_gptq_buffer(self.model) - - def generate(self, input_list: Union[list, dict]): - """ - Args: - input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. - - Returns: - out (list): a list of output data, each element is a list of token. - timestamp (float): the time cost of the inference, only return when verbose is `True`. - """ - - out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) - if self.verbose: - return out, timestamp - else: - return out - - def _shardformer(self, model, model_policy, stage_manager, tp_group): - shardconfig = ShardConfig( - tensor_parallel_process_group=tp_group, - pipeline_stage_manager=stage_manager, - enable_tensor_parallelism=(self.tp_size > 1), - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - extra_kwargs={"quant": self.quant}, - ) - shardformer = ShardFormer(shard_config=shardconfig) - shard_model, _ = shardformer.optimize(model, model_policy) - return shard_model.cuda() - - def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: - max_total_token_num = max_batch_size * (max_input_len + max_output_len) - if model.config.model_type == "llama": - head_dim = model.config.hidden_size // model.config.num_attention_heads - head_num = model.config.num_key_value_heads // self.tp_size - num_hidden_layers = ( - model.config.num_hidden_layers - if hasattr(model.config, "num_hidden_layers") - else model.config.num_layers - ) - layer_num = num_hidden_layers // self.pp_size - elif model.config.model_type == "bloom": - head_dim = model.config.hidden_size // model.config.n_head - head_num = model.config.n_head // self.tp_size - num_hidden_layers = model.config.n_layer - layer_num = num_hidden_layers // self.pp_size - elif model.config.model_type == "chatglm": - head_dim = model.config.hidden_size // model.config.num_attention_heads - if model.config.multi_query_attention: - head_num = model.config.multi_query_group_num // self.tp_size - else: - head_num = model.config.num_attention_heads // self.tp_size - num_hidden_layers = model.config.num_layers - layer_num = num_hidden_layers // self.pp_size - else: - raise NotImplementedError("Only support llama, bloom and chatglm model.") - - if self.quant == "smoothquant": - cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) - else: - cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) - return cache_manager diff --git a/colossalai/inference/engine/microbatch_manager.py b/colossalai/inference/engine/microbatch_manager.py deleted file mode 100644 index 7264b81e06a0..000000000000 --- a/colossalai/inference/engine/microbatch_manager.py +++ /dev/null @@ -1,248 +0,0 @@ -from enum import Enum -from typing import Dict - -import torch - -from ..kv_cache import BatchInferState, MemoryManager - -__all__ = "MicroBatchManager" - - -class Status(Enum): - PREFILL = 1 - GENERATE = 2 - DONE = 3 - COOLDOWN = 4 - - -class MicroBatchDescription: - """ - This is the class to record the information of each microbatch, and also do some update operation. - This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more - details, please refer to the doc of these two classes blow. - - Args: - inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. - """ - - def __init__( - self, - inputs_dict: Dict[str, torch.Tensor], - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - ) -> None: - self.mb_length = inputs_dict["input_ids"].shape[-1] - self.target_length = self.mb_length + max_output_len - self.infer_state = BatchInferState.init_from_batch( - batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager - ) - # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") - - def update(self, *args, **kwargs): - pass - - @property - def state(self): - """ - Return the state of current micro batch, when current length is equal to target length, - the state is DONE, otherwise GENERATE - - """ - # TODO: add the condition for early stopping - if self.cur_length == self.target_length: - return Status.DONE - elif self.cur_length == self.target_length - 1: - return Status.COOLDOWN - else: - return Status.GENERATE - - @property - def cur_length(self): - """ - Return the current sequence length of micro batch - - """ - - -class HeadMicroBatchDescription(MicroBatchDescription): - """ - This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` - and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the - information and the condition to determine the state is different from other stages. - - Args: - inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. - - """ - - def __init__( - self, - inputs_dict: Dict[str, torch.Tensor], - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) - assert inputs_dict is not None - assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None - self.input_ids = inputs_dict["input_ids"] - self.attn_mask = inputs_dict["attention_mask"] - self.new_tokens = None - - def update(self, new_token: torch.Tensor = None): - if new_token is not None: - self._update_newtokens(new_token) - if self.state is not Status.DONE and new_token is not None: - self._update_attnmask() - - def _update_newtokens(self, new_token: torch.Tensor): - if self.new_tokens is None: - self.new_tokens = new_token - else: - self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1) - - def _update_attnmask(self): - self.attn_mask = torch.cat( - (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 - ) - - @property - def cur_length(self): - """ - When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token - - """ - if self.new_tokens is None: - return self.mb_length - else: - return self.mb_length + len(self.new_tokens[0]) - - -class BodyMicroBatchDescription(MicroBatchDescription): - """ - This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, - - Args: - inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. - """ - - def __init__( - self, - inputs_dict: Dict[str, torch.Tensor], - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) - - @property - def cur_length(self): - """ - When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 - - """ - return self.infer_state.seq_len.max().item() - - -class MicroBatchManager: - """ - MicroBatchManager is a class that manages the micro batch. - - Args: - stage (int): stage id of current stage. - micro_batch_size (int): the micro batch size. - micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - - """ - - def __init__( - self, - stage: int, - micro_batch_size: int, - micro_batch_buffer_size: int, - max_input_len: int, - max_output_len: int, - cache_manager_list: MemoryManager, - ): - self.stage = stage - self.micro_batch_size = micro_batch_size - self.buffer_size = micro_batch_buffer_size - self.max_input_len = max_input_len - self.max_output_len = max_output_len - self.cache_manager_list = cache_manager_list - self.mb_description_buffer = {} - self.new_tokens_buffer = {} - self.idx = 0 - - def add_description(self, inputs_dict: Dict[str, torch.Tensor]): - if self.stage == 0: - self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] - ) - else: - self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] - ) - - def step(self, new_token: torch.Tensor = None): - """ - Update the state if microbatch manager, 2 conditions. - 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. - 2. For other condition, only receive the output of previous stage, and update the description. - - Args: - inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. - new_token (torch.Tensor): the new token generated by current stage. - """ - # Add description first if the description is None - self.cur_description.update(new_token) - return self.cur_state - - def export_new_tokens(self): - new_tokens_list = [] - for i in self.mb_description_buffer.values(): - new_tokens_list.extend(i.new_tokens.tolist()) - return new_tokens_list - - def is_micro_batch_done(self): - if len(self.mb_description_buffer) == 0: - return False - for mb in self.mb_description_buffer.values(): - if mb.state != Status.DONE: - return False - return True - - def clear(self): - self.mb_description_buffer.clear() - for cache in self.cache_manager_list: - cache.free_all() - - def next(self): - self.idx = (self.idx + 1) % self.buffer_size - - def _remove_description(self): - self.mb_description_buffer.pop(self.idx) - - @property - def cur_description(self) -> MicroBatchDescription: - return self.mb_description_buffer.get(self.idx) - - @property - def cur_infer_state(self): - if self.cur_description is None: - return None - return self.cur_description.infer_state - - @property - def cur_state(self): - """ - Return the state of current micro batch, when current description is None, the state is PREFILL - - """ - if self.cur_description is None: - return Status.PREFILL - return self.cur_description.state diff --git a/colossalai/inference/engine/modeling/__init__.py b/colossalai/inference/engine/modeling/__init__.py deleted file mode 100644 index 8a9e9999d3c5..000000000000 --- a/colossalai/inference/engine/modeling/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .bloom import BloomInferenceForwards -from .chatglm2 import ChatGLM2InferenceForwards -from .llama import LlamaInferenceForwards - -__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/inference/engine/modeling/_utils.py b/colossalai/inference/engine/modeling/_utils.py deleted file mode 100644 index 068b64b4f829..000000000000 --- a/colossalai/inference/engine/modeling/_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Utils for model inference -""" -import os - -import torch - -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest - - -def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - """ - This function copies the key and value cache to the memory cache - Args: - layer_id : id of current layer - key_buffer : key cache - value_buffer : value cache - context_mem_index : index of memory cache in kv cache manager - mem_manager : cache manager - """ - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - - -def init_to_get_rotary(self, base=10000, use_elem=False): - """ - This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer - Args: - self : Model that holds the rotary positional embedding - base : calculation arg - use_elem : activated when using chatglm-based models - """ - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - - # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) - - if ntk_alpha is not None: - ntk_alpha = float(ntk_alpha) - assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" - if ntk_alpha > 1: - print(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - - n_elem = self.config.head_dim_ - if use_elem: - n_elem //= 2 - - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/engine/modeling/bloom.py b/colossalai/inference/engine/modeling/bloom.py deleted file mode 100644 index 4c098d3e4c80..000000000000 --- a/colossalai/inference/engine/modeling/bloom.py +++ /dev/null @@ -1,452 +0,0 @@ -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch.nn import functional as F -from transformers.models.bloom.modeling_bloom import ( - BaseModelOutputWithPastAndCrossAttentions, - BloomAttention, - BloomBlock, - BloomForCausalLM, - BloomModel, -) -from transformers.utils import logging - -from colossalai.inference.kv_cache.batch_infer_state import BatchInferState -from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd -from colossalai.pipeline.stage_manager import PipelineStageManager - -try: - from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_bloom_context_attention_fwd, - ) - - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - - -def generate_alibi(n_head, dtype=torch.float16): - """ - This method is adapted from `_generate_alibi` function - in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py` - of the ModelTC/lightllm GitHub repository. - This method is originally the `build_alibi_tensor` function - in `transformers/models/bloom/modeling_bloom.py` - of the huggingface/transformers GitHub repository. - """ - - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - return [start * start**i for i in range(n)] - - def get_slopes(n): - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) - slopes_double = get_slopes(2 * closest_power_of_2) - slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] - return slopes_combined - - slopes = get_slopes(n_head) - return torch.tensor(slopes, dtype=dtype) - - -class BloomInferenceForwards: - """ - This class serves a micro library for bloom inference forwards. - We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention, - as well as prepare_inputs_for_generation method for BloomForCausalLM. - For future improvement, we might want to skip replacing methods for BloomForCausalLM, - and call BloomModel.forward iteratively in TpInferEngine - """ - - @staticmethod - def bloom_for_causal_lm_forward( - self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = False, - infer_state: BatchInferState = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - tp_group: Optional[dist.ProcessGroup] = None, - **deprecated_arguments, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - logger = logging.get_logger(__name__) - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - - # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - # If is first stage and hidden_states is not None, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - outputs = BloomInferenceForwards.bloom_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - tp_group=tp_group, - ) - - return outputs - - @staticmethod - def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = None, - infer_state: BatchInferState = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - tp_group: Optional[dist.ProcessGroup] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - logger = logging.get_logger(__name__) - - # add warnings here - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - if use_cache: - logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") - use_cache = False - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - # first stage - if stage_manager.is_first_stage(): - # check inputs and inputs embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # other stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - if seq_length != 1: - # prefill stage - infer_state.is_context_stage = True # set prefill stage, notify attention layer - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - BatchInferState.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - - if attention_mask is None: - attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model, - # or store to BatchInferState to prevent re-calculating - # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here - tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 - curr_tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 - alibi = ( - generate_alibi(self.num_heads * tp_size) - .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] - .cuda() - ) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - infer_state.decode_layer_id = 0 - - start_idx, end_idx = stage_index[0], stage_index[1] - if past_key_values is None: - past_key_values = tuple([None] * (end_idx - start_idx + 1)) - - for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): - block = self.h[idx] - outputs = block( - hidden_states, - layer_past=past_key_value, - attention_mask=causal_mask, - head_mask=head_mask[idx], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - infer_state=infer_state, - ) - - infer_state.decode_layer_id += 1 - hidden_states = outputs[0] - - if stage_manager.is_last_stage() or stage_manager.num_stages == 1: - hidden_states = self.ln_f(hidden_states) - - # update indices - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - # always return dict for imediate stage - return {"hidden_states": hidden_states} - - @staticmethod - def bloom_block_forward( - self: BloomBlock, - hidden_states: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - # hidden_states: [batch_size, seq_length, hidden_size] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - - # Layer norm post the self attention. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - # Self attention. - attn_outputs = self.self_attention( - layernorm_output, - residual, - layer_past=layer_past, - attention_mask=attention_mask, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - infer_state=infer_state, - ) - - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - layernorm_output = self.post_attention_layernorm(attention_output) - - # Get residual - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = attention_output - - # MLP. - output = self.mlp(layernorm_output, residual) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - @staticmethod - def bloom_attention_forward( - self: BloomAttention, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, q_length, H, D_HEAD = query_layer.shape - k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - - mem_manager = infer_state.cache_manager - layer_id = infer_state.decode_layer_id - - if infer_state.is_context_stage: - # context process - max_input_len = q_length - b_start_loc = infer_state.start_loc - b_seq_len = infer_state.seq_len[:batch_size] - q = query_layer.reshape(-1, H, D_HEAD) - - copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) - - # output = self.output[:batch_size*q_length, :, :] - output = torch.empty_like(q) - - if HAS_LIGHTLLM_KERNEL: - lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) - else: - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - else: - # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) - assert q_length == 1, "for non-context process, we only support q_length == 1" - q = query_layer.reshape(-1, H, D_HEAD) - - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(k) - cache_v.copy_(v) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] - copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) - - b_start_loc = infer_state.start_loc - b_loc = infer_state.block_loc - b_seq_len = infer_state.seq_len - output = torch.empty_like(q) - token_attention_fwd( - q, - mem_manager.key_buffer[layer_id], - mem_manager.value_buffer[layer_id], - output, - b_loc, - b_start_loc, - b_seq_len, - infer_state.max_len_in_batch, - alibi, - ) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, - # we create the past key value pair from the cache manager - present = None - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) - - # dropout is not required here during inference - output_tensor = residual + output_tensor - - outputs = (output_tensor, present) - assert output_attentions is False, "we do not support output_attentions at this time" - - return outputs diff --git a/colossalai/inference/engine/modeling/chatglm2.py b/colossalai/inference/engine/modeling/chatglm2.py deleted file mode 100644 index 56e777bb2b87..000000000000 --- a/colossalai/inference/engine/modeling/chatglm2.py +++ /dev/null @@ -1,492 +0,0 @@ -from typing import List, Optional, Tuple - -import torch -from transformers.utils import logging - -from colossalai.inference.kv_cache import BatchInferState -from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, - GLMTransformer, - SelfAttention, - split_tensor_along_last_dim, -) - -from ._utils import copy_kv_to_mem_cache - -try: - from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) - - HAS_LIGHTLLM_KERNEL = True -except: - print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") - HAS_LIGHTLLM_KERNEL = False - - -def get_masks(self, input_ids, past_length, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - if past_length: - full_attention_mask = torch.cat( - ( - torch.ones(batch_size, seq_length, past_length, device=input_ids.device), - full_attention_mask, - ), - dim=-1, - ) - - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - -def get_position_ids(batch_size, seq_length, device): - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - -class ChatGLM2InferenceForwards: - """ - This class holds forwards for Chatglm2 inference. - We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. - """ - - @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = True, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - infer_state: Optional[BatchInferState] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ): - logger = logging.get_logger(__name__) - - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - # If is first stage and hidden_states is not None, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - return {"logits": lm_logits} - - outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config, - ) - - return outputs - - @staticmethod - def chatglm_model_forward( - self: ChatGLMModel, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: BatchInferState = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - if position_ids is None: - position_ids = get_position_ids(batch_size, seq_length, input_ids.device) - hidden_states = inputs_embeds - else: - assert hidden_states is not None, "hidden_states should not be None in non-first stage" - seq_length, batch_size, _ = hidden_states.shape - if position_ids is None: - position_ids = get_position_ids(batch_size, seq_length, hidden_states.device) - - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - seq_length_with_past = seq_length + past_key_values_length - - # prefill stage at first - if seq_length != 1: - infer_state.is_context_stage = True - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - # related to rotary embedding - if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt( - batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype, - ) - if attention_mask is not None: - attention_mask = torch.cat( - [ - attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask, - ], - dim=-1, - ) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = get_masks( - self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask - ) - - # Run encoder. - hidden_states = self.encoder( - hidden_states, - full_attention_mask, - kv_caches=past_key_values, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - infer_state=infer_state, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config, - ) - - # update indices - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - return {"hidden_states": hidden_states} - - @staticmethod - def chatglm_encoder_forward( - self: GLMTransformer, - hidden_states, - attention_mask, - kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - infer_state: Optional[BatchInferState] = None, - stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - ): - hidden_states = hidden_states.transpose(0, 1).contiguous() - - infer_state.decode_layer_id = 0 - start_idx, end_idx = stage_index[0], stage_index[1] - if kv_caches is None: - kv_caches = tuple([None] * (end_idx - start_idx + 1)) - - for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches): - layer = self.layers[idx] - layer_ret = layer( - hidden_states, - attention_mask, - kv_cache=kv_cache, - use_cache=use_cache, - infer_state=infer_state, - ) - infer_state.decode_layer_id += 1 - - hidden_states, _ = layer_ret - - hidden_states = hidden_states.transpose(0, 1).contiguous() - - if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1): - # Final layer norm. - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states - - @staticmethod - def chatglm_glmblock_forward( - self: GLMBlock, - hidden_states, - attention_mask, - kv_cache=None, - use_cache=True, - infer_state: Optional[BatchInferState] = None, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - kv_cache=kv_cache, - use_cache=use_cache, - infer_state=infer_state, - ) - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - return output, kv_cache - - @staticmethod - def chatglm_flash_attn_kvcache_forward( - self: SelfAttention, - hidden_states, - attention_mask, - kv_cache=None, - use_cache=True, - infer_state: Optional[BatchInferState] = None, - ): - assert use_cache is True, "use_cache should be set to True using this chatglm attention" - # hidden_states: original :[sq, b, h] --> this [b, sq, h] - batch_size = hidden_states.shape[0] - hidden_size = hidden_states.shape[-1] - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] - + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - cos, sin = infer_state.position_cos, infer_state.position_sin - - chatglm2_rotary_emb_fwd( - query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin - ) - if self.multi_query_attention: - chatglm2_rotary_emb_fwd( - key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), - cos, - sin, - ) - else: - chatglm2_rotary_emb_fwd( - key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - cos, - sin, - ) - - # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 - query_layer = query_layer.reshape( - -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head - ) - key_layer = key_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) - value_layer = value_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) - - if infer_state.is_context_stage: - # first token generation: - # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_layer, - value_layer, - infer_state.context_mem_index, - infer_state.cache_manager, - ) - attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) - - # NOTE: no bug in context attn fwd (del it ) - lightllm_llama2_context_attention_fwd( - query_layer, - key_layer, - value_layer, - attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - - else: - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(key_layer) - cache_v.copy_(value_layer) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_layer, - value_layer, - infer_state.decode_mem_index, - infer_state.cache_manager, - ) - - # second token and follows - attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - : infer_state.decode_mem_end, :, : - ] - - # ================================== - # core attention computation is replaced by triton kernel - # ================================== - Llama2TokenAttentionForwards.token_attn( - query_layer, - cache_k, - cache_v, - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) - - # ================= - # Output:[b,sq, h] - # ================= - output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) - - return output, kv_cache diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py deleted file mode 100644 index a7efb4026be0..000000000000 --- a/colossalai/inference/engine/modeling/llama.py +++ /dev/null @@ -1,503 +0,0 @@ -# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py -import math -from typing import List, Optional, Tuple - -import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel -from transformers.utils import logging - -from colossalai.inference.kv_cache.batch_infer_state import BatchInferState -from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd -from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards -from colossalai.pipeline.stage_manager import PipelineStageManager - -from ._utils import copy_kv_to_mem_cache - -try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) - from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_context_attention_fwd, - ) - from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd - - HAS_LIGHTLLM_KERNEL = True -except: - print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") - HAS_LIGHTLLM_KERNEL = False - -try: - from colossalai.kernel.triton.flash_decoding import token_flash_decoding - - HAS_TRITON_FLASH_DECODING_KERNEL = True -except: - print( - "no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8" - ) - HAS_TRITON_FLASH_DECODING_KERNEL = False - -try: - from flash_attn import flash_attn_with_kvcache - - HAS_FLASH_KERNEL = True -except: - HAS_FLASH_KERNEL = False - print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def llama_triton_context_attention( - query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 -): - if num_key_value_groups == 1: - if HAS_LIGHTLLM_KERNEL is False: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - else: - lightllm_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - else: - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" - lightllm_llama2_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - - -def llama_triton_token_attention( - query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1 -): - if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: - token_flash_decoding( - q=query_states, - o_tensor=attn_output, - infer_state=infer_state, - q_head_num=q_head_num, - head_dim=head_dim, - cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - ) - return - - if num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - else: - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) - - -class LlamaInferenceForwards: - """ - This class holds forwards for llama inference. - We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. - """ - - @staticmethod - def llama_causal_lm_forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: BatchInferState = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - """ - logger = logging.get_logger(__name__) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - # If is first stage and hidden_states is None, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = LlamaInferenceForwards.llama_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - return outputs - - @staticmethod - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: BatchInferState = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - use_cache = use_cache if use_cache is not None else self.config.use_cache - # retrieve input_ids and inputs_embeds - if stage_manager is None or stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - assert stage_manager is not None - assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}" - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - # NOTE: differentiate with prefill stage - # block_loc require different value-assigning method for two different stage - if use_cache and seq_length != 1: - # NOTE assume prefill stage - # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - else: - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.repeat(batch_size, 1) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device - ) - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - - # decoder layers - infer_state.decode_layer_id = 0 - - start_idx, end_idx = stage_index[0], stage_index[1] - if past_key_values is None: - past_key_values = tuple([None] * (end_idx - start_idx + 1)) - - for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): - decoder_layer = self.layers[idx] - # NOTE: modify here for passing args to decoder layer - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_state=infer_state, - ) - infer_state.decode_layer_id += 1 - hidden_states = layer_outputs[0] - - if stage_manager.is_last_stage() or stage_manager.num_stages == 1: - hidden_states = self.norm(hidden_states) - - # update indices - # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - return {"hidden_states": hidden_states} - - @staticmethod - def llama_decoder_layer_forward( - self: LlamaDecoderLayer, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_state=infer_state, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - @staticmethod - def llama_flash_attn_kvcache_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - assert use_cache is True, "use_cache should be set to True using this llama attention" - - bsz, q_len, _ = hidden_states.size() - - # NOTE might think about better way to handle transposed k and v - # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] - # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - # NOTE might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - - cos, sin = infer_state.position_cos, infer_state.position_sin - - llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) - - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) - - if infer_state.is_context_stage: - # first token generation - # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.context_mem_index, - infer_state.cache_manager, - ) - attn_output = torch.empty_like(query_states) - - llama_triton_context_attention( - query_states, - key_states, - value_states, - attn_output, - infer_state, - num_key_value_groups=self.num_key_value_groups, - ) - else: - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(key_states) - cache_v.copy_(value_states) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.decode_mem_index, - infer_state.cache_manager, - ) - - if HAS_LIGHTLLM_KERNEL: - attn_output = torch.empty_like(query_states) - llama_triton_token_attention( - query_states=query_states, - attn_output=attn_output, - infer_state=infer_state, - num_key_value_groups=self.num_key_value_groups, - q_head_num=q_len * self.num_heads, - head_dim=self.head_dim, - ) - else: - self.num_heads // self.num_key_value_heads - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) - copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - - attn_output = flash_attn_with_kvcache( - q=query_states, - k_cache=copy_cache_k, - v_cache=copy_cache_v, - softmax_scale=1 / math.sqrt(self.head_dim), - causal=True, - ) - - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - # return past_key_value as None - return attn_output, None, None diff --git a/colossalai/inference/engine/policies/__init__.py b/colossalai/inference/engine/policies/__init__.py deleted file mode 100644 index 269d1c57b276..000000000000 --- a/colossalai/inference/engine/policies/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .bloom import BloomModelInferPolicy -from .chatglm2 import ChatGLM2InferPolicy -from .llama import LlamaModelInferPolicy - -model_policy_map = { - "llama": LlamaModelInferPolicy, - "bloom": BloomModelInferPolicy, - "chatglm": ChatGLM2InferPolicy, -} - -__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy", "model_polic_map"] diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py deleted file mode 100644 index 5bc47c3c1a49..000000000000 --- a/colossalai/inference/engine/policies/bloom.py +++ /dev/null @@ -1,127 +0,0 @@ -from functools import partial -from typing import List - -import torch -from torch.nn import LayerNorm, Module - -import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription -from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy - -from ..modeling.bloom import BloomInferenceForwards - -try: - from colossalai.kernel.triton import layer_norm - - HAS_TRITON_NORM = True -except: - print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") - HAS_TRITON_NORM = False - - -def get_triton_layernorm_forward(): - if HAS_TRITON_NORM: - - def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): - return layer_norm(hidden_states, self.weight.data, self.bias, self.eps) - - return _triton_layernorm_forward - else: - return None - - -class BloomModelInferPolicy(BloomForCausalLMPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel - - policy = super().module_policy() - if self.shard_config.extra_kwargs.get("quant", None) == "gptq": - from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - - policy[BloomBlock] = ModulePolicyDescription( - attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size - // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size - // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 3}, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} - ), - ], - ) - # NOTE set inference mode to shard config - self.shard_config._infer() - - # set as default, in inference we also use pipeline style forward, just setting stage as 1 - self.set_pipeline_forward( - model_cls=BloomForCausalLM, - new_forward=partial( - BloomInferenceForwards.bloom_for_causal_lm_forward, - tp_group=self.shard_config.tensor_parallel_process_group, - ), - policy=policy, - ) - - method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) - - method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) - - method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=BloomAttention - ) - - if HAS_TRITON_NORM: - infer_method = get_triton_layernorm_forward() - method_replacement = {"forward": partial(infer_method)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LayerNorm - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - - if self.model.__class__.__name__ == "BloomModel": - module = self.model - else: - module = self.model.transformer - stage_manager = self.pipeline_stage_manager - - held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.h)) - if stage_manager.is_first_stage(): - held_layers.append(module.word_embeddings) - held_layers.append(module.word_embeddings_layernorm) - held_layers.append(self.model.lm_head) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) - - return held_layers diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py deleted file mode 100644 index c7c6f3b927e1..000000000000 --- a/colossalai/inference/engine/policies/chatglm2.py +++ /dev/null @@ -1,89 +0,0 @@ -from typing import List - -import torch.nn as nn - -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, - GLMTransformer, - SelfAttention, -) - -# import colossalai -from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy - -from ..modeling._utils import init_to_get_rotary -from ..modeling.chatglm2 import ChatGLM2InferenceForwards - -try: - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -class ChatGLM2InferPolicy(ChatGLMModelPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - self.shard_config._infer() - - model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward - method_replacement = {"forward": model_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) - - encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward - method_replacement = {"forward": encoder_infer_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=GLMTransformer - ) - - encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward - method_replacement = {"forward": encoder_layer_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) - - attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward - method_replacement = {"forward": attn_infer_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=SelfAttention - ) - if self.shard_config.enable_tensor_parallelism: - policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( - self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size - ) - # for rmsnorm and others, we need to check the shape - - self.set_pipeline_forward( - model_cls=ChatGLMForConditionalGeneration, - new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward, - policy=policy, - ) - - return policy - - def get_held_layers(self) -> List[nn.Module]: - module = self.model.transformer - stage_manager = self.pipeline_stage_manager - - held_layers = [] - layers_per_stage = stage_manager.distribute_layers(module.num_layers) - if stage_manager.is_first_stage(): - held_layers.append(module.embedding) - held_layers.append(module.output_layer) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.encoder.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - if module.encoder.post_layer_norm: - held_layers.append(module.encoder.final_layernorm) - - # rotary_pos_emb is needed for all stages - held_layers.append(module.rotary_pos_emb) - - return held_layers - - def postprocess(self): - init_to_get_rotary(self.model.transformer) - return self.model diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py deleted file mode 100644 index a57a4e50cdb9..000000000000 --- a/colossalai/inference/engine/policies/llama.py +++ /dev/null @@ -1,206 +0,0 @@ -from functools import partial -from typing import List - -import torch -from torch.nn import Module -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, -) - -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription - -# import colossalai -from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy - -from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards - -try: - from colossalai.kernel.triton import rmsnorm_forward - - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - - return _triton_rmsnorm_forward - else: - return None - - -class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads - // self.shard_config.tensor_parallel_size, - } - if self.shard_config.extra_kwargs.get("quant", None) == "gptq": - from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - ], - ) - - elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant": - from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer - from colossalai.inference.quant.smoothquant.models.parallel_linear import ( - ColW8A8BFP32OFP32Linear, - RowW8A8B8O8Linear, - RowW8A8BFP32O32LinearSiLU, - RowW8A8BFP32OFP32Linear, - ) - - policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=RowW8A8B8O8Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=ColW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=RowW8A8BFP32O32LinearSiLU, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=RowW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=ColW8A8BFP32OFP32Linear, - kwargs={"split_num": 1}, - ), - ], - ) - self.shard_config._infer() - - infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaDecoderLayer - ) - - infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaAttention - ) - - # set as default, in inference we also use pipeline style forward, just setting stage as 1 - self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy - ) - - infer_forward = None - if HAS_TRITON_RMSNORM: - infer_forward = get_triton_rmsnorm_forward() - - if infer_forward is not None: - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaRMSNorm - ) - - return policy - - def postprocess(self): - init_to_get_rotary(self.model.model) - return self.model - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - assert self.pipeline_stage_manager is not None - - if self.model.__class__.__name__ == "LlamaModel": - module = self.model - else: - module = self.model.model - stage_manager = self.pipeline_stage_manager - - held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - held_layers.append(self.model.lm_head) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) - - return held_layers diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py new file mode 100644 index 000000000000..913b8667dcf9 --- /dev/null +++ b/colossalai/inference/executor/rpc_worker.py @@ -0,0 +1,306 @@ +from typing import List, Tuple, Union + +import rpyc +import torch +import torch.distributed as dist +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import InferenceConfig, InputMetaData +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.policy import ( + NoPaddingBaichuanModelInferPolicy, + NoPaddingLlamaModelInferPolicy, + model_policy_map, +) +from colossalai.inference.sampler import search_tokens +from colossalai.inference.utils import get_model_size +from colossalai.interface import ModelWrapper +from colossalai.logging import get_dist_logger +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +PP_AXIS, TP_AXIS = 0, 1 + +_SUPPORTED_MODELS = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_SUPPORTED_MODEL_POLICIES = { + "NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy, + "NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy, +} + +logger = get_dist_logger(__name__) + + +class rpcWorkerService(rpyc.Service): + + """ + Execute the computation tasks and manage its own kv cache + + Func with prefix `exposed_` will be invoked by client. + """ + + def exposed_init_dist_env(self, rank, world_size, master_address, master_port): + logger.info(f"init process group for rank {rank}") + colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) + logger.info(f"init process group done for rank {rank}") + + def exposed_init_model( + self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None + ): + assert dist.is_initialized(), "invoke init_dist_env first please!" + + self.inference_config = InferenceConfig.from_rpc_param(inference_config_param) + model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None + + self.dtype = self.inference_config.dtype + self.verbose = True + + self._init_model(model_or_path, model_policy) + self._init_fd_tensor() + self._init_output_tensor() + logger.info(f"init model done for rank {dist.get_rank()}") + + def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): + """Initialize the physical cache on the device. + + For each layer of the model, we allocate two tensors for key and value respectively, + with shape of [num_blocks, num_kv_heads, block_size, head_size] + """ + kalloc_shape, valloc_shape = alloc_shape + num_layers = self.model_config.num_hidden_layers + + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + for _ in range(num_layers): + self.k_cache.append( + torch.zeros( + kalloc_shape, + dtype=self.inference_config.kv_cache_dtype, + device=get_accelerator().get_current_device(), + ) + ) + self.v_cache.append( + torch.zeros( + valloc_shape, + dtype=self.inference_config.kv_cache_dtype, + device=get_accelerator().get_current_device(), + ) + ) + logger.info("physical cache init over") + + def exposed_execute_model_forward( + self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict + ): + # prepare the data for model forward + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + if input_meta_data.is_prompts: + n_tokens = input_meta_data.sequence_lengths.sum().item() + else: + n_tokens = input_meta_data.batch_size + input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) + + # execute the model + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) + + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + generation_config_param, + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) + + # return the tokens generated to scheduler + return next_tokens.tolist() + + def _init_output_tensor(self): + alloc_shape = ( + self.inference_config.max_batch_size + * (self.inference_config.max_input_len + self.inference_config.max_output_len), + self.model_config.hidden_size // self.inference_config.tp_size, + ) + self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device) + + def _init_fd_tensor(self): + fd_inter_tensor = FDIntermTensors() + + if fd_inter_tensor._tensors_initialized: + fd_inter_tensor._reset() + + # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq + max_n_tokens = self.inference_config.max_batch_size + max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 + + inference_config = self.inference_config + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + fd_inter_tensor.initialize( + max_batch_size=max_n_tokens, + num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size, + kv_max_split_num=kv_max_split_num, + head_dim=head_dim, + dtype=self.dtype, + device=get_accelerator().get_current_device(), + ) + + self.fd_inter_tensor = fd_inter_tensor + + def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + """ + Shard model or/and Load weight + + Shard model: When we set tp_size > 1, we will shard the model by given model_policy. + Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model + """ + + if isinstance(model_or_path, str): + # is_local = os.path.isdir(model_or_path) + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + arch = getattr(hf_config, "architectures")[0] + # NOTE(lry89757) Currently we load the model using transformers-api, + # but we will use lazy tensor and checkpoint io to accelerate + # the model load process in the future. + model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + # if is_local: + # model = _SUPPORTED_MODELS[arch](hf_config) + # else: + # # load the real checkpoint + # model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) + except Exception as e: + logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + torch.cuda.set_device(self.device) + if self.verbose: + logger.info(f"the device is {self.device}") + + model = model.to(dtype=self.dtype, non_blocking=False).eval() + + if self.verbose: + logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + if self.inference_config.pad_input: + model_type = "padding_" + self.model_config.model_type + else: + model_type = "nopadding_" + self.model_config.model_type + model_policy = model_policy_map[model_type]() + + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device()) + + if self.verbose: + logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + # NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor + # if isinstance(model_or_path, str) and is_local: + # from colossalai.inference.core.plugin import InferCheckpoint_io + + # cpt_io = InferCheckpoint_io() + # if_has_index_file, model_index_file = has_index_file(model_or_path) + # assert if_has_index_file, "the model path is invalid" + # cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model + + def exposed_compute_only_for_test(self): + dist_rank = dist.get_rank() + + # Dummy data for each worker + data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank) + dist.barrier() + + # Perform distributed all_reduce + dist.all_reduce(data, op=dist.ReduceOp.SUM) + + dist.barrier() + logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") + + return data.item() diff --git a/colossalai/inference/flash_decoding_utils.py b/colossalai/inference/flash_decoding_utils.py new file mode 100644 index 000000000000..48f43bf51622 --- /dev/null +++ b/colossalai/inference/flash_decoding_utils.py @@ -0,0 +1,82 @@ +import torch + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.utils import get_current_device + + +class FDIntermTensors(metaclass=SingletonMeta): + """Singleton class to hold tensors used for storing intermediate values in flash-decoding. + For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv) + """ + + def __init__(self): + self._tensors_initialized = False + + def _reset(self): + self._tensors_initialized = False + del self._mid_output + del self._mid_output_lse + del self._exp_sums + del self._max_logits + + @property + def is_initialized(self): + return self._tensors_initialized + + @property + def mid_output(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._mid_output + + @property + def mid_output_lse(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._mid_output_lse + + @property + def exp_sums(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._exp_sums + + @property + def max_logits(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._max_logits + + def initialize( + self, + max_batch_size: int, + num_attn_heads: int, + kv_max_split_num: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + device: torch.device = get_current_device(), + ) -> None: + """Initialize tensors. + + Args: + max_batch_size (int): The maximum batch size over all the model forward. + This could be greater than the batch size in attention forward func when using dynamic batch size. + num_attn_heads (int)): Number of attention heads. + kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm. + **The maximum length/size of blocks splitted on kv should be the kv cache block size.** + head_dim (int): Head dimension. + dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors. + device (torch.device, optional): Device used to initialize intermediate tensors. + """ + assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized." + + self._mid_output = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device + ) + self._mid_output_lse = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) + self._exp_sums = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) + self._max_logits = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) + + self._tensors_initialized = True diff --git a/colossalai/inference/graph_runner.py b/colossalai/inference/graph_runner.py new file mode 100644 index 000000000000..e8b805574e43 --- /dev/null +++ b/colossalai/inference/graph_runner.py @@ -0,0 +1,100 @@ +from typing import Dict, List + +import torch +from torch import nn + +from colossalai.inference.config import InputMetaData +from colossalai.logging import get_dist_logger + + +class CUDAGraphRunner: + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + self.logger = get_dist_logger(__name__) + + def capture( + self, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + memory_pool=None, + ) -> None: + assert self.graph is None + + # run kernel once to cache the kernel, avoid stream capture error + hidden_states_origin_model = self.model( + input_tokens_ids, + output_tensor, + inputmetadata, + k_caches, + v_caches, + ) + torch.cuda.synchronize() + + # Capture the graph. + # self.logger.info(f"begin capture model...") + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states_cuda_graph = self.model( + input_tokens_ids, + output_tensor, + inputmetadata, + k_caches, + v_caches, + ) + torch.cuda.synchronize() + + # Save the input and output buffers, because replay always uses the same virtual memory space + self.input_buffers = { + "input_tokens_ids": input_tokens_ids, + "output_tensor": output_tensor, + "block_tables": inputmetadata.block_tables, + "sequence_lengths": inputmetadata.sequence_lengths, + # "fd_inter_tensor_mid_output": inputmetadata.fd_inter_tensor._mid_output, + # "fd_inter_tensor_mid_output_lse": inputmetadata.fd_inter_tensor._mid_output_lse, + "k_caches": k_caches, + "v_caches": v_caches, + } + self.output_buffers = {"logits": hidden_states_cuda_graph} + return + + def forward( + self, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + ) -> torch.Tensor: + # Copy the input tensors to the input buffers. + self.input_buffers["input_tokens_ids"].copy_(input_tokens_ids, non_blocking=True) + self.input_buffers["output_tensor"].copy_(output_tensor, non_blocking=True) + + # for flexible block_table + self.input_buffers["block_tables"].fill_(-1) + M, N = inputmetadata.block_tables.shape + self.input_buffers["block_tables"][:M, :N].copy_(inputmetadata.block_tables, non_blocking=True) + + self.input_buffers["sequence_lengths"].copy_(inputmetadata.sequence_lengths, non_blocking=True) + + # we only have a global fd_inter_tensor so we don't need to copy them + # self.input_buffers["fd_inter_tensor_mid_output"].copy_(inputmetadata.fd_inter_tensor.mid_output, non_blocking=True) + # self.input_buffers["fd_inter_tensor_mid_output_lse"].copy_(inputmetadata.fd_inter_tensor.mid_output_lse, non_blocking=True) + + # KV caches are fixed tensors, so we don't need to copy them. + # self.input_buffers["k_caches"].copy_(k_caches, non_blocking=True) + # self.input_buffers["v_caches"].copy_(v_caches, non_blocking=True) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["logits"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/colossalai/inference/kv_cache/__init__.py b/colossalai/inference/kv_cache/__init__.py index 5b6ca182efae..b232db936774 100644 --- a/colossalai/inference/kv_cache/__init__.py +++ b/colossalai/inference/kv_cache/__init__.py @@ -1,2 +1,4 @@ -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager +from .block_cache import CacheBlock +from .kvcache_manager import KVCacheManager, RPCKVCacheManager + +__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"] diff --git a/colossalai/inference/kv_cache/batch_infer_state.py b/colossalai/inference/kv_cache/batch_infer_state.py deleted file mode 100644 index f707a86df37e..000000000000 --- a/colossalai/inference/kv_cache/batch_infer_state.py +++ /dev/null @@ -1,118 +0,0 @@ -# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later -from dataclasses import dataclass - -import torch -from transformers.tokenization_utils_base import BatchEncoding - -from .kvcache_manager import MemoryManager - - -# adapted from: lightllm/server/router/model_infer/infer_batch.py -@dataclass -class BatchInferState: - r""" - Information to be passed and used for a batch of inputs during - a single model forward - """ - batch_size: int - max_len_in_batch: int - - cache_manager: MemoryManager = None - - block_loc: torch.Tensor = None - start_loc: torch.Tensor = None - seq_len: torch.Tensor = None - past_key_values_len: int = None - - is_context_stage: bool = False - context_mem_index: torch.Tensor = None - decode_is_contiguous: bool = None - decode_mem_start: int = None - decode_mem_end: int = None - decode_mem_index: torch.Tensor = None - decode_layer_id: int = None - - device: torch.device = torch.device("cuda") - - @property - def total_token_num(self): - # return self.batch_size * self.max_len_in_batch - assert self.seq_len is not None and self.seq_len.size(0) > 0 - return int(torch.sum(self.seq_len)) - - def set_cache_manager(self, manager: MemoryManager): - self.cache_manager = manager - - # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 - @staticmethod - def init_block_loc( - b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor - ): - """in-place update block loc mapping based on the sequence length of the inputs in current bath""" - start_index = 0 - seq_len_numpy = seq_len.cpu().numpy() - for i, cur_seq_len in enumerate(seq_len_numpy): - b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ - start_index : start_index + cur_seq_len - ] - start_index += cur_seq_len - return - - @classmethod - def init_from_batch( - cls, - batch: torch.Tensor, - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - ): - if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): - raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") - - input_ids_list = None - attention_mask = None - - if isinstance(batch, (BatchEncoding, dict)): - input_ids_list = batch["input_ids"] - attention_mask = batch["attention_mask"] - else: - input_ids_list = batch - if isinstance(input_ids_list[0], int): # for a single input - input_ids_list = [input_ids_list] - attention_mask = [attention_mask] if attention_mask is not None else attention_mask - - batch_size = len(input_ids_list) - - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - start_index = 0 - - max_len_in_batch = -1 - if isinstance(batch, (BatchEncoding, dict)): - for i, attn_mask in enumerate(attention_mask): - curr_seq_len = len(attn_mask) - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - else: - length = max(len(input_id) for input_id in input_ids_list) - for i, input_ids in enumerate(input_ids_list): - curr_seq_len = length - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") - - return cls( - batch_size=batch_size, - max_len_in_batch=max_len_in_batch, - seq_len=seq_lengths.to("cuda"), - start_loc=seq_start_indexes.to("cuda"), - block_loc=block_loc, - decode_layer_id=0, - past_key_values_len=0, - is_context_stage=True, - cache_manager=cache_manager, - ) diff --git a/colossalai/inference/kv_cache/block_cache.py b/colossalai/inference/kv_cache/block_cache.py new file mode 100644 index 000000000000..755c9581e224 --- /dev/null +++ b/colossalai/inference/kv_cache/block_cache.py @@ -0,0 +1,58 @@ +from typing import Any + +__all__ = ["CacheBlock"] + + +class CacheBlock: + """A simplified version of logical cache block used for Paged Attention.""" + + def __init__(self, block_id: int, block_size: int, elem_size: int, k_ptrs: Any = None, v_ptrs: Any = None): + # Unique id of a cache block + self.block_id = block_id + + # size/capacity of the block in terms of the number of tokens it can hold + self.block_size = block_size + + # element size in bytes + self.elem_size = elem_size + + # For common cases, we track the relationships between logical and physical caches in KV Cache Manager, + # Additionally, k, v pointers can be optionally used for tracking the physical cache by CacheBlock itself. + self.k_ptrs = k_ptrs + self.v_ptrs = v_ptrs + + self.ref_count = 0 + # the number of slots that have been allocated (i.e. the number of tokens occupying the block) + self.allocated_size = 0 + # the token ids whose KV Cache would be written to corresponding physical caches + # TODO add logics to update token_ids + self.token_ids = [None] * self.block_size + + @property + def available_space(self) -> int: + # `allocated_size` is ensured to be less than or equal to `block_size` + return self.block_size - self.allocated_size + + def add_ref(self) -> None: + self.ref_count += 1 + + def remove_ref(self) -> None: + assert self.ref_count > 0, f"Block#{self.block_id} has no reference to remove." + self.ref_count -= 1 + + def has_ref(self) -> bool: + return self.ref_count > 0 + + def allocate(self, size: int) -> None: + assert size <= self.available_space, f"Block#{self.block_id} has no available space to allocate." + self.allocated_size += size + + def is_empty(self): + return self.allocated_size < 1 + + def clear(self) -> None: + self.ref_count = 0 + self.allocated_size = 0 + + def __repr__(self): + return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})" diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index dda46a756cc3..a20bd8ee79ea 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -1,106 +1,576 @@ -""" -Refered/Modified from lightllm/common/mem_manager.py -of the ModelTC/lightllm GitHub repository -https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py -we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. -""" +from typing import List, Tuple + import torch -from transformers.utils import logging +from transformers.configuration_utils import PretrainedConfig + +from colossalai.inference.config import InferenceConfig +from colossalai.inference.struct import Sequence +from colossalai.logging import get_dist_logger +from colossalai.utils import get_current_device + +from .block_cache import CacheBlock + +__all__ = ["KVCacheManager"] + +GIGABYTE = 1024**3 -class MemoryManager: - r""" - Manage token block indexes and allocate physical memory for key and value cache +class KVCacheManager: + """KVCacheManager manages both the logical cache blocks and physical KV cache (tensors). - Args: - size: maximum token number used as the size of key and value buffer - dtype: data type of cached key and value - head_num: number of heads the memory manager is responsible for - head_dim: embedded size per head - layer_num: the number of layers in the model - device: device used to store the key and value cache + NOTE: The KVCacheManager is designed to be interacted with indices of logical blocks. + That is, it won't allocate and return a physical cache to the engine or scheduler; + instead, it will mark the logical block as allocated and update the block id representing + the physical cache to the caller. The physical cache is actually used and updated in kernels. + + Example + A block table of a single sequence before block allocation might be: + | -1 | -1 | -1 | -1 | -1 | -1 | + where the maximum blocks per sequence is 6 + The block table after block allocation might be: + | 0 | 1 | 2 | -1 | -1 | -1 | + Then the logical blocks with id 0, 1, and 2, are allocated for this sequence, + and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer, + corresponding to these blocks will be used to read/write KV Caches in kernels. + + For a batch of sequences, the block tables after allocation might be: + | 0 | 1 | 2 | -1 | -1 | -1 | + | 3 | 4 | 5 | 6 | 7 | -1 | + | 8 | 9 | 10 | 11 | -1 | -1 | + | 12 | 13 | 14 | 15 | -1 | -1 | + where 16 logical cache blocks are allocated and the same number of physical cache blocks will be used in kernels. + + Currently, allocations and updates are done at granularity of a single sequence. + That is, the block table should be a 1D tensor of shape [max_blocks_per_sequence]. + And it's possible to have a batch of sequences with different lengths of block tables. """ - def __init__( + def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None: + self.logger = get_dist_logger(__name__) + self.device = get_current_device() + + # Parallel settings + self.tp_size = config.tp_size + # Model settings + self.dtype = config.dtype + + if config.kv_cache_dtype is None: + self.kv_cache_dtype = config.dtype + else: + self.kv_cache_dtype = config.kv_cache_dtype + + self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() + self.num_layers = model_config.num_hidden_layers + self.head_num = model_config.num_attention_heads + self.head_size = model_config.hidden_size // self.head_num + if hasattr(model_config, "num_key_value_heads"): + self.kv_head_num = model_config.num_key_value_heads + else: + self.kv_head_num = self.head_num + + assert ( + self.kv_head_num % self.tp_size == 0 + ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" + self.kv_head_num //= self.tp_size + self.beam_width = config.beam_width + self.max_batch_size = config.max_batch_size + self.max_input_length = config.max_input_len + self.max_output_length = config.max_output_len + # Cache block settings + self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + + # Physical cache allocation + if config.use_cuda_kernel: + x = 16 // torch.tensor([], dtype=config.dtype).element_size() + kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x) + valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info( + f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks." + ) + self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape) + else: + alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape) + self.total_physical_cache_size_in_bytes = ( + self.elem_size_in_bytes + * self.num_layers + * 2 + * self.num_blocks + * self.block_size + * self.kv_head_num + * self.head_size + ) + self.logger.info( + f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}." + ) + # Logical cache blocks allocation + self._available_blocks = self.num_blocks + self._cache_blocks = tuple(self._init_logical_caches()) + # block availablity state 0->allocated, 1->free + self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool) + self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64) + self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64) + + @property + def total_num_blocks(self) -> int: + """Get the total number of logical cache blocks.""" + return self.num_blocks + + @property + def num_available_blocks(self) -> int: + """Get the number of available cache blocks.""" + return self._available_blocks + + def get_head_size(self): + return self.head_size + + def get_kv_cache(self): + """Get k_cache and v_cache""" + return self._kv_caches + + def get_max_blocks_per_sequence(self) -> int: + """Get the maximum number of blocks that can be allocated for a single sequence.""" + # TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler, + # which will make the max_blocks_per_sequence dynamic based on the prompt lengths of sequences + # in the current batch. + return self.max_blocks_per_sequence + + def check_allocation(self, seq: Sequence) -> bool: + num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size + return num_blocks_needed <= self.num_available_blocks + + def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]: + """Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block.""" + block: CacheBlock = self._cache_blocks[block_id] + return block.k_ptrs[layer_id], block.v_ptrs[layer_id] + + def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> Tuple[int, int]: + """Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table.""" + k_ptrs = [] + v_ptrs = [] + for block_id in block_table: + if block_id >= 0: + block: CacheBlock = self._cache_blocks[block_id] + k_ptrs.append(block.k_ptrs[layer_id]) + v_ptrs.append(block.v_ptrs[layer_id]) + return k_ptrs, v_ptrs + + def allocate_context_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: + """Allocate the logical cache blocks for a single sequence during prefill stage, + and updates the provided block table with the allocated block ids. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. + context_len: The length of the processing sequnece. + """ + assert block_table.dim() == 1 + if not torch.all(block_table < 0): + self.logger.error("Some slots on provided block table have been allocated.") + blocks_required = (context_len + self.block_size - 1) // self.block_size + if blocks_required > self._available_blocks: + self.logger.warning( + f"No enough blocks to allocate. Available blocks {self._available_blocks}; context length {context_len}." + ) + return + + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[blocks_required:], + self._block_states_cum[:-blocks_required], + out=self._block_finder[blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - blocks_required # closed interval + block_indexes = torch.arange(start_idx, end_idx, device=block_table.device) + else: + # non-contiguous cache + available_block_indexes = torch.nonzero(self._block_states == 0).view(-1) + block_indexes = available_block_indexes[:blocks_required] + # Update block table + block_table[:blocks_required] = block_indexes + # Update cache blocks + self._block_states[block_indexes] = 0 + self._available_blocks -= blocks_required + for block_id in block_indexes.tolist(): + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + if block_id == block_indexes[-1].item(): + self._allocate_on_block( + block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size + ) + else: + self._allocate_on_block(block, block.block_size) + + def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None: + """Allocate logical cache blocks for a batch of sequences during prefill stage. + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz]] + """ + assert block_tables.dim() == 2 + assert block_tables.size(0) == context_lengths.size(0) + if not torch.all(block_tables < 0): + self.logger.error("Some slots on provided block table have been allocated.") + blocks_required = (context_lengths + self.block_size - 1) // self.block_size + num_blocks_required = torch.sum(blocks_required).item() + assert isinstance(num_blocks_required, int) + if num_blocks_required > self._available_blocks: + self.logger.warning( + f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}." + ) + return + + bsz = block_tables.size(0) + # Try contiguous allocation + torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:]) + torch.subtract( + self._block_states_cum[num_blocks_required:], + self._block_states_cum[:-num_blocks_required], + out=self._block_finder[num_blocks_required - 1 :], + ) + end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1) + if end_indexes.numel() > 0: + # contiguous cache exists + end_idx = end_indexes[0].item() + 1 # open interval + start_idx = end_idx - num_blocks_required # closed interval + alloc_block_ids = torch.arange(start_idx, end_idx) + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, :curr_required] = torch.arange( + start_idx, start_idx + curr_required, device=block_tables.device + ) + start_idx += curr_required + else: + # non-contiguous cache + available_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = available_block_ids[:num_blocks_required] + alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device) + start_idx = 0 + for i in range(bsz): + curr_required = blocks_required[i] + block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required] + start_idx += curr_required + + # Update cache blocks + self._block_states[alloc_block_ids] = 0 + self._available_blocks -= num_blocks_required + last_block_locs = torch.cumsum(blocks_required, dim=0) - 1 + last_block_locs = last_block_locs.to(device=alloc_block_ids.device) + + for i, block_id in enumerate(alloc_block_ids[last_block_locs]): + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block( + block, + block.block_size + if context_lengths[i] % block.block_size == 0 + else context_lengths[i].item() % block.block_size, + ) + for block_id in alloc_block_ids: + if block_id in alloc_block_ids[last_block_locs]: + continue + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._allocate_on_block(block, block.block_size) + + def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None: + """Allocate the logical cache block for a single sequence during decoding stage, + and updates the provided block table if a new cache block is needed. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. + context_len: The length of the processing sequnece (already-allocated length). + """ + assert block_table.dim() == 1 + # The last allocated block may be either partially or fully occupied. + # `alloc_local_block_idx` is the index of block to be allocated on provided block table. + alloc_local_block_idx = context_len // self.block_size + return self.allocate_single_block(block_table, alloc_local_block_idx) + + def allocate_tokens_from_block_tables( + self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None + ) -> List[int]: + """Allocate logical cache blocks for a batch of sequences during decoding stage. + + Usage: + allocate_context_from_block_tables + model forward (block tables & context lengths passed) + update context lengths + allocate_tokens_from_block_tables + model forward + update context lengths + allocate_tokens_from_block_tables + model forward + update context lengths + ... + + Args: + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] + context_lengths (torch.Tensor): [bsz] + + Returns: + List[int]: list of sequence uid to be recycled + """ + assert block_tables.dim() == 2 + assert context_lens.dim() == 1 + + bsz = block_tables.size(0) if bsz is None else bsz + + alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size + block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes] + seqs_to_recycle = [] + new_blocks_required = torch.sum(block_global_ids < 0).item() + seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze() + + if new_blocks_required > 0: + if new_blocks_required > self._available_blocks: + # TODO might want to revise the logic here + # Process the first (_available_blocks) sequences that require new blocks + # Put the rest of the sequences back to recycled + seqs_req_new_blocks, seqs_to_recycle = ( + seqs_req_new_blocks[: self._available_blocks], + seqs_req_new_blocks[self._available_blocks :], + ) + for seq_id in seqs_to_recycle: + self.free_block_table(block_tables[seq_id]) + new_blocks_required = self._available_blocks + + # NOTE might want to alloc contiguous logic + free_block_ids = torch.nonzero(self._block_states > 0).view(-1) + alloc_block_ids = free_block_ids[:new_blocks_required].to( + dtype=block_tables.dtype, device=block_tables.device + ) + + for block_id in alloc_block_ids: + block: CacheBlock = self._cache_blocks[block_id] + block.add_ref() + self._block_states[block_id] = 0 + self._available_blocks -= 1 + block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids + block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes] + + for block_id in block_global_ids: + self._allocate_on_block(self._cache_blocks[block_id], 1) + + return seqs_to_recycle + + def allocate_n_tokens_from_block_tables( self, - size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: torch.device = torch.device("cuda"), - ): - self.logger = logging.get_logger(__name__) - self.available_size = size - self.max_len_in_batch = 0 - self._init_mem_states(size, device) - self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - - def _init_mem_states(self, size, device): - """Initialize tensors used to manage memory states""" - self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) - self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) - self.indexes = torch.arange(0, size, dtype=torch.long, device=device) - - def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): - """Initialize key buffer and value buffer on specified device""" - self.key_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + block_tables: torch.Tensor, + context_lens: torch.Tensor, + bsz: int, + n: int, + ) -> List[int]: + """Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage.""" + assert block_tables.dim() == 2 + assert context_lens.dim() == 1 + + bsz = block_tables.size(0) if bsz is None else bsz + assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1 + + seqs_to_recycle = [] + for i in range(n): + seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz) + + return seqs_to_recycle + + def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int: + """Allocate space asked on a single block in the block table, specified by the provided position id, + and updates the provided block table with the allocated block. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id. + block_local_idx: The index of the block in the block table. + space_asked: i.e. The number of tokens to be assigned space for. + Returns: + The remaining space required to be allocated (in other blocks). + """ + space_asked = 1 + block_global_id = block_table[block_local_idx].item() + if block_global_id < 0: + # Allocate a new block if the current position is not assigned a block yet + if self._available_blocks <= 0: + # No available blocks to allocate, we free current sequence and return it to + self.free_block_table(block_table) + return True + free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0] + block: CacheBlock = self._cache_blocks[free_block_id] + block.add_ref() + block_global_id = block.block_id + self._available_blocks -= 1 + self._block_states[block_global_id] = 0 + block_table[block_local_idx] = block_global_id + block: CacheBlock = self._cache_blocks[block_global_id] + return self._allocate_on_block(block, space_asked) + # only when space asked if fully satisfied, the return value will be zero. + + def free_block_table(self, block_table: torch.Tensor) -> None: + """Free the logical cache blocks for **a single sequence**.""" + assert block_table.dim() == 1 + for i, global_block_id in enumerate(block_table.tolist()): + if global_block_id < 0: + return + block: CacheBlock = self._cache_blocks[global_block_id] + block.remove_ref() + if not block.has_ref(): + block.allocated_size = 0 + self._available_blocks += 1 + self._block_states[global_block_id] = 1 + # reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine) + block_table[i] = -1 + + def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None: + """Release the logical cache blocks for a batch of sequences. + If `first_n` is provided, only the blocks for the first several sequences will be released. + """ + assert block_tables.dim() == 2 + first_n = block_tables.size(0) if first_n is None else first_n + for block_table in block_tables[:first_n]: + self.free_block_table(block_table) + + def clear_all(self) -> None: + """Clear all the references and allocations on all the cache blocks.""" + for block in self._cache_blocks: + block.clear() + self._available_blocks = self.num_blocks + self._block_states[:] = 1 + + def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" + return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] + + def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int: + """Allocate a specific size of space on a provided cache block. + + Returns: + The remaining space required to be allocated (in other blocks). + """ + assert block.available_space > 0, f"Found no available space left in the chosen block {block}." + space_to_allocate = min(block.available_space, space_asked) + block.allocate(space_to_allocate) + return space_asked - space_to_allocate + + def _init_logical_caches(self): + """Initialize the logical cache blocks. + + NOTE This function should be called only after the physical caches have been allocated. + The data pointers of physical caches will be binded to each logical cache block. + """ + assert self._kv_caches is not None and len(self._kv_caches[0]) > 0 + blocks = [] + physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size + k_ptrs = [ + self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) ] - self.value_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + v_ptrs = [ + self._kv_caches[1][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers) ] + for i in range(self.num_blocks): + k_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in k_ptrs] + v_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in v_ptrs] + cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs, v_ptrs) + blocks.append(cache_block) + return blocks - @torch.no_grad() - def alloc(self, required_size): - """allocate space of required_size by providing indexes representing available physical spaces""" - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) - select_index = self.indexes[select_index] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - return select_index - - @torch.no_grad() - def alloc_contiguous(self, required_size): - """allocate contiguous space of required_size""" - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - sum_size = len(self.mem_cum_sum) - loc_sums = ( - self.mem_cum_sum[required_size - 1 :] - - self.mem_cum_sum[0 : sum_size - required_size + 1] - + self.mem_state[0 : sum_size - required_size + 1] - ) - can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] - if can_used_loc.shape[0] == 0: + def _init_device_caches( + self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Initialize the physical cache on the device. + + For each layer of the model, we allocate two tensors for key and value respectively, + with shape of [num_blocks, num_kv_heads, block_size, head_size] + """ + k_cache: List[torch.Tensor] = [] + v_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device)) + v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) + return k_cache, v_cache + + +class RPCKVCacheManager(KVCacheManager): + def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: + self.logger = get_dist_logger(__name__) + self.device = get_current_device() + self.config = config + + # Parallel settings + self.tp_size = config.tp_size + # Model settings + self.dtype = config.dtype + self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() + self.num_layers = model_config.num_hidden_layers + self.head_num = model_config.num_attention_heads + self.head_size = model_config.hidden_size // self.head_num + if hasattr(model_config, "num_key_value_heads"): + self.kv_head_num = model_config.num_key_value_heads + else: + self.kv_head_num = self.head_num + + if config.kv_cache_dtype is None: + self.kv_cache_dtype = config.dtype + else: + self.kv_cache_dtype = config.kv_cache_dtype + + assert ( + self.kv_head_num % self.tp_size == 0 + ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" + self.kv_head_num //= self.tp_size + self.beam_width = config.beam_width + self.max_batch_size = config.max_batch_size + self.max_input_length = config.max_input_len + self.max_output_length = config.max_output_len + # Cache block settings + self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size + self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width + + # Logical cache blocks allocation + self._available_blocks = self.num_blocks + self._cache_blocks = tuple(self._init_logical_caches()) + # block availablity state 0->allocated, 1->free + self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool) + self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64) + self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64) + + def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + # Physical cache allocation + if self.config.use_cuda_kernel: + x = 16 // torch.tensor([], dtype=self.config.dtype).element_size() + kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x) + valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) self.logger.info( - f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks." ) - return None - start_loc = can_used_loc[0] - select_index = self.indexes[start_loc : start_loc + required_size] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - start = start_loc.item() - end = start + required_size - return select_index, start, end - - @torch.no_grad() - def free(self, free_index): - """free memory by updating memory states based on given indexes""" - self.available_size += free_index.shape[0] - self.mem_state[free_index] = 1 - - @torch.no_grad() - def free_all(self): - """free all memory by updating memory states""" - self.available_size = len(self.mem_state) - self.mem_state[:] = 1 - self.max_len_in_batch = 0 - # self.logger.info("freed all space of memory manager") + else: + alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) + kalloc_shape = alloc_shape + valloc_shape = alloc_shape + self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") + return kalloc_shape, valloc_shape + + def get_kv_cache(self): + """Get k_cache and v_cache""" + return NotImplementedError + + def _init_logical_caches(self): + """Initialize the logical cache blocks.""" + blocks = [] + for i in range(self.num_blocks): + cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None) + blocks.append(cache_block) + return blocks diff --git a/colossalai/inference/logit_processors.py b/colossalai/inference/logit_processors.py new file mode 100644 index 000000000000..ea73f833242e --- /dev/null +++ b/colossalai/inference/logit_processors.py @@ -0,0 +1,189 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py +import logging +from typing import List, Union + +import torch +import torch.nn.functional as F + +_LOGITS_PROCESSOR_MAP = {} + + +def register_logits_processor(process_type): + """ + register flops computation function for operation. + """ + + def register(func): + global _LOGITS_PROCESSOR_MAP + _LOGITS_PROCESSOR_MAP[process_type] = func + return func + + return register + + +@register_logits_processor("no_repeat_ngram_size") +def apply_no_repeat_ngram_size(logits, ngram_size: int, batch_token_ids: List[List[int]]): + """ + enforces no repetition of n-grams to avoid repetitions of word sequences. + """ + + if not isinstance(ngram_size, int) or ngram_size < 0: + raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") + + if ngram_size != 0: + batch_size = len(batch_token_ids) + + for batch_id in range(batch_size): + current_token_ids = batch_token_ids[batch_id] + current_len = len(current_token_ids) + if current_len + 1 < ngram_size: + continue + + ngrams_dict = {} + + for ngram in zip(*[current_token_ids[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]] + + prev_ngrams = tuple(current_token_ids[current_len + 1 - ngram_size : current_len]) + banned_token = ngrams_dict.get(prev_ngrams, []) + + logits[batch_id, banned_token] = -float("inf") + + return logits + + +@register_logits_processor("repetition_penalty") +def apply_repetition_penalty(logits, penalty: float, batch_token_ids: List[List[int]]): + """ + apply the penalty to the tokens present in the prompt. + """ + + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.") + + logits_list = [] + + # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. + if penalty != 1.0: + for batch_id in range(len(batch_token_ids)): + current_logit = logits[batch_id] + current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) + + curretn_socre = torch.gather(current_logit, 0, current_token) + curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty) + logits_list.append(current_logit.scatter(0, current_token, curretn_socre)) + + logits = torch.stack(logits_list) + + return logits + + +@register_logits_processor("temperature") +def apply_temperature(logits, temperature: float): + """ + apply temperature scaling. + """ + + if not isinstance(temperature, float) or not (0.0 < temperature <= 1.0): + except_msg = f"'temperature={temperature}' should be a strictly positive float, less than or equal to 1.0 and greater than 0." + if temperature == 0.0: + except_msg += "if you want to use greedy decoding strategies, set `do_sample=False`." + raise ValueError(except_msg) + + return logits if temperature == 1.0 else logits / temperature + + +@register_logits_processor("top_k") +def apply_top_k(logits, top_k: int): + """ + top_k logit processor + """ + + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` should be a strictly positive integer, but got {top_k}.") + + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float("inf") + return logits + + +@register_logits_processor("top_p") +def apply_top_p(logits, top_p: float): + """ + top_p logit processor + """ + + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` should be a float > 0 and < 1, but got {top_p}.") + + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + + sorted_indices_to_remove = torch.roll(sorted_indices_to_remove, 1, -1) + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits[indices_to_remove] = -float("inf") + return logits + + +@register_logits_processor("forced_eos_token_id") +def apply_forced_eos_token_id( + logits: torch.Tensor, + sequence_lengths: Union[torch.Tensor, List[int]], + max_lengths: Union[torch.Tensor, List[int]], + eos_token_id: Union[int, List[int]], +): + """ + Enforces the specified token as the last generated token when the maximum output length + is reached. Notice that the maximum output lengths for different sequences, even if they're + in the same batch, can be different. + + Args: + logits(torch.Tensor): logits + sequence_lengths(torch.Tensor): sequence lengths including prompt and output tokens + max_lengths(torch.Tensor): the maximum length for each sequence + eos_token_id(Union[int, List[int]]): forced eos token id + """ + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if isinstance(sequence_lengths, torch.Tensor): + sequence_lengths = sequence_lengths.tolist() + if isinstance(max_lengths, torch.Tensor): + max_lengths = max_lengths.tolist() + + select_indexes = [] + num_sequences = logits.shape[0] + sequence_lengths = sequence_lengths[:num_sequences] + max_lengths = max_lengths[:num_sequences] + for i, (sequence_length, max_out_length) in enumerate(zip(sequence_lengths, max_lengths)): + if sequence_length == max_out_length - 1: + select_indexes.append(i) + if select_indexes: + logits[select_indexes, :] = -float("inf") + logits[select_indexes, eos_token_id] = 0 + + return logits + + +def get_logits_processor(processor: str, logits, *args, **kwargs): + """ + do logit process for given logits. + + Args: + processor(str): the type of logit processor + logits(torch.Tensor): input logits + + Returns: + logits after process + """ + if processor not in _LOGITS_PROCESSOR_MAP: + logging.warning(f"Unsupported processor {processor}. Fall back to the original logits.") + else: + func = _LOGITS_PROCESSOR_MAP[processor] + logits = func(logits, *args, **kwargs) + + return logits diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/modeling/__init__.py similarity index 100% rename from colossalai/inference/quant/smoothquant/__init__.py rename to colossalai/inference/modeling/__init__.py diff --git a/colossalai/inference/modeling/layers/__init__.py b/colossalai/inference/modeling/layers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py new file mode 100644 index 000000000000..43ccdc430ef1 --- /dev/null +++ b/colossalai/inference/modeling/layers/attention.py @@ -0,0 +1,313 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + +def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): + """ + Func: copy key/value into key/value cache. + + Args: key/value(source): shape [bsz,seq_len,num_heads,head_size] + cache: shape [num_blocks, num_kv_heads, head_size, block_size] + lengths: key/value lengths + block_tables + """ + num_blocks, num_heads, block_size, head_size = cache.shape + bsz, max_blocks_per_seq = block_tables.shape + needed_blocks = (lengths + block_size - 1) // block_size + + if type == "prefill": + for i in range(bsz): + seq_len = lengths[i] + block_num = needed_blocks[i] + token_id = 0 + for block_idx in range(block_num - 1): + cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2) + token_id += block_size + cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute( + 1, 0, 2 + ) + elif type == "decoding": + assert source.size(1) == 1, "seq_len should be equal to 1 when decoding." + source = source.squeeze(1) + slot_idx = (lengths + block_size - 1) % block_size + for i in range(bsz): + cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i] + + return cache + + +def convert_kvcache(cache, lengths, block_tables, pad_id=0): + """ + Func: convert key/value cache for calculation + + Args: cache: shape [num_blocks, num_heads, block_size, head_size] + lengths: key/value length + block_tables + pad_id: padded_id + """ + num_blocks, num_heads, block_size, head_size = cache.shape + + needed_blocks = (lengths + block_size - 1) // block_size + num_remaing_tokens = lengths % block_size + num_remaing_tokens[num_remaing_tokens == 0] += block_size + bsz = block_tables.shape[0] + seq_len = max(lengths) + padded_cache = [] + for i in range(bsz): + _cache = torch.cat( + ( + cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size), + cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2), + ), + dim=0, + ) + padding = seq_len - _cache.size(0) + if padding > 0: + _cache = F.pad(_cache, (0, 0, 0, 0, 0, padding), value=pad_id) + padded_cache.append(_cache) + return torch.stack(padded_cache, dim=0) + + +class PagedAttention: + """ + Pure Torch implementation version of paged_attention. + Holds different types of forward function and useful components. + """ + + @staticmethod + def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size): + """ + Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] + """ + bsz = len(seq_lengths) + padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size, dtype=tensor.dtype) + + token_idx = 0 + for i, seq_len in enumerate(seq_lengths): + seq_tensor = tensor[token_idx : token_idx + seq_len] + padded_tensor[i, :seq_len, :, :] = seq_tensor + token_idx += seq_len + return padded_tensor + + @staticmethod + def generate_padding_mask(lengths, max_seq_len): + range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) + padding_mask = range_tensor < lengths.unsqueeze(1) + return padding_mask + + @staticmethod + def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor: + """ + Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim) + n_rep: times of repeatition. + Output: hidden_states (batch, num_attention_heads, seqlen, head_dim) + """ + if n_rep == 1: + return hidden_states + + batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape + num_attention_heads = n_rep * num_key_value_heads + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) + + return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim) + + @staticmethod + def nopad_context_forward( + q: torch.Tensor, # [num_tokens, num_heads, head_size] + k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] + v_cache: torch.Tensor, + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + """ + NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version. + """ + # Fisrt, do shape verification + num_tokens, num_heads, head_size = q.shape + num_kv_heads = k.shape[-2] + + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads + + block_size = k_cache.size(-2) + bsz, max_blocks_per_sequence = block_tables.shape + max_seq_len = max_blocks_per_sequence * block_size + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.shape[0] == k.shape[0] == v.shape[0] + assert context_lengths.shape[0] == block_tables.shape[0] + shape = (bsz, max_seq_len, num_heads, head_size) + input_shape = shape[:2] + + q = PagedAttention.pad_and_reshape( + q, context_lengths, max_seq_len, num_heads, head_size + ) # bsz,seqlen,num_heads,head_size + k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size) + v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size) + + copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) + + attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len) + + q = q.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) + + # position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device) + # position_ids = position_ids.unsqueeze(0) + # cos, sin = self.rotary_emb(value, max_seq_len) + # query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) + if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.") + + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != (bsz, num_heads, max_seq_len, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.") + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1) + + del attn_weights + + return attn_output + + @staticmethod + def pad_context_forward( + q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] + k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] + v_cache: torch.Tensor, + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths] + ): + # Firt, do shape verification + bsz, seq_len, num_heads, head_size = q.shape + num_kv_heads = k.shape[-2] + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads + block_size = k_cache.size(-2) + assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] + block_tables.shape[-1] * block_size + + # Copy kv to memory(rotary embedded) + copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables) + copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables) + + q = q.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) + + padding_mask = None + + if attn_mask is not None: + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len) + + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, seq_len), q.dtype, q.device, past_key_values_length=seq_len - seq_len + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) + + if attn_weights.size() != (bsz, num_heads, seq_len, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.") + if attn_mask is not None: + attn_weights += attn_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != (bsz, num_heads, seq_len, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.") + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) + + return attn_output + + @staticmethod + def pad_decoding_forward( + q: torch.Tensor, # [bsz, 1, num_heads, head_size] + k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] + v_cache: torch.Tensor, + lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + attn_mask: torch.Tensor = None, # [bsz, input_lengths + output_lengths] + ): + # Firt, do shape verification. + bsz, q_length, num_heads, head_size = q.shape + + num_kv_heads = k.shape[-2] + assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads" + num_kv_groups = num_heads // num_kv_heads + seq_len = max(lengths) + + assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0] + + copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding") + copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding") + + k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen, + v = convert_kvcache(v_cache, lengths, block_tables) + + q = q.transpose(1, 2) + k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups) + v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size) + if attn_weights.size() != (bsz, num_heads, 1, seq_len): + raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.") + + padding_mask = None + if attn_mask is not None: + padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length) + + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length + ) + + if padding_mask is not None: + attn_mask = attn_mask.masked_fill(padding_mask.bool(), torch.finfo(q.dtype).min) + + attn_weights += attn_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != (bsz, num_heads, 1, head_size): + raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) + + return attn_output + + @staticmethod + def no_pad_decoding_forward( + self, + q: torch.Tensor, # [num_tokens, num_heads, head_size] + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] + v_cache: torch.Tensor, + lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths + block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence] + ): + return self.pad_decoding_forward( + q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), k_cache, v_cache, lengths, block_tables + ) diff --git a/colossalai/inference/modeling/layers/baichuan_tp_linear.py b/colossalai/inference/modeling/layers/baichuan_tp_linear.py new file mode 100644 index 000000000000..e050dd71c8b2 --- /dev/null +++ b/colossalai/inference/modeling/layers/baichuan_tp_linear.py @@ -0,0 +1,43 @@ +from typing import List, Union + +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.shardformer.layer import Linear1D_Col +from colossalai.shardformer.layer.parallel_module import ParallelModule + + +class BaichuanLMHeadLinear1D_Col(Linear1D_Col): + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + module.in_features = module.weight.size(1) + module.out_features = module.weight.size(0) + module.bias = None + module.weight.data = nn.functional.normalize(module.weight) + + return Linear1D_Col.from_native_module( + module, + process_group, + *args, + **kwargs, + ) + + +class BaichuanWpackLinear1D_Col(Linear1D_Col): + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + in_features = module.in_features * 3 + out_features = module.out_features // 3 + module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features) + module.bias = None + + return Linear1D_Col.from_native_module( + module, + process_group, + *args, + **kwargs, + ) diff --git a/colossalai/inference/modeling/models/__init__.py b/colossalai/inference/modeling/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py new file mode 100644 index 000000000000..7b25f3e7489d --- /dev/null +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -0,0 +1,475 @@ +# This is modified from huggingface transformers +# https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py +import warnings +from types import MethodType +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaForCausalLM, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) + +from colossalai.inference.spec import GlideInput +from colossalai.kernel.triton import flash_decoding_attention +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_single_rotary_pos_emb(q, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +def glide_llama_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + glide_input: Optional[GlideInput] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + glide_input=glide_input, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def glide_llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + glide_input: GlideInput = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # GlideLlamaDecoderLayer + layer_outputs = decoder_layer( + hidden_states, + glide_input=glide_input, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GlideLlamaConfig(LlamaConfig): + """Configuration class with specific arguments used by GLIDE llama model as a drafter""" + + def __init__( + self, + large_hidden_size=4096, + large_num_attention_heads=32, + **kwargs, + ): + super().__init__(**kwargs) + self.large_hidden_size = large_hidden_size + self.large_num_attention_heads = large_num_attention_heads + + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GlideLlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # large model (verifier) configs + self.large_hidden_size = config.large_hidden_size + self.large_num_heads = config.large_num_attention_heads + self.large_head_dim = self.large_hidden_size // self.large_num_heads + + self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False) + self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.large_head_dim, + max_position_embeddings=self.max_position_embeddings, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.large_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.large_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + glide_input: GlideInput = None, # Used for glimpsing main model's KV caches + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Optional[torch.Tensor]: + bsz, q_len, _ = hidden_states.size() + + block_tables = glide_input.block_tables + large_k_cache = glide_input.large_k_cache + large_v_cache = glide_input.large_v_cache + sequence_lengths = glide_input.sequence_lengths + cache_block_size = large_k_cache.size(-2) + + query_states = self.q_proj(hidden_states) + kv_seq_len = sequence_lengths.max().item() + + query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) + + # for RoPE + cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32) + query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) + query_states = query_states.transpose(1, 2) + query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) + + attn_output = flash_decoding_attention( + q=query_states, + k_cache=large_k_cache, + v_cache=large_v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=cache_block_size, + max_seq_len_in_batch=kv_seq_len, + ) # attn_output: [bsz * q_len, num_heads * head_dim] + + attn_output = attn_output.reshape(bsz, q_len, self.large_hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +# A class to be used to replace LlamaDecoderLayer in a Llama Model as Drafter in speculative decoding. +# Refer to GLIDE with a CAPE https://arxiv.org/pdf/2402.02082.pdf +class GlideLlamaDecoderLayer(nn.Module): + def __init__(self, config: GlideLlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlamaDecoderLayer": + """Build a GlideLlamaDecoderLayer from a native LlamaDecoderLayer""" + config: LlamaConfig = module.mlp.config # XXX + layer_idx = module.self_attn.layer_idx + glide_config = GlideLlamaConfig(**config.to_dict()) + glide_decoder_layer = GlideLlamaDecoderLayer(glide_config, layer_idx=layer_idx) + + return glide_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + glide_input: GlideInput = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + curr_q_len = hidden_states.size(1) + # Cross attention + if glide_input is None or not glide_input.glimpse_ready: + warnings.warn( + "Data used for glimpsing the past KV caches of the main model (verifier) is not complete. " + "Fall back to normal decoder layer modeling (drafter). " + "This might lead to incorrect results when using the Glide Models for speculative decoding." + ) + elif curr_q_len == 1: + # Notice that we skip prefill stage + # always use the output of the main model as the inputs for the next round of speculation + residual = hidden_states + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + glide_input=glide_input, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=True, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GlideLlamaForCausalLM(LlamaForCausalLM): + def __init__(self, config: GlideLlamaConfig): + super().__init__(config) + self.config = config + bound_method = MethodType(glide_llama_causal_lm_forward, self) + setattr(self, "forward", bound_method) + bound_method = MethodType(glide_llama_model_forward, self.model) + model = getattr(self, "model") + setattr(model, "forward", bound_method) + replaced_layers = nn.ModuleList( + [GlideLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + setattr(model, "layers", replaced_layers) diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py new file mode 100644 index 000000000000..b50e73d6fcf4 --- /dev/null +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -0,0 +1,421 @@ +# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py +import itertools +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_k_to_blocked_cache, + decoding_fused_rotary_embedding, + flash_decoding_attention, + rms_layernorm, + rotary_embedding, +) +from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor + +logger = get_dist_logger(__name__) + +try: + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True +except ImportError: + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + +logger = get_dist_logger(__name__) + +try: + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True +except ImportError: + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + +inference_ops = InferenceOpsLoader().load() + +logger = get_dist_logger(__name__) + + +# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +def baichuan_rmsnorm_forward( + self, + hidden_states: torch.Tensor, + norm_output: torch.Tensor, + residual: torch.Tensor = None, + use_cuda_kernel: bool = True, +): + # Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b. + if hasattr(self, "variance_epsilon"): + eps = self.variance_epsilon + elif hasattr(self, "epsilon"): + eps = self.epsilon + else: + TypeError( + "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'." + ) + if use_cuda_kernel: + if residual is not None: + inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps) + return hidden_states, residual + + if norm_output is None: + norm_output = torch.empty_like(hidden_states) + inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps) + return norm_output, hidden_states + else: + return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual) + + +class NopadBaichuanAttention(ParallelModule): + def __init__( + self, + config, + attn_qproj_w: torch.Tensor = None, + attn_kproj_w: torch.Tensor = None, + attn_vproj_w: torch.Tensor = None, + attn_oproj: ParallelModule = None, + num_heads: int = None, + hidden_size: int = None, + process_group: ProcessGroup = None, + helper_layout: Layout = None, + ): + """This layer will replace the BaichuanAttention. + + Args: + config (BaichuanConfig): Holding the Baichuan model config. + attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. + attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. + attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. + """ + ParallelModule.__init__(self) + self.o_proj = attn_oproj + + self.config = config + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.process_group = process_group + qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] + self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) + + self.helper_layout = helper_layout + + self.alibi_slopes = None + self.use_alibi_attn = False + # Used for Baichuan13B + if config.hidden_size == 5120: + slopes_start = self.process_group.rank() * num_heads + self.use_alibi_attn = True + self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ + slopes_start : slopes_start + num_heads + ].contiguous() + self.alibi_slopes = nn.Parameter(self.alibi_slopes) + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> "NopadBaichuanAttention": + """Used for initialize the weight of NopadBaichuanAttention by origin BaichuanAttention. + + Args: + module (nn.Module): The origin BaichuanAttention layer. + """ + + config = module.config + q_proj_w, k_proj_w, v_proj_w = module.W_pack.weight.view((module.hidden_size, 3, -1)).transpose(0, 1) + + attn_qproj_w = q_proj_w + attn_kproj_w = k_proj_w + attn_vproj_w = v_proj_w + attn_oproj = module.o_proj + + helper_layout = ( + module.W_pack.weight.dist_layout + ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + + attn_layer = NopadBaichuanAttention( + config=config, + attn_qproj_w=attn_qproj_w, + attn_kproj_w=attn_kproj_w, + attn_vproj_w=attn_vproj_w, + attn_oproj=attn_oproj, + num_heads=module.num_heads, + hidden_size=module.hidden_size, + process_group=process_group, + helper_layout=helper_layout, + ) + + return attn_layer + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + key = "qkv_weight" + qkv_w = state_dict[prefix + "W_pack.weight"] + + in_features = qkv_w.size(1) + out_features = qkv_w.size(0) // 3 + + qkv_w.data = qkv_w.view((3, out_features, -1)).transpose(0, 1).reshape(out_features, in_features * 3) + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + qkv_w = distribute_tensor(qkv_w, device_mesh, sharding_spec) + + qkv_w = qkv_w.transpose(0, 1).reshape(3, in_features, -1) + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + + param = local_state[key] + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + strict = False # to avoid unexpected_keys + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward( + self, + hidden_states: torch.Tensor, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, + is_prompts: bool = True, + is_verifier: bool = False, + tokens_to_verify: int = None, + kv_seq_len: int = 0, + output_tensor: torch.Tensor = None, + sm_scale: int = None, + use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + """ + + token_nums = hidden_states.size(0) + # fused qkv + hidden_states = hidden_states.expand(3, -1, -1) + query_states, key_states, value_states = ( + torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) + ) + + block_size = k_cache.size(-2) + + if is_prompts: + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + if not self.use_alibi_attn: + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ) + attn_output = attn_output.view(token_nums, -1) + else: + if not self.use_alibi_attn: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + alibi_slopes=self.alibi_slopes, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, + ) + else: + q_len = tokens_to_verify + 1 if is_verifier else 1 + + if use_cuda_kernel: + if not self.use_alibi_attn: + inference_ops.rotary_embedding_and_cache_copy( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + sequence_lengths, + block_tables, + high_precision, + ) + else: + inference_ops.decode_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables + ) + inference_ops.flash_decoding_attention( + output_tensor, + query_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + block_size, + kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + self.alibi_slopes, + sm_scale, + ) + attn_output = output_tensor + else: + if not is_verifier and not self.use_alibi_attn: + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, + ) + else: + if not self.use_alibi_attn: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + copy_k_to_blocked_cache( + key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + copy_k_to_blocked_cache( + value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=self.alibi_slopes, + sm_scale=sm_scale, + q_len=q_len, + ) + + attn_output = attn_output.view(-1, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output + + def extra_repr(self) -> str: + return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" + + +# NOTE This will cause difference as out length increases. +class NopadBaichuanMLP(NopadLlamaMLP): + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + """Used for initialize the weight of NopadBaichuanMLP by origin MLP(Baichuan). + + Args: + module (nn.Module): The origin MLP(Baichuan) layer. + """ + mlp_gproj_w = module.gate_proj.weight + assert is_distributed_tensor( + module.gate_proj.weight + ), "gate_proj.weight must be dtensor so we could get the layout of the weight" + mlp_uproj_w = module.up_proj.weight + mlp_dproj = module.down_proj + + mlp_layer = NopadBaichuanMLP( + config=None, + mlp_gproj_w=mlp_gproj_w, + mlp_uproj_w=mlp_uproj_w, + mlp_dproj=mlp_dproj, + process_group=process_group, + ) + + return mlp_layer diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py new file mode 100644 index 000000000000..f6f160eb7e96 --- /dev/null +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -0,0 +1,697 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +import itertools +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed import ProcessGroup +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, +) + +from colossalai.inference.config import InputMetaData +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_k_to_blocked_cache, + decoding_fused_rotary_embedding, + flash_decoding_attention, + get_xine_cache, + rms_layernorm, + rotary_embedding, +) +from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor + +inference_ops = InferenceOpsLoader().load() + +logger = get_dist_logger(__name__) + +try: + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True +except ImportError: + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + + +def llama_causal_lm_forward( + self: LlamaForCausalLM, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +) -> torch.Tensor: + """This function will replace the forward function of LlamaForCausalLM. + + Args: + batch (BatchInfo): It stores the necessary input information for this inference. + k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. + v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + """ + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = llama_model_forward( + self.model, + input_tokens_ids=input_tokens_ids, + output_tensor=output_tensor, + inputmetadata=inputmetadata, + k_caches=k_caches, + v_caches=v_caches, + use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could + high_precision=inputmetadata.high_precision, + ) + + logits = self.lm_head(hidden_states) + return logits + + +def llama_model_forward( + self: LlamaModel, + input_tokens_ids: torch.Tensor, + output_tensor: torch.Tensor, + inputmetadata: InputMetaData, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, + use_cuda_kernel: Optional[bool] = True, + high_precision: bool = False, +) -> torch.Tensor: + """This function will replace the forward function of LlamaModel. + + Args: + batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. + k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. + v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + """ + block_tables = inputmetadata.block_tables + sequence_lengths = inputmetadata.sequence_lengths + kv_seq_len = inputmetadata.kv_seq_len + + # NOTE (yuanheng-zhao): fow now, only triton kernels support verification process + # during speculative-decoding (`q_len > 1`) + # We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled + if inputmetadata.use_spec_dec and use_cuda_kernel: + use_cuda_kernel = False + logger.warning("CUDA kernel is disabled for speculative-decoding.") + + hidden_states = self.embed_tokens(input_tokens_ids) + + cu_seqlens = None + + # NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now + if inputmetadata.use_spec_dec: + # For speculative-decoding Prefill and Verifying Stage + if inputmetadata.is_prompts: + # output tensor shape is the same as normal Prefill Stage + rotary_indexes = [torch.arange(0, length) for length in sequence_lengths] + else: + # the number of tokens to be verified in parallel plus the correct token in the last step + n_tokens = inputmetadata.num_tokens_to_verify + 1 + assert n_tokens == hidden_states.size(0) + rotary_indexes = [(length - n_tokens + i).view(-1) for i in range(n_tokens) for length in sequence_lengths] + rotary_indexes = torch.cat(rotary_indexes, dim=-1) + cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes]) + + elif use_cuda_kernel: + if inputmetadata.dtype != torch.float32 and use_flash_attn2: + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + hidden_dim = self._cos_cached.size(-1) + total_length = hidden_states.size(0) + cos = torch.empty((total_length, hidden_dim), dtype=self._cos_cached.dtype, device=self._cos_cached.device) + sin = torch.empty((total_length, hidden_dim), dtype=self._sin_cached.dtype, device=self._sin_cached.device) + inference_ops.get_cos_and_sin( + self._cos_cached, self._sin_cached, cos, sin, sequence_lengths, kv_seq_len, inputmetadata.is_prompts + ) + cos_sin = (cos, sin) + else: + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) + + sm_scale = 1.0 / (inputmetadata.head_dim**0.5) + + norm_output = torch.empty_like(hidden_states) + tokens_to_verify = inputmetadata.num_tokens_to_verify if inputmetadata.use_spec_dec else None + residual = None + + for layer_id, decoder_layer in enumerate(self.layers): + hidden_states, residual = decoder_layer( + hidden_states, + residual=residual, + block_tables=block_tables, + k_cache=k_caches[layer_id], + v_cache=v_caches[layer_id], + is_prompts=inputmetadata.is_prompts, + is_verifier=inputmetadata.use_spec_dec, + tokens_to_verify=tokens_to_verify, + sequence_lengths=sequence_lengths, + cos_sin=cos_sin, + fd_inter_tensor=inputmetadata.fd_inter_tensor, + kv_seq_len=kv_seq_len, + output_tensor=output_tensor, + norm_output=norm_output, + sm_scale=sm_scale, + use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, + ) + + if inputmetadata.is_prompts: + seq_len_cumsum = sequence_lengths.cumsum(dim=0) + hidden_states = hidden_states[seq_len_cumsum - 1].contiguous() + residual = residual[seq_len_cumsum - 1].contiguous() + norm_output = torch.empty_like(hidden_states) + hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) + + return hidden_states + + +def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + residual: torch.Tensor, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, + is_prompts: bool = True, + is_verifier: bool = False, + tokens_to_verify: int = None, + kv_seq_len: int = 0, + output_tensor: torch.Tensor = None, + norm_output: torch.Tensor = None, + sm_scale: int = None, + use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """This function will replace the forward function of LlamaDecoderLayer. + + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor]): Holding cos and sin. + fd_inter_tensor (FDIntermTensors): Holding tensors used for + storing intermediate values in flash-decoding. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + """ + + hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + block_tables=block_tables, + k_cache=k_cache, + v_cache=v_cache, + is_prompts=is_prompts, + is_verifier=is_verifier, + tokens_to_verify=tokens_to_verify, + sequence_lengths=sequence_lengths, + cos_sin=cos_sin, + fd_inter_tensor=fd_inter_tensor, + kv_seq_len=kv_seq_len, + output_tensor=output_tensor, + sm_scale=sm_scale, + use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +def llama_rmsnorm_forward( + self: LlamaRMSNorm, + hidden_states: torch.Tensor, + norm_output: torch.Tensor, + residual: torch.Tensor = None, + use_cuda_kernel: bool = True, +): + if use_cuda_kernel: + if residual is not None: + inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon) + return hidden_states, residual + + if norm_output is None: + norm_output = torch.empty_like(hidden_states) + inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon) + return norm_output, hidden_states + else: + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) + + +class NopadLlamaMLP(LlamaMLP, ParallelModule): + def __init__( + self, + config: LlamaConfig, + mlp_gproj_w: torch.Tensor = None, + mlp_uproj_w: torch.Tensor = None, + mlp_dproj: ParallelModule = None, + process_group: ProcessGroup = None, + ): + """A Unified Layer for + + Args: + config (LlamaConfig): Holding the Llama model config. + mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None. + mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None. + mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None. + """ + ParallelModule.__init__(self) + self.config = config + assert is_distributed_tensor( + mlp_gproj_w + ), "mlp_gproj_w must be dtensor so we could get the layout of the weight" + self.helper_layout = ( + mlp_gproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict) + self.gate_up_weight = nn.Parameter( + torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0) + ) + self.down_proj = mlp_dproj + self.process_group = process_group + + @staticmethod + def from_native_module( + module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + """Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP. + + Args: + module (LlamaMLP): The origin LlamaMLP layer. + """ + + config = module.config + + mlp_gproj_w = module.gate_proj.weight + assert is_distributed_tensor( + module.gate_proj.weight + ), "gate_proj.weight must be dtensor so we could get the layout of the weight" + mlp_uproj_w = module.up_proj.weight + mlp_dproj = module.down_proj + + mlp_layer = NopadLlamaMLP( + config=config, + mlp_gproj_w=mlp_gproj_w, + mlp_uproj_w=mlp_uproj_w, + mlp_dproj=mlp_dproj, + process_group=process_group, + ) + + return mlp_layer + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + # NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight) + + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + key = "gate_up_weight" + k1 = "gate_proj.weight" + k2 = "up_proj.weight" + + gate_w = state_dict[prefix + k1] + up_w = state_dict[prefix + k2] + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec) + up_w = distribute_tensor(up_w, device_mesh, sharding_spec) + + gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0) + + input_param = nn.Parameter( + gate_up_w + ) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + param = local_state[key] + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + strict = False # to avoid unexpected_keys + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + """ + hidden_states = hidden_states.expand(2, -1, -1) + gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) + act_out = inference_ops.silu_and_mul(gate_up_proj_out) + + return self.down_proj(act_out) + + def extra_repr(self) -> str: + return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False" + + +class NopadLlamaAttention(LlamaAttention, ParallelModule): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + attn_qproj_w: torch.Tensor = None, + attn_kproj_w: torch.Tensor = None, + attn_vproj_w: torch.Tensor = None, + attn_oproj: ParallelModule = None, + process_group: ProcessGroup = None, + num_heads: int = None, + hidden_size: int = None, + num_key_value_heads: int = None, + ): + """This layer will replace the LlamaAttention. + + Args: + config (LlamaConfig): Holding the Llama model config. + layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None. + attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None. + attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None. + attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None. + attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None. + """ + ParallelModule.__init__(self) + self.config = config + self.layer_idx = layer_idx + + self.o_proj = attn_oproj + self.process_group = process_group + + self.attention_dropout = config.attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if self.num_heads == self.num_key_value_heads: + qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)] + self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0)) + self.helper_layout = ( + attn_qproj_w.dist_layout + ) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict) + else: + self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous()) + self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous()) + self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous()) + + @staticmethod + def from_native_module( + module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention. + + Args: + module (LlamaAttention): The origin LlamaAttention layer. + """ + + config = module.config + layer_idx = module.layer_idx + + attn_qproj_w = module.q_proj.weight + attn_kproj_w = module.k_proj.weight + attn_vproj_w = module.v_proj.weight + assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor" + attn_oproj = module.o_proj + + attn_layer = NopadLlamaAttention( + config=config, + layer_idx=layer_idx, + attn_qproj_w=attn_qproj_w, + attn_kproj_w=attn_kproj_w, + attn_vproj_w=attn_vproj_w, + attn_oproj=attn_oproj, + process_group=process_group, + num_heads=module.num_heads, + hidden_size=module.hidden_size, + num_key_value_heads=module.num_key_value_heads, + ) + + return attn_layer + + # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, + is_prompts: bool = True, + is_verifier: bool = False, + tokens_to_verify: int = None, + kv_seq_len: int = 0, + output_tensor: torch.Tensor = None, + sm_scale: int = None, + use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Args: + hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. + fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for + storing intermediate values in flash-decoding. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. + output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. + sm_scale (int, optional): Used for flash attention. Defaults to None. + use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. + """ + + token_nums = hidden_states.size(0) + + if self.num_heads != self.num_key_value_heads: + query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) + key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) + value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) + else: + # fused qkv + hidden_states = hidden_states.expand(3, -1, -1) + query_states, key_states, value_states = ( + torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) + ) + + block_size = k_cache.size(-2) + + if is_prompts: + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) + else: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + use_new_kcache_layout=use_cuda_kernel, + ) + else: + q_len = tokens_to_verify + 1 if is_verifier else 1 + + if use_cuda_kernel: + inference_ops.rotary_embedding_and_cache_copy( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + sequence_lengths, + block_tables, + high_precision, + ) + inference_ops.flash_decoding_attention( + output_tensor, + query_states, + k_cache, + v_cache, + sequence_lengths, + block_tables, + block_size, + kv_seq_len, + fd_inter_tensor.mid_output, + fd_inter_tensor.exp_sums, + fd_inter_tensor.max_logits, + None, + sm_scale, + ) + attn_output = output_tensor + else: + if is_verifier: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + copy_k_to_blocked_cache( + key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + copy_k_to_blocked_cache( + value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + else: + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, + ) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + kv_group_num=self.num_key_value_groups, + q_len=q_len, + ) + + attn_output = attn_output.view(-1, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.num_heads == self.num_key_value_heads: + # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight) + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + key = "qkv_weight" + k1 = "q_proj.weight" + k2 = "k_proj.weight" + k3 = "v_proj.weight" + q_w = state_dict[prefix + k1] + k_w = state_dict[prefix + k2] + v_w = state_dict[prefix + k3] + + device_mesh = self.helper_layout.device_mesh + sharding_spec = self.helper_layout.sharding_spec + q_w = distribute_tensor(q_w, device_mesh, sharding_spec) + k_w = distribute_tensor(k_w, device_mesh, sharding_spec) + v_w = distribute_tensor(v_w, device_mesh, sharding_spec) + + qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0) + + input_param = nn.Parameter( + qkv_w + ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param) + + param = local_state[key] + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + + strict = False # to avoid unexpected_keys + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def extra_repr(self) -> str: + return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False" diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py new file mode 100644 index 000000000000..fa03955907fe --- /dev/null +++ b/colossalai/inference/modeling/policy/__init__.py @@ -0,0 +1,16 @@ +from .glide_llama import GlideLlamaModelPolicy +from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy +from .nopadding_llama import NoPaddingLlamaModelInferPolicy + +model_policy_map = { + "nopadding_llama": NoPaddingLlamaModelInferPolicy, + "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy, + "glide_llama": GlideLlamaModelPolicy, +} + +__all__ = [ + "NoPaddingLlamaModelInferPolicy", + "NoPaddingBaichuanModelInferPolicy", + "GlideLlamaModelPolicy", + "model_polic_map", +] diff --git a/colossalai/inference/modeling/policy/glide_llama.py b/colossalai/inference/modeling/policy/glide_llama.py new file mode 100644 index 000000000000..817b3324ed7d --- /dev/null +++ b/colossalai/inference/modeling/policy/glide_llama.py @@ -0,0 +1,45 @@ +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel + +from colossalai.inference.modeling.models.glide_llama import ( + GlideLlamaDecoderLayer, + glide_llama_causal_lm_forward, + glide_llama_model_forward, +) +from colossalai.inference.utils import init_to_get_rotary +from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + + +class GlideLlamaModelPolicy(LlamaForCausalLMPolicy): + def module_policy(self): + policy = super().module_policy() + + num_layers = self.model.config.num_hidden_layers + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix=f"layers[{i}]", + target_module=GlideLlamaDecoderLayer, + ) + for i in range(num_layers) + ], + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={"forward": glide_llama_model_forward}, + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={"forward": glide_llama_causal_lm_forward}, + policy=policy, + target_key=LlamaForCausalLM, + ) + + return policy + + def postprocess(self): + for layer in self.model.model.layers: + init_to_get_rotary(layer.cross_attn) + return self.model diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py new file mode 100644 index 000000000000..78268d6e7e85 --- /dev/null +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -0,0 +1,110 @@ +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.layers.baichuan_tp_linear import ( + BaichuanLMHeadLinear1D_Col, + BaichuanWpackLinear1D_Col, +) +from colossalai.inference.modeling.models.nopadding_baichuan import ( + NopadBaichuanAttention, + NopadBaichuanMLP, + baichuan_rmsnorm_forward, +) +from colossalai.inference.modeling.models.nopadding_llama import ( + llama_causal_lm_forward, + llama_decoder_layer_forward, + llama_model_forward, +) +from colossalai.inference.utils import init_to_get_rotary +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + + +class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + else: + decoder_attribute_replacement = None + + # used for Baichuan 7B and 13B for baichuan DecoderLayer + for DecoderLayer in ["DecoderLayer", "BaichuanLayer"]: + policy[DecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp", + target_module=NopadBaichuanMLP, + ), + SubModuleReplacementDescription( + suffix="self_attn.W_pack", + target_module=BaichuanWpackLinear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attn", + target_module=NopadBaichuanAttention, + ), + ], + ) + + self.append_or_create_method_replacement( + description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=DecoderLayer + ) + + policy["BaichuanForCausalLM"] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=BaichuanLMHeadLinear1D_Col, kwargs={"gather_output": True} + ) + ], + ) + + self.append_or_create_method_replacement( + description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM" + ) + self.append_or_create_method_replacement( + description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel" + ) + self.append_or_create_method_replacement( + description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm" + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "NoPaddingBaichuanModelInferPolicy": + return NoPaddingBaichuanModelInferPolicy() diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py new file mode 100644 index 000000000000..24cf7c740b10 --- /dev/null +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -0,0 +1,112 @@ +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.models.nopadding_llama import ( + NopadLlamaAttention, + NopadLlamaMLP, + llama_causal_lm_forward, + llama_decoder_layer_forward, + llama_model_forward, + llama_rmsnorm_forward, +) +from colossalai.inference.utils import init_to_get_rotary +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + + +class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + else: + decoder_attribute_replacement = None + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp", + target_module=NopadLlamaMLP, + ), + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attn", + target_module=NopadLlamaAttention, + ), + ], + ) + + policy[LlamaForCausalLM] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True} + ) + ], + ) + + # self.shard_config._infer() + self.append_or_create_method_replacement( + description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM + ) + self.append_or_create_method_replacement( + description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel + ) + self.append_or_create_method_replacement( + description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer + ) + self.append_or_create_method_replacement( + description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model, self.model.config.rope_theta) + return self.model + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy": + return NoPaddingLlamaModelInferPolicy() diff --git a/colossalai/inference/quant/__init__.py b/colossalai/inference/quant/__init__.py deleted file mode 100644 index 18e0de9cc9fc..000000000000 --- a/colossalai/inference/quant/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .smoothquant.models.llama import SmoothLlamaForCausalLM diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py deleted file mode 100644 index 4cf1fd658a41..000000000000 --- a/colossalai/inference/quant/gptq/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .cai_gptq import HAS_AUTO_GPTQ - -if HAS_AUTO_GPTQ: - from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear - from .gptq_manager import GPTQManager diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py deleted file mode 100644 index 4ed76293bd81..000000000000 --- a/colossalai/inference/quant/gptq/cai_gptq/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -import warnings - -HAS_AUTO_GPTQ = False -try: - import auto_gptq - - HAS_AUTO_GPTQ = True -except ImportError: - warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ") - HAS_AUTO_GPTQ = False - -if HAS_AUTO_GPTQ: - from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear - from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py deleted file mode 100644 index 36339ac88486..000000000000 --- a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py +++ /dev/null @@ -1,354 +0,0 @@ -# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ - -import math -import warnings -from typing import List, Union - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import ParallelModule - -from .gptq_op import CaiGPTQLinearOp - -HAS_GPTQ_CUDA = False -try: - from colossalai.kernel.op_builder.gptq import GPTQBuilder - - gptq_cuda = GPTQBuilder().load() - HAS_GPTQ_CUDA = True -except ImportError: - warnings.warn("CUDA gptq is not installed") - HAS_GPTQ_CUDA = False - - -class CaiQuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - super().__init__() - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize if groupsize != -1 else infeatures - - self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) - self.register_buffer( - "qzeros", - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), - ) - self.register_buffer( - "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) - ) - if row_split: - self.register_buffer( - "g_idx", - torch.tensor( - [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 - ), - ) - else: - self.register_buffer( - "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) - ) - - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) - else: - self.bias = None - - self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) - - self.q4 = None - self.empty_tensor = torch.empty((1, 1), device="meta") - self.tp_size = tp_size - self.tp_rank = tp_rank - self.row_split = row_split - - def pack(self, linear, scales, zeros, g_idx=None): - g_idx = ( - g_idx.clone() - if g_idx is not None - else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) - ) - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - half_scales = scales.clone().half() - # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - pbits = 32 - ptype = torch.int32 - unsign_type = np.uint32 - sign_type = np.int32 - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ - :, None - ] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(unsign_type) - qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) - - i = 0 - row = 0 - - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (pbits // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += pbits // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qweight = qweight.astype(sign_type) - qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() # .to("cuda") - self.qweight.data.copy_(qweight1) - - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) - zeros -= 1 - zeros = zeros.numpy().astype(unsign_type) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (pbits // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += pbits // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qzeros = qzeros.astype(sign_type) - qzeros = torch.from_numpy(qzeros) - qzeros = qzeros - self.qzeros.data.copy_(qzeros) - - if torch.equal(self.g_idx.to(g_idx.device), g_idx): - self.g_idx = None - else: - self.g_idx = g_idx - - def init_q4(self): - assert self.qweight.device.type == "cuda" - self.q4_width = self.qweight.shape[1] - if self.g_idx is not None: - if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device, - ), - ): - self.g_idx = None - elif torch.equal( - self.g_idx, - torch.tensor( - [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device - ), - ): - self.g_idx = None - - if self.g_idx is not None: - g_idx = self.g_idx.to("cpu") - else: - g_idx = self.empty_tensor - - self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) - torch.cuda.synchronize() - - def forward(self, x): - outshape = x.shape[:-1] + (self.outfeatures,) - - if HAS_GPTQ_CUDA and self.bits == 4: - if self.q4 is None: - self.init_q4() - - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) - gptq_cuda.q4_matmul(x.half(), self.q4, output) - if self.bias is not None and (not self.row_split or self.tp_size == 1): - output.add_(self.bias) - else: - if self.bias is not None and (not self.row_split or self.tp_size == 1): - bias = self.bias - else: - bias = None - output = self.gptq_linear( - x, - self.qweight, - self.scales, - self.qzeros, - g_idx=self.g_idx, - bias=bias, - ) - return output.view(outshape) - - -def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) - qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) - scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) - g_idx = gptq_linear.g_idx - if gptq_linear.bias is not None: - bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) - - cai_split_out_features = cai_linear.outfeatures // split_num - zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num - - for i in range(split_num): - cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ - :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] - cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ - :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block - ] - cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ - :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] - if cai_linear.bias is not None: - cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ - tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] - - cai_linear.g_idx.copy_(g_idx) - - -def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) - qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) - scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) - g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) - - cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num - zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num - idx_split_features = cai_linear.infeatures // split_num - - for i in range(split_num): - cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ - tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : - ] - cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ - tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : - ] - cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ - tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : - ] - cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ - tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features - ] - if cai_linear.bias is not None: - cai_linear.bias.copy_(gptq_linear.bias) - - -class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - super().__init__( - bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split - ) - self.process_group = None - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - in_features = module.in_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if in_features < tp_size: - return module - - if in_features % tp_size != 0: - raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = RowCaiQuantLinear( - module.bits, - module.group_size, - module.in_features // tp_size, - module.out_features, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=True, - ) - linear_1d.process_group = process_group - - split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - return linear_1d - - def forward(self, x): - output = super().forward(x) - if self.tp_size > 1: - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) - if self.bias is not None: - output.add_(self.bias) - return output - - -class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - super().__init__( - bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split - ) - self.process_group = None - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - in_features = module.in_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if in_features < tp_size: - return module - - if in_features % tp_size != 0: - raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = ColCaiQuantLinear( - module.bits, - module.group_size, - module.in_features, - module.out_features // tp_size, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - ) - linear_1d.process_group = process_group - - split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - return linear_1d diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py deleted file mode 100644 index a8902eb35cd0..000000000000 --- a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch - -from colossalai.kernel.triton import gptq_fused_linear_triton - - -class CaiGPTQLinearOp(torch.nn.Module): - def __init__(self, gptq_group_size, gptq_quant_bits): - super(CaiGPTQLinearOp, self).__init__() - self.group_size = gptq_group_size - self.bits = gptq_quant_bits - self.maxq = 2**self.bits - 1 - self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) - - def forward( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zeros: torch.Tensor, - g_idx: torch.Tensor = None, - act_type=0, - bias: torch.Tensor = None, - residual: torch.Tensor = None, - qkv_fused=False, - ): - add_bias = True - if bias is None: - bias = self.empty_tensor - add_bias = False - - add_residual = True - if residual is None: - residual = self.empty_tensor - add_residual = False - x = input.view(-1, input.shape[-1]) - - out = gptq_fused_linear_triton( - x, - weight, - weight_scales, - weight_zeros, - bias, - residual, - self.bits, - self.maxq, - self.group_size, - qkv_fused, - add_bias, - add_residual, - act_type=act_type, - g_idx=g_idx, - ) - if qkv_fused: - out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) - else: - out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) - - return out diff --git a/colossalai/inference/quant/gptq/gptq_manager.py b/colossalai/inference/quant/gptq/gptq_manager.py deleted file mode 100644 index 2d352fbef2b9..000000000000 --- a/colossalai/inference/quant/gptq/gptq_manager.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - - -class GPTQManager: - def __init__(self, quant_config, max_input_len: int = 1): - self.max_dq_buffer_size = 1 - self.max_inner_outer_dim = 1 - self.bits = quant_config.bits - self.use_act_order = quant_config.desc_act - self.max_input_len = 1 - self.gptq_temp_state_buffer = None - self.gptq_temp_dq_buffer = None - self.quant_config = quant_config - - def post_init_gptq_buffer(self, model: torch.nn.Module) -> None: - from .cai_gptq import CaiQuantLinear - - HAS_GPTQ_CUDA = False - try: - from colossalai.kernel.op_builder.gptq import GPTQBuilder - - gptq_cuda = GPTQBuilder().load() - HAS_GPTQ_CUDA = True - except ImportError: - warnings.warn("CUDA gptq is not installed") - HAS_GPTQ_CUDA = False - - for name, submodule in model.named_modules(): - if isinstance(submodule, CaiQuantLinear): - self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) - - if self.use_act_order: - self.max_inner_outer_dim = max( - self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures - ) - self.bits = submodule.bits - if not (HAS_GPTQ_CUDA and self.bits == 4): - return - - max_input_len = 1 - if self.use_act_order: - max_input_len = self.max_input_len - # The temp_state buffer is required to reorder X in the act-order case. - # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros( - (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() - ) - self.gptq_temp_dq_buffer = torch.zeros( - (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() - ) - - gptq_cuda.prepare_buffers( - torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer - ) - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - torch.cuda.empty_cache() diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py deleted file mode 100644 index 1663028da138..000000000000 --- a/colossalai/inference/quant/smoothquant/models/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -try: - import torch_int - - HAS_TORCH_INT = True -except ImportError: - HAS_TORCH_INT = False - print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") - -if HAS_TORCH_INT: - from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py deleted file mode 100644 index f3afe5d83bb0..000000000000 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ /dev/null @@ -1,494 +0,0 @@ -# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ -# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py -# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py - -import os -import warnings -from abc import abstractmethod -from functools import partial -from os.path import isdir, isfile, join -from typing import Dict, List, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -import transformers -from safetensors.torch import save_file as safe_save -from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel -from transformers.modeling_utils import no_init_weights -from transformers.utils.generic import ContextManagers -from transformers.utils.hub import PushToHubMixin, cached_file - -from colossalai.inference.kv_cache.batch_infer_state import BatchInferState, MemoryManager - -try: - import accelerate - - HAS_ACCELERATE = True -except ImportError: - HAS_ACCELERATE = False - print("accelerate is not installed.") - - -SUPPORTED_MODELS = ["llama"] - - -class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): - layer_type: str = None - - def __init__(self, model: PreTrainedModel, quantized: bool = False): - super().__init__() - - self.model = model - self.model_type = self.model.config.model_type - self._quantized = quantized - self.config = self.model.config - self.cache_manager = None - self.max_total_token_num = 0 - - @property - def quantized(self): - return self._quantized - - def init_cache_manager(self, max_total_token_num=2048): - if self.config.model_type == "llama": - head_num = self.config.num_key_value_heads - layer_num = self.config.num_hidden_layers - head_dim = self.config.hidden_size // head_num - - self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) - self.max_total_token_num = max_total_token_num - - def init_batch_state(self, max_output_len=256, **kwargs): - input_ids = kwargs["input_ids"] - batch_size = len(input_ids) - - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - start_index = 0 - max_len_in_batch = -1 - - for i in range(batch_size): - seq_len = len(input_ids[i]) - seq_lengths[i] = seq_len - seq_start_indexes[i] = start_index - start_index += seq_len - max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch - - if "max_total_token_num" in kwargs.keys(): - max_total_token_num = kwargs["max_total_token_num"] - self.init_cache_manager(max_total_token_num) - - if "max_new_tokens" in kwargs.keys(): - max_output_len = kwargs["max_new_tokens"] - - if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: - max_total_token_num = batch_size * (max_len_in_batch + max_output_len) - warnings.warn(f"reset max tokens to {max_total_token_num}") - self.init_cache_manager(max_total_token_num) - - block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") - batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to("cuda") - batch_infer_state.start_loc = seq_start_indexes.to("cuda") - batch_infer_state.block_loc = block_loc - batch_infer_state.decode_layer_id = 0 - batch_infer_state.is_context_stage = True - batch_infer_state.set_cache_manager(self.cache_manager) - batch_infer_state.cache_manager.free_all() - return batch_infer_state - - @abstractmethod - @torch.inference_mode() - def quantize( - self, - examples: List[Dict[str, Union[List[int], torch.LongTensor]]], - ): - if self.quantized: - raise EnvironmentError("can't execute quantize because the model is quantized.") - - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def generate(self, **kwargs): - """shortcut for model.generate""" - - batch_infer_state = self.init_batch_state(**kwargs) - if self.config.model_type == "llama": - setattr(self.model.model, "infer_state", batch_infer_state) - - with torch.inference_mode(): - return self.model.generate(**kwargs) - - def prepare_inputs_for_generation(self, *args, **kwargs): - """shortcut for model.prepare_inputs_for_generation""" - return self.model.prepare_inputs_for_generation(*args, **kwargs) - - def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): - for text in tqdm(dataset): - input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) - model(input_ids) - - def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): - pbar = tqdm(dataset) - for text in pbar: - input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) - model(input_ids) - mean_scale = np.mean([v["input"] for v in act_dict.values()]) - pbar.set_description(f"Mean input scale: {mean_scale:.2f}") - - # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py - def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = {} - - def stat_tensor(name, tensor): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float().cpu() - if name in act_scales: - act_scales[name] = torch.max(act_scales[name], comming_max) - else: - act_scales[name] = comming_max - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x) - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear): - hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) - - self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) - - for h in hooks: - h.remove() - - return act_scales - - # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py - @torch.no_grad() - def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): - if not isinstance(fcs, list): - fcs = [fcs] - for fc in fcs: - assert isinstance(fc, nn.Linear) - assert ln.weight.numel() == fc.in_features == act_scales.numel() - - device, dtype = fcs[0].weight.device, fcs[0].weight.dtype - act_scales = act_scales.to(device=device, dtype=dtype) - weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) - weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) - - scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) - - ln.weight.div_(scales) - if hasattr(ln, "bias"): - ln.bias.div_(scales) - - for fc in fcs: - fc.weight.mul_(scales.view(1, -1)) - - @classmethod - def create_quantized_model(model): - raise NotImplementedError("Not implement create_quantized_model method") - - # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py - def save_quantized( - self, - save_dir: str, - model_basename: str, - use_safetensors: bool = False, - safetensors_metadata: Optional[Dict[str, str]] = None, - ): - """save quantized model and configs to local disk""" - os.makedirs(save_dir, exist_ok=True) - - if not self.quantized: - raise EnvironmentError("can only save quantized model, please execute .quantize first.") - - self.model.to("cpu") - - model_base_name = model_basename # or f"smooth-" - if use_safetensors: - model_save_name = model_base_name + ".safetensors" - state_dict = self.model.state_dict() - state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} - if safetensors_metadata is None: - safetensors_metadata = {} - elif not isinstance(safetensors_metadata, dict): - raise TypeError("safetensors_metadata must be a dictionary.") - else: - print(f"Received safetensors_metadata: {safetensors_metadata}") - new_safetensors_metadata = {} - converted_keys = False - for key, value in safetensors_metadata.items(): - if not isinstance(key, str) or not isinstance(value, str): - converted_keys = True - try: - new_key = str(key) - new_value = str(value) - except Exception as e: - raise TypeError( - f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" - ) - if new_key in new_safetensors_metadata: - print( - f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." - ) - new_safetensors_metadata[new_key] = new_value - safetensors_metadata = new_safetensors_metadata - if converted_keys: - print( - f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" - ) - - # Format is required to enable Accelerate to load the metadata - # otherwise it raises an OSError - safetensors_metadata["format"] = "pt" - - safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) - else: - model_save_name = model_base_name + ".bin" - torch.save(self.model.state_dict(), join(save_dir, model_save_name)) - - self.model.config.save_pretrained(save_dir) - - # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py - def save_pretrained( - self, - save_dir: str, - use_safetensors: bool = False, - safetensors_metadata: Optional[Dict[str, str]] = None, - **kwargs, - ): - """alias of save_quantized""" - warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") - self.save_quantized(save_dir, use_safetensors, safetensors_metadata) - - # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - max_memory: Optional[dict] = None, - trust_remote_code: bool = False, - torch_dtype: torch.dtype = torch.float16, - **model_init_kwargs, - ): - if not torch.cuda.is_available(): - raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - - # Parameters related to loading from Hugging Face Hub - cache_dir = model_init_kwargs.pop("cache_dir", None) - force_download = model_init_kwargs.pop("force_download", False) - resume_download = model_init_kwargs.pop("resume_download", False) - proxies = model_init_kwargs.pop("proxies", None) - local_files_only = model_init_kwargs.pop("local_files_only", False) - use_auth_token = model_init_kwargs.pop("use_auth_token", None) - revision = model_init_kwargs.pop("revision", None) - subfolder = model_init_kwargs.pop("subfolder", "") - model_init_kwargs.pop("_commit_hash", None) - - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "use_auth_token": use_auth_token, - "revision": revision, - "subfolder": subfolder, - } - - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) - if config.model_type not in SUPPORTED_MODELS: - raise TypeError(f"{config.model_type} isn't supported yet.") - - # enforce some values despite user specified - model_init_kwargs["torch_dtype"] = torch_dtype - model_init_kwargs["trust_remote_code"] = trust_remote_code - if max_memory: - if "disk" in max_memory: - raise NotImplementedError("disk offload not support yet.") - with accelerate.init_empty_weights(): - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) - model.tie_weights() - - max_memory = accelerate.utils.get_balanced_memory( - model, - max_memory=max_memory, - no_split_module_classes=[cls.layer_type], - dtype=model_init_kwargs["torch_dtype"], - low_zero=False, - ) - model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( - model, - max_memory=max_memory, - no_split_module_classes=[cls.layer_type], - dtype=model_init_kwargs["torch_dtype"], - ) - model_init_kwargs["low_cpu_mem_usage"] = True - - del model - else: - model_init_kwargs["device_map"] = None - model_init_kwargs["low_cpu_mem_usage"] = False - - torch.cuda.empty_cache() - - merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) - - model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] - if any([k in model_config for k in seq_len_keys]): - for key in seq_len_keys: - if key in model_config: - model.seqlen = model_config[key] - break - else: - warnings.warn("can't get model's sequence length from model config, will set to 4096.") - model.seqlen = 4096 - model.eval() - - return cls(model, False) - - # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py - @classmethod - def from_quantized( - cls, - model_name_or_path: Optional[str], - model_basename: Optional[str] = None, - device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, - max_memory: Optional[dict] = None, - device: Optional[Union[str, int]] = None, - low_cpu_mem_usage: bool = False, - torch_dtype: Optional[torch.dtype] = None, - use_safetensors: bool = False, - trust_remote_code: bool = False, - **kwargs, - ): - """load quantized model from local disk""" - - # Parameters related to loading from Hugging Face Hub - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", "") - commit_hash = kwargs.pop("_commit_hash", None) - - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "use_auth_token": use_auth_token, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - - # == step1: prepare configs and file names == # - config = AutoConfig.from_pretrained( - model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs - ) - - if config.model_type not in SUPPORTED_MODELS: - raise TypeError(f"{config.model_type} isn't supported yet.") - - extensions = [] - if use_safetensors: - extensions.append(".safetensors") - else: - extensions += [".bin", ".pt"] - - model_name_or_path = str(model_name_or_path) - is_local = isdir(model_name_or_path) - - resolved_archive_file = None - if is_local: - model_save_name = join(model_name_or_path, model_basename) - for ext in extensions: - if isfile(model_save_name + ext): - resolved_archive_file = model_save_name + ext - break - else: # remote - for ext in extensions: - resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) - if resolved_archive_file is not None: - break - - if resolved_archive_file is None: # Could not find a model file to use - raise FileNotFoundError(f"Could not find model in {model_name_or_path}") - - model_save_name = resolved_archive_file - - # == step2: convert model to quantized-model (replace Linear) == # - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - - transformers.modeling_utils._init_weights = False - - init_contexts = [no_init_weights()] - if low_cpu_mem_usage: - init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) - - with ContextManagers(init_contexts): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype - ) - cls.create_quantized_model(model) - model.tie_weights() - - # == step3: load checkpoint to quantized-model == # - accelerate.utils.modeling.load_checkpoint_in_model( - model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True - ) - - # == step4: set seqlen == # - model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] - if any([k in model_config for k in seq_len_keys]): - for key in seq_len_keys: - if key in model_config: - model.seqlen = model_config[key] - break - else: - warnings.warn("can't get model's sequence length from model config, will set to 4096.") - model.seqlen = 4096 - - return cls( - model, - True, - ) - - def __getattr__(self, item): - try: - return super().__getattr__(item) - except: - return getattr(self.model, item) - - -__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py deleted file mode 100644 index 03d994b32489..000000000000 --- a/colossalai/inference/quant/smoothquant/models/linear.py +++ /dev/null @@ -1,189 +0,0 @@ -# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py - -import torch - -try: - from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 - from torch_int.functional.quantization import quantize_per_tensor_absmax - - HAS_TORCH_INT = True -except ImportError: - HAS_TORCH_INT = False - print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") - - -try: - from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder - - smoothquant_cuda = SmoothquantBuilder().load() - HAS_SMOOTHQUANT_CUDA = True -except: - HAS_SMOOTHQUANT_CUDA = False - print("CUDA smoothquant linear is not installed") - - -class W8A8BFP32O32LinearSiLU(torch.nn.Module): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer( - "weight", - torch.randint( - -127, - 127, - (self.out_features, self.in_features), - dtype=torch.int8, - requires_grad=False, - ), - ) - self.register_buffer( - "bias", - torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), - ) - self.register_buffer("a", torch.tensor(alpha)) - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) - y = y.view(*x_shape[:-1], -1) - return y - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - if module.bias is not None: - int8_module.bias.data.copy_(module.bias.to(torch.float)) - int8_module.a = alpha - return int8_module - - -# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py -class W8A8B8O8Linear(torch.nn.Module): - # For qkv_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer( - "weight", - torch.randint( - -127, - 127, - (self.out_features, self.in_features), - dtype=torch.int8, - requires_grad=False, - ), - ) - self.register_buffer( - "bias", - torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), - ) - self.register_buffer("a", torch.tensor(alpha)) - self.register_buffer("b", torch.tensor(beta)) - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - self.bias = self.bias.to(*args, **kwargs) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) - y = y.view(*x_shape[:-1], -1) - return y - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale, output_scale): - int8_module = W8A8B8O8Linear(module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale / output_scale - int8_module.weight = int8_weight - int8_module.a = alpha - - if module.bias is not None: - int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) - int8_module.bias = int8_bias - beta = bias_scale / output_scale - int8_module.b = beta - - return int8_module - - -# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py -class W8A8BFP32OFP32Linear(torch.nn.Module): - # For fc2 and out_proj - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer( - "weight", - torch.randint( - -127, - 127, - (self.out_features, self.in_features), - dtype=torch.int8, - requires_grad=False, - ), - ) - self.register_buffer( - "bias", - torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False), - ) - self.register_buffer("a", torch.tensor(alpha)) - - def _apply(self, fn): - # prevent the bias from being converted to half - super()._apply(fn) - if self.bias is not None: - self.bias = self.bias.to(torch.float32) - return self - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.weight = self.weight.to(*args, **kwargs) - if self.bias is not None: - self.bias = self.bias.to(*args, **kwargs) - self.bias = self.bias.to(torch.float32) - return self - - @torch.no_grad() - def forward(self, x): - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) - y = y.view(*x_shape[:-1], -1) - return y - - @staticmethod - def from_float(module: torch.nn.Linear, input_scale): - int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) - int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) - alpha = input_scale * weight_scale - int8_module.weight = int8_weight - int8_module.a = alpha - int8_module.input_scale = input_scale - int8_module.weight_scale = weight_scale - - if module.bias is not None: - int8_module.bias = module.bias.to(torch.float32) - - return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py deleted file mode 100644 index bb74dc49d7af..000000000000 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ /dev/null @@ -1,852 +0,0 @@ -import math -import os -import types -from collections import defaultdict -from functools import partial -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LLAMA_INPUTS_DOCSTRING, - LlamaAttention, - LlamaDecoderLayer, - LlamaMLP, - LlamaRotaryEmbedding, - rotate_half, -) -from transformers.utils import add_start_docstrings_to_model_forward - -from colossalai.inference.kv_cache.batch_infer_state import BatchInferState -from colossalai.kernel.triton import ( - copy_kv_cache_to_dest, - int8_rotary_embedding_fwd, - smooth_llama_context_attn_fwd, - smooth_token_attention_fwd, -) - -try: - from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T - - HAS_TORCH_INT = True -except ImportError: - HAS_TORCH_INT = False - print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") - - -from .base_model import BaseSmoothForCausalLM -from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LLamaSmoothquantAttention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - - self.qk_bmm = BMM_S8T_S8N_F32T(1.0) - self.pv_bmm = BMM_S8T_S8N_S8T(1.0) - - self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) - - self.register_buffer("q_output_scale", torch.tensor([1.0])) - self.register_buffer("k_output_scale", torch.tensor([1.0])) - self.register_buffer("v_output_scale", torch.tensor([1.0])) - self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) - self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) - self.register_buffer("out_input_scale", torch.tensor([1.0])) - self.register_buffer("attn_input_scale", torch.tensor([1.0])) - - self._init_rope() - self.num_key_value_heads = num_heads - - def _init_rope(self): - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=2048, - base=10000.0, - ) - - @staticmethod - def pack( - module: LlamaAttention, - attn_input_scale: float, - q_output_scale: float, - k_output_scale: float, - v_output_scale: float, - q_rotary_output_scale: float, - k_rotary_output_scale: float, - out_input_scale: float, - ): - int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) - - int8_module.attn_input_scale = torch.tensor([attn_input_scale]) - - int8_module.q_output_scale = torch.tensor([q_output_scale]) - int8_module.k_output_scale = torch.tensor([k_output_scale]) - int8_module.v_output_scale = torch.tensor([v_output_scale]) - - int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) - int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) - - int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) - int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) - int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) - int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) - - int8_module.out_input_scale = torch.tensor([out_input_scale]) - - return int8_module - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - @torch.no_grad() - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - cos, sin = infer_state.position_cos, infer_state.position_sin - - int8_rotary_embedding_fwd( - query_states.view(-1, self.num_heads, self.head_dim), - cos, - sin, - self.q_output_scale.item(), - self.q_rotary_output_scale.item(), - ) - int8_rotary_embedding_fwd( - key_states.view(-1, self.num_heads, self.head_dim), - cos, - sin, - self.k_output_scale.item(), - self.k_rotary_output_scale.item(), - ) - - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - query_states = query_states.view(-1, self.num_heads, self.head_dim) - key_states = key_states.view(-1, self.num_heads, self.head_dim) - value_states = value_states.view(-1, self.num_heads, self.head_dim) - - if infer_state.is_context_stage: - # first token generation - - # copy key and value calculated in current step to memory manager - _copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.context_mem_index, - infer_state.cache_manager, - ) - - attn_output = torch.empty_like(query_states) - - smooth_llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - self.q_rotary_output_scale.item(), - self.k_rotary_output_scale.item(), - self.v_output_scale.item(), - self.out_input_scale.item(), - infer_state.start_loc, - infer_state.seq_len, - q_len, - ) - - else: - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(key_states) - cache_v.copy_(value_states) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.decode_mem_index, - infer_state.cache_manager, - ) - - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - smooth_token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - self.q_rotary_output_scale.item(), - self.k_rotary_output_scale.item(), - self.v_output_scale.item(), - self.out_input_scale.item(), - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - - attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) - attn_output = self.o_proj(attn_output) - - return attn_output, None, None - - -class LlamaLayerNormQ(torch.nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.input_scale = 1.0 - self.variance_epsilon = eps - self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) - - def forward(self, x): - ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) - ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) - return ln_output_int8 - - @staticmethod - def from_float(module: torch.nn.LayerNorm, output_scale: float): - assert module.weight.shape[0] == module.weight.numel() - q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) - q_module.weight = module.weight / output_scale - return q_module - - -class LlamaSmoothquantMLP(nn.Module): - def __init__(self, intermediate_size, hidden_size): - super().__init__() - self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) - self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) - self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) - self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) - - @staticmethod - def pack( - mlp_module: LlamaMLP, - gate_proj_input_scale: float, - up_proj_input_scale: float, - down_proj_input_scale: float, - ): - int8_module = LlamaSmoothquantMLP( - mlp_module.intermediate_size, - mlp_module.hidden_size, - ) - - int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) - int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) - int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) - int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) - return int8_module - - def forward( - self, - hidden_states: torch.Tensor, - ): - x_shape = hidden_states.shape - gate_out = self.gate_proj(hidden_states) - up_out = self.up_proj(hidden_states) - inter_out = gate_out * up_out - inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) - down_out = self.down_proj(inter_out) - down_out = down_out.view(*x_shape[:-1], -1) - return down_out - - -class LlamaSmoothquantDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) - - self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) - self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) - - self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) - - @staticmethod - def pack( - module: LlamaDecoderLayer, - attn_input_scale: float, - q_output_scale: float, - k_output_scale: float, - v_output_scale: float, - q_rotary_output_scale: float, - k_rotary_output_scale: float, - out_input_scale: float, - gate_input_scale: float, - up_input_scale: float, - down_input_scale: float, - ): - config = module.self_attn.config - int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) - - int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) - int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( - module.self_attn, - attn_input_scale, - q_output_scale, - k_output_scale, - v_output_scale, - q_rotary_output_scale, - k_rotary_output_scale, - out_input_scale, - ) - - int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( - module.post_attention_layernorm, gate_input_scale - ) - - int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( - module.mlp, - gate_input_scale, - up_input_scale, - down_input_scale, - ) - - return int8_decoder_layer - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - infer_state=infer_state, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states, None, None - - -class LlamaApplyRotary(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - x_embed = (x * cos) + (rotate_half(x) * sin) - - return x_embed - - -# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -def llama_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) - key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def init_to_get_rotary(config, base=10000, use_elem=False): - """ - This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer - Args: - base : calculation arg - use_elem : activated when using chatglm-based models - """ - config.head_dim_ = config.hidden_size // config.num_attention_heads - if not hasattr(config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 - - if hasattr(config, "max_sequence_length"): - max_seq_len = config.max_sequence_length - elif hasattr(config, "max_position_embeddings"): - max_seq_len = config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - - # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - try: - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - print(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula - except: - pass - - n_elem = config.head_dim_ - if use_elem: - n_elem //= 2 - - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - _cos_cached = torch.cos(freqs).to(torch.float) - _sin_cached = torch.sin(freqs).to(torch.float) - return _cos_cached, _sin_cached - - -# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -def llama_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - infer_state = self.infer_state - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - seq_length_with_past = seq_length + past_key_values_length - - # NOTE: differentiate with prefill stage - # block_loc require different value-assigning method for two different stage - # NOTE: differentiate with prefill stage - # block_loc require different value-assigning method for two different stage - if infer_state.is_context_stage: - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}") - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - raise NotImplementedError("not implement gradient_checkpointing and training options ") - - if past_key_values_length == 0: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - else: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - infer_state=infer_state, - ) - - hidden_states = layer_outputs[0] - infer_state.decode_layer_id += 1 - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - infer_state.is_context_stage = False - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): - layer_type = "LlamaDecoderLayer" - - def __init__(self, model: PreTrainedModel, quantized: bool = False): - super().__init__(model, quantized) - - # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py - def get_act_dict( - self, - tokenizer, - dataset, - num_samples=512, - seq_len=512, - ): - llama_model = self.model - - llama_model.eval() - device = next(llama_model.parameters()).device - # print("model:", llama_model) - act_dict = defaultdict(dict) - - def stat_io_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - if name not in act_dict or "input" not in act_dict[name]: - act_dict[name]["input"] = x.detach().abs().max().item() - else: - act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) - if isinstance(y, tuple): - y = y[0] - if name not in act_dict or "output" not in act_dict[name]: - act_dict[name]["output"] = y.detach().abs().max().item() - else: - act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) - - for name, m in llama_model.named_modules(): - if isinstance(m, LlamaAttention): - setattr(m, "q_apply_rotary", LlamaApplyRotary()) - setattr(m, "k_apply_rotary", LlamaApplyRotary()) - m.forward = types.MethodType(llama_decoder_layer_forward, m) - - hooks = [] - for name, m in llama_model.named_modules(): - if isinstance(m, LlamaApplyRotary): - hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - if isinstance(m, torch.nn.Linear): - hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - - self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) - - for hook in hooks: - hook.remove() - return act_dict - - def smooth_fn(self, scales, alpha=0.5): - model = self.model - for name, module in model.named_modules(): - if isinstance(module, LlamaDecoderLayer): - attn_ln = module.input_layernorm - qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] - qkv_input_scales = scales[name + ".self_attn.q_proj"] - self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) - - def create_quantized_model(model): - llama_config = model.config - for i, layer in enumerate(model.model.layers): - model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) - - model.model.forward = types.MethodType(llama_model_forward, model.model) - cos, sin = init_to_get_rotary(llama_config) - model.model.register_buffer("_cos_cached", cos) - model.model.register_buffer("_sin_cached", sin) - - def quantized( - self, - tokenizer, - dataset, - num_samples=512, - seq_len=512, - alpha=0.5, - ): - llama_model = self.model - llama_config = llama_model.config - - act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) - - self.smooth_fn(act_scales, alpha) - - act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) - decoder_layer_scales = [] - - for idx in range(llama_config.num_hidden_layers): - scale_dict = {} - scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 - scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 - scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 - scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 - - scale_dict["q_rotary_output_scale"] = ( - act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 - ) - scale_dict["k_rotary_output_scale"] = ( - act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 - ) - - scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 - - scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 - scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 - scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 - - decoder_layer_scales.append(scale_dict) - - for i, layer in enumerate(llama_model.model.layers): - orig_layer = layer - llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) - - llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) - - cos, sin = init_to_get_rotary(llama_config) - llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) - llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/inference/quant/smoothquant/models/parallel_linear.py b/colossalai/inference/quant/smoothquant/models/parallel_linear.py deleted file mode 100644 index 962b687a1d05..000000000000 --- a/colossalai/inference/quant/smoothquant/models/parallel_linear.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import List, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import ParallelModule - -from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear - - -def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1): - qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0) - if smooth_linear.bias is not None: - bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0) - - smooth_split_out_features = para_linear.out_features // split_num - - for i in range(split_num): - para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][ - tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, : - ] - - if para_linear.bias is not None: - para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][ - :, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features - ] - - -def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1): - qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1) - - smooth_split_in_features = para_linear.in_features // split_num - - for i in range(split_num): - para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][ - :, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features - ] - - if smooth_linear.bias is not None: - para_linear.bias.copy_(smooth_linear.bias) - - -class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__(in_features, out_features, alpha, beta) - self.process_group = None - self.tp_size = 1 - self.tp_rank = 0 - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - out_features = module.out_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if out_features < tp_size: - return module - - if out_features % tp_size != 0: - raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size) - linear_1d.tp_size = tp_size - linear_1d.tp_rank = tp_rank - linear_1d.process_group = process_group - linear_1d.a = module.a.clone().detach() - linear_1d.b = module.b.clone().detach() - split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - return linear_1d - - -class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__(in_features, out_features, alpha, beta) - self.process_group = None - self.tp_size = 1 - self.tp_rank = 0 - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - in_features = module.in_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if in_features < tp_size: - return module - - if in_features % tp_size != 0: - raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features) - linear_1d.tp_size = tp_size - linear_1d.tp_rank = tp_rank - linear_1d.process_group = process_group - linear_1d.a = torch.tensor(module.a) - linear_1d.b = torch.tensor(module.b) - - split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - if linear_1d.bias is not None: - linear_1d.bias = linear_1d.bias // tp_size - - return linear_1d - - @torch.no_grad() - def forward(self, x): - output = super().forward(x) - if self.tp_size > 1: - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) - return output - - -class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__(in_features, out_features, alpha, beta) - self.process_group = None - self.tp_size = 1 - self.tp_rank = 0 - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - out_features = module.out_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if out_features < tp_size: - return module - - if out_features % tp_size != 0: - raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size) - linear_1d.tp_size = tp_size - linear_1d.tp_rank = tp_rank - linear_1d.process_group = process_group - linear_1d.a = module.a.clone().detach() - - split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - return linear_1d - - -class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__(in_features, out_features, alpha, beta) - self.process_group = None - self.tp_size = 1 - self.tp_rank = 0 - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - out_features = module.out_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if out_features < tp_size: - return module - - if out_features % tp_size != 0: - raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size) - linear_1d.tp_size = tp_size - linear_1d.tp_rank = tp_rank - linear_1d.process_group = process_group - linear_1d.a = module.a.clone().detach() - - split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - return linear_1d - - -class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule): - def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): - super().__init__(in_features, out_features, alpha, beta) - self.process_group = None - self.tp_size = 1 - self.tp_rank = 0 - - @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: - LazyInitContext.materialize(module) - # get the attributes - in_features = module.in_features - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." - process_group = process_group[0] - - tp_size = dist.get_world_size(process_group) - tp_rank = dist.get_rank(process_group) - - if in_features < tp_size: - return module - - if in_features % tp_size != 0: - raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features) - linear_1d.tp_size = tp_size - linear_1d.tp_rank = tp_rank - linear_1d.process_group = process_group - linear_1d.a = module.a.clone().detach() - - split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) - if linear_1d.bias is not None: - linear_1d.bias = linear_1d.bias / tp_size - - return linear_1d - - @torch.no_grad() - def forward(self, x): - output = super().forward(x) - if self.tp_size > 1: - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) - return output diff --git a/colossalai/inference/sampler.py b/colossalai/inference/sampler.py new file mode 100644 index 000000000000..949d979bc247 --- /dev/null +++ b/colossalai/inference/sampler.py @@ -0,0 +1,111 @@ +from typing import List, Optional, Tuple, Union + +import torch +from transformers.generation import GenerationConfig + +from colossalai.inference.logit_processors import get_logits_processor + + +def greedy_sample( + logprobs: torch.Tensor, +) -> torch.Tensor: + """ + Sample tokens greedyly. + """ + results = torch.argmax(logprobs, dim=-1) + return results + + +def multinomial_sample( + probs: torch.Tensor, +) -> torch.Tensor: + """ + Sample tokens in a random phase. + """ + random_results = torch.multinomial(probs, num_samples=1).squeeze(1) + return random_results + + +def beam_search_sample( + beam_width: int, + logprobs: torch.Tensor, + is_prompt: bool = False, +) -> List[Tuple[List[int], List[int]]]: + """ + Sample tokens with beam search. + We sample 2 * beam_width candidates to make sure that with high probability we can get `beam_width` candidates in addition to + the finished sequences for the next iteration. + + ref: + https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 + for details. See also HF reference: + https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 + + # NOTE: this beam search sample function is wrong now. + """ + + results = [] + if is_prompt: + # Prompt phase. + parent_ids = [0] * (2 * beam_width) + _, next_token_ids = torch.topk(logprobs[0], 2 * beam_width) + next_token_ids = next_token_ids.tolist() + else: + # Generation phase. + # cumulative_logprobs = [seq_data[seq_id].cumulative_logprob for seq_id in seq_ids] + cumulative_logprobs = torch.tensor(logprobs, dtype=torch.float, device=seq_group_logprobs.device) + seq_group_logprobs = seq_group_logprobs + cumulative_logprobs.unsqueeze(dim=1) + _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) + + results.append((next_token_ids, parent_ids)) + return results + + +def search_tokens( + generation_config: Union[GenerationConfig, dict], + logits, + is_prompt: bool = False, + batch_token_ids: Optional[List[List[int]]] = None, +): + """ + Sample tokens for finished requests. + """ + # NOTE: need to decide the granularity to process logits (sequence or batch) + + # convert GenerationConfig to dict + # temporary fix for compatibility with the usage of RPCInferenceEngine + if isinstance(generation_config, GenerationConfig): + generation_config = generation_config.to_dict() + + if (repetition_penalty := generation_config.get("repetition_penalty", 1.0)) != 1.0: + logits = get_logits_processor("repetition_penalty", logits, repetition_penalty, batch_token_ids) + if (no_repeat_ngram_size := generation_config.get("no_repeat_ngram_size", 0)) > 0: + logits = get_logits_processor("no_repeat_ngram_size", logits, no_repeat_ngram_size, batch_token_ids) + if (forced_eos_token_id := generation_config.get("forced_eos_token_id", None)) is not None: + sequence_lengths = [len(batch_token_ids[i]) for i in range(len(batch_token_ids))] + max_out_lengths = [generation_config.max_length for _ in range(len(batch_token_ids))] + logits = get_logits_processor( + "forced_eos_token_id", logits, sequence_lengths, max_out_lengths, forced_eos_token_id + ) + + if generation_config.get("do_sample"): + if (temperature := generation_config.get("temperature", 1.0)) != 1.0: + logits = get_logits_processor("temperature", logits, temperature) + if (top_k := generation_config.get("top_k", 0)) != 0: + logits = get_logits_processor("top_k", logits, top_k) + if (top_p := generation_config.get("top_p", 1.0)) < 1.0: + logits = get_logits_processor("top_p", logits, top_p) + + # calculate probs + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # sample the next tokens + if generation_config.get("num_beams", 1) != 1: + raise NotImplementedError("Beam search is not supported yet.") + if generation_config.get("do_sample", False): + sample_tokens = multinomial_sample(probs) + else: + sample_tokens = greedy_sample(logprobs) + + return sample_tokens diff --git a/colossalai/inference/server/README.md b/colossalai/inference/server/README.md new file mode 100644 index 000000000000..8b5f29fc097d --- /dev/null +++ b/colossalai/inference/server/README.md @@ -0,0 +1,27 @@ +# Online Service +Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and +you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill +the blank quickly. + +# Usage +```bash +# First, Lauch an API locally. +python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %} +{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" + + +# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + +# For completion service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +# For chat service, you can invoke it +curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation": + [{"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"},], + "stream": "False",}' +# If you just want to test a simple generation, turn to generate api +curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}' + +``` +We also support streaming output, simply change the `stream` to `True` in the request body. diff --git a/colossalai/inference/server/__init__.py b/colossalai/inference/server/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py new file mode 100644 index 000000000000..91c77ed35784 --- /dev/null +++ b/colossalai/inference/server/api_server.py @@ -0,0 +1,228 @@ +""" +Doc: + Feature: + - FastAPI based http server for Colossal-Inference + - Completion Service Supported + Usage: (for local user) + - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` + - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hello, who are you? ","stream":"False"}'` + Version: V1.0 +""" + +import argparse +import json + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.server.chat_service import ChatServing +from colossalai.inference.server.completion_service import CompletionServing +from colossalai.inference.server.utils import id_generator +from colossalai.inference.utils import find_available_ports + +from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +supported_models_dict = {"Llama_Models": ("llama2-7b",)} +prompt_template_choices = ["llama", "vicuna"] +async_engine = None +chat_serving = None +completion_serving = None + +app = FastAPI() + + +# NOTE: (CjhHa1) models are still under development, need to be updated +@app.get("/models") +def get_available_models() -> Response: + return JSONResponse(supported_models_dict) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + A request should be a JSON object with the following fields: + - prompts: the prompts to use for the generation. + - stream: whether to stream the results or not. + - other fields: + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", "false") + if isinstance(stream, str): + stream = stream.lower() + request_id = id_generator() + generation_config = get_generation_config(request_dict) + results = engine.generate(request_id, prompt, generation_config=generation_config) + + # Streaming case + def stream_results(): + for request_output in results: + ret = {"text": request_output[len(prompt) :]} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream == "true" or stream == True: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + for request_output in results: + if request.is_disconnected(): + # Abort the request if the client disconnects. + engine.abort(request_id) + return Response(status_code=499) + final_output = request_output[len(prompt) :] + + assert final_output is not None + ret = {"text": final_output} + return JSONResponse(ret) + + +@app.post("/completion") +async def create_completion(request: Request): + request_dict = await request.json() + stream = request_dict.pop("stream", "false") + if isinstance(stream, str): + stream = stream.lower() + generation_config = get_generation_config(request_dict) + result = await completion_serving.create_completion(request, generation_config) + + ret = {"request_id": result.request_id, "text": result.output} + if stream == "true" or stream == True: + return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") + else: + return JSONResponse(content=ret) + + +@app.post("/chat") +async def create_chat(request: Request): + request_dict = await request.json() + + stream = request_dict.get("stream", "false") + if isinstance(stream, str): + stream = stream.lower() + generation_config = get_generation_config(request_dict) + message = await chat_serving.create_chat(request, generation_config) + if stream == "true" or stream == True: + return StreamingResponse(content=message, media_type="text/event-stream") + else: + ret = {"role": message.role, "text": message.content} + return ret + + +def get_generation_config(request): + generation_config = async_engine.engine.generation_config + for arg in request: + if hasattr(generation_config, arg): + setattr(generation_config, arg, request[arg]) + return generation_config + + +def add_engine_config(parser): + parser.add_argument( + "-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use" + ) + # Parallel arguments not supported now + + # KV cache arguments + parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") + + parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") + + parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length") + + parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length") + + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + + # generation arguments + parser.add_argument( + "--prompt_template", + choices=prompt_template_choices, + default=None, + help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", + ) + return parser + + +def parse_args(): + parser = argparse.ArgumentParser(description="Colossal-Inference API server.") + + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.") + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.", + ) + + parser.add_argument( + "--chat-template", + type=str, + default=None, + help="The file path to the chat template, " "or the template in single-line form " "for the specified model", + ) + parser.add_argument( + "--response-role", + type=str, + default="assistant", + help="The role name to return if " "`request.add_generation_prompt=true`.", + ) + parser = add_engine_config(parser) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + inference_config = InferenceConfig.from_dict(vars(args)) + tokenizer = AutoTokenizer.from_pretrained(args.model) + colossalai_backend_port = find_available_ports(1)[0] + colossalai.launch( + rank=0, + world_size=1, + host=args.host, + port=colossalai_backend_port, + backend="nccl", + ) + model = AutoModelForCausalLM.from_pretrained(args.model) + async_engine = AsyncInferenceEngine( + start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config + ) + engine = async_engine.engine + completion_serving = CompletionServing(async_engine, model.__class__.__name__) + chat_serving = ChatServing( + async_engine, + served_model=model.__class__.__name__, + tokenizer=tokenizer, + response_role=args.response_role, + chat_template=args.chat_template, + ) + app.root_path = args.root_path + uvicorn.run( + app=app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/colossalai/inference/server/chat_service.py b/colossalai/inference/server/chat_service.py new file mode 100644 index 000000000000..d84e82d2989a --- /dev/null +++ b/colossalai/inference/server/chat_service.py @@ -0,0 +1,142 @@ +import asyncio +import codecs +import logging + +from fastapi import Request + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator + +logger = logging.getLogger("colossalai-inference") + + +class ChatServing: + def __init__( + self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None + ): + self.engine = engine + self.served_model = served_model + self.tokenizer = tokenizer + self.response_role = response_role + self._load_chat_template(chat_template) + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_chat(self, request: Request, generation_config): + request_dict = await request.json() + messages = request_dict["messages"] + stream = request_dict.pop("stream", "false").lower() + add_generation_prompt = request_dict.pop("add_generation_prompt", False) + request_id = id_generator() + try: + prompt = self.tokenizer.apply_chat_template( + conversation=messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + except Exception as e: + raise RuntimeError(f"Error in applying chat template from request: {str(e)}") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt) + + if stream == "true": + return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id) + else: + return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id) + + async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int): + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request, request_dict) + n = request_dict.get("n", 1) + echo = request_dict.get("echo", "false").lower() + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role)) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if echo == "true": + last_msg_content = "" + if ( + request_dict["messages"] + and isinstance(request_dict["messages"], list) + and request_dict["messages"][-1].get("content") + and request_dict["messages"][-1].get("role") == role + ): + last_msg_content = request_dict["messages"][-1]["content"] + if last_msg_content: + for i in range(n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, message=DeltaMessage(content=last_msg_content) + ) + data = choice_data.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + result = await result_generator + choice_data = DeltaMessage(content=result.output) + data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {data}\n\n" + + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: Request, + request_dict: dict, + result_generator, + request_id, + ): + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return {"error_msg": "Client disconnected"} + + result = await result_generator + assert result is not None + role = self.get_chat_request_role(request, request_dict) + choice_data = ChatMessage(role=role, content=result.output) + echo = request_dict.get("echo", "false").lower() + + if echo == "true": + last_msg_content = "" + if ( + request.messages + and isinstance(request.messages, list) + and request.messages[-1].get("content") + and request.messages[-1].get("role") == role + ): + last_msg_content = request.messages[-1]["content"] + + full_message = last_msg_content + choice_data.content + choice_data.content = full_message + + return choice_data + + def get_chat_request_role(self, request: Request, request_dict: dict) -> str: + add_generation_prompt = request_dict.get("add_generation_prompt", False) + if add_generation_prompt: + return self.response_role + else: + return request_dict["messages"][-1]["role"] + + def _load_chat_template(self, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}") + elif self.tokenizer.chat_template is not None: + logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}") + else: + logger.warning("No chat template provided. Chat API will not work.") diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py new file mode 100644 index 000000000000..16111dad4b1e --- /dev/null +++ b/colossalai/inference/server/completion_service.py @@ -0,0 +1,34 @@ +import asyncio + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import id_generator + + +class CompletionServing: + def __init__(self, engine: AsyncInferenceEngine, served_model: str): + self.engine = engine + self.served_model = served_model + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_completion(self, request, generation_config): + request_dict = await request.json() + request_id = id_generator() + + prompt = request_dict.pop("prompt") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config) + + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + raise RuntimeError("Client disconnected") + + final_res = await result_generator + return final_res diff --git a/colossalai/inference/server/utils.py b/colossalai/inference/server/utils.py new file mode 100644 index 000000000000..9eac26576c6c --- /dev/null +++ b/colossalai/inference/server/utils.py @@ -0,0 +1,36 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + +# make it singleton +class NumericIDGenerator: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NumericIDGenerator, cls).__new__(cls) + cls._instance.current_id = 0 + return cls._instance + + def __call__(self): + self.current_id += 1 + return self.current_id + + +id_generator = NumericIDGenerator() + + +class ChatMessage(BaseModel): + role: str + content: Any + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[Any] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + message: DeltaMessage diff --git a/colossalai/inference/spec/__init__.py b/colossalai/inference/spec/__init__.py new file mode 100644 index 000000000000..b1a05f6a407e --- /dev/null +++ b/colossalai/inference/spec/__init__.py @@ -0,0 +1,4 @@ +from .drafter import Drafter +from .struct import DrafterOutput, GlideInput + +__all__ = ["Drafter", "DrafterOutput", "GlideInput"] diff --git a/colossalai/inference/spec/drafter.py b/colossalai/inference/spec/drafter.py new file mode 100644 index 000000000000..3144b2c90c95 --- /dev/null +++ b/colossalai/inference/spec/drafter.py @@ -0,0 +1,121 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers import PreTrainedTokenizer + +from colossalai.utils import get_current_device + +from .struct import DrafterOutput, GlideInput + + +class Drafter: + """Container for the Drafter Model (Assistant Model) used in Speculative Decoding. + + Args: + model (nn.Module): The drafter model. + tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model. + device (torch.device): The device for the drafter model. + """ + + def __init__( + self, + model: nn.Module, + tokenizer: PreTrainedTokenizer, + device: torch.device = None, + dtype: torch.dtype = torch.float16, + ): + self._tokenizer = tokenizer + self._device = device or get_current_device() + self._dtype = dtype + self._drafter_model = model.to(self._device) + self._drafter_model = model.to(self._dtype) + self._drafter_model.eval() + + def get_model(self) -> nn.Module: + return self._drafter_model + + @staticmethod + def trim_kv_cache( + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int + ) -> Tuple[Tuple[torch.FloatTensor]]: + """Trim the last `invalid_token_num` kv caches. + + past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape + num_layers x 2 x (bsz x num_heads x seq_len x head_dim) + invalid_token_num (int): The number of invalid tokens to trim. + """ + if past_key_values is None or invalid_token_num < 1: + return past_key_values + + trimmed_past_key_values = [] + for layer_idx in range(len(past_key_values)): + past_key_value = past_key_values[layer_idx] + trimmed_past_key_values.append( + ( + past_key_value[0][:, :, :-invalid_token_num, :], + past_key_value[1][:, :, :-invalid_token_num, :], + ) + ) + past_key_values = tuple(trimmed_past_key_values) + return past_key_values + + @torch.inference_mode() + def speculate( + self, + input_ids: torch.Tensor, + n_spec_tokens: int, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + glide_input: Optional[GlideInput] = None, + ) -> DrafterOutput: + """Generate n_spec_tokens tokens using the drafter model. + + Args: + input_ids (torch.Tensor): Input token ids. + n_spec_tokens (int): Number of tokens to speculate. + past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence. + glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model, + when using the glide model as a drafter. + """ + assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate" + + # For compatibility with transformers of versions before 4.38.0 + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + logits = [] + token_ids = [] + + kwargs = {"return_dict": True, "use_cache": True} + if glide_input: + # required only when using glide model + kwargs["glide_input"] = glide_input + + for _ in range(n_spec_tokens): + # update past key values + kwargs["past_key_values"] = past_key_values + + outputs = self._drafter_model(input_ids, **kwargs) + next_token_logits = outputs.logits[:, -1, :] + + # NOTE Only use greedy search for speculating. + # As the drafter model usually has only a few layers with few parameters, + # introducing sampling will make the speculation unstable and lead to worse performance. + next_token_ids = torch.argmax(next_token_logits, dim=-1) + + logits.append(next_token_logits) + token_ids.append(next_token_ids) + if next_token_ids.item() == self._tokenizer.eos_token_id: + # TODO(yuanheng-zhao) support bsz > 1 + break + input_ids = next_token_ids[:, None] + past_key_values = outputs.past_key_values + + speculated_length = len(token_ids) # For now, only support bsz 1 + logits = torch.concat(logits, dim=0) + token_ids = torch.concat(token_ids, dim=-1) + + out = DrafterOutput( + speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values + ) + return out diff --git a/colossalai/inference/spec/struct.py b/colossalai/inference/spec/struct.py new file mode 100644 index 000000000000..143f26d09a59 --- /dev/null +++ b/colossalai/inference/spec/struct.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + + +@dataclass +class DrafterOutput: + """ + Dataclass for drafter model outputs. + + Args: + speculated_length (int): Speculated length of the output sequence + It is always less than or equal to spec_num during drafter's speculation process + logits (torch.FloatTensor): Logits of the output sequence + next_tokens (torch.Tensor): Next token ids + past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence + """ + + speculated_length: int = None + logits: torch.FloatTensor = None + next_tokens: torch.Tensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + def __post_init__(self): + assert self.speculated_length is not None and self.speculated_length >= 0 + if self.past_key_values is not None: + assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple" + assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values]) + + +@dataclass +class GlideInput: + """Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`). + Used for pack data that will be used during glimpsing KV Caches of the main model. + + Args: + block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches. + large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size] + Blocked key cache of the main model + large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache. + sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch. + """ + + block_tables: torch.Tensor = None + large_k_cache: torch.Tensor = None + large_v_cache: torch.Tensor = None + sequence_lengths: torch.Tensor = None + + @property + def glimpse_ready(self): + return all( + attr is not None + for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths] + ) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py new file mode 100644 index 000000000000..1a3094a27e2d --- /dev/null +++ b/colossalai/inference/struct.py @@ -0,0 +1,180 @@ +import enum +from dataclasses import dataclass +from typing import Any, List + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger(__name__) + +""" +The abstraction of request and sequence are defined here. +""" + + +class RequestStatus(enum.Enum): + """ + The status of Sentences + """ + + # running status + WAITING = enum.auto() + RUNNING = enum.auto() + ABORTED = enum.auto() + + # completion status + OVERLENGTH = enum.auto() + COMPLETED = enum.auto() + LENGTH_CAPPED = enum.auto() + + # recycle status + RECYCLED = enum.auto() + + @staticmethod + def is_finished(status: "RequestStatus") -> bool: + return status in [ + RequestStatus.OVERLENGTH, + RequestStatus.COMPLETED, + RequestStatus.LENGTH_CAPPED, + ] + + @staticmethod + def is_running(status: "RequestStatus") -> bool: + return status == RequestStatus.RUNNING + + @staticmethod + def is_waiting(status: "RequestStatus") -> bool: + return status == RequestStatus.WAITING + + +@dataclass +class Sequence: + """Store information of input sequence. + + Args: + request_id (int): The ID of input sequence. + prompt (str): The prompt of input sequence. + input_token_id (List[int]): The tokens ID of input sequence. + block_size (int): The block size of input sequence. + sample_params (SampleParams): The sample_params of input sequence. + block_table (torch.Tensor): The index of input sequence in block_table. + eos_token_id (int): The eos token id for this inference process. + pad_token_id (int): The pad token id for this inference process. + max_output_len (int): Maximum output length. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. + output(str): The output of sequence + """ + + request_id: int + prompt: str + input_token_id: List[int] + block_size: int + sample_params: Any # SampleParams needs to be imported later. + eos_token_id: int + pad_token_id: int + max_output_len: int = 256 + # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future. + ignore_eos: bool = False + output: str = None + + def __post_init__(self): + self.output_token_id = [] + self.status = RequestStatus.WAITING + + @property + def sentence_len(self) -> int: + """ + Get length of current sentence. + """ + return len(self.input_token_id) + len(self.output_token_id) + + @property + def input_len(self) -> int: + """ + Get length of input sentence. + """ + return len(self.input_token_id) + + @property + def output_len(self) -> int: + """ + Get length of output sentence. + """ + return len(self.output_token_id) + + def check_finish(self) -> bool: + """ + Check whether the inference is finished. + + Returns: + bool: Whether the inference is finished. + """ + if RequestStatus.is_finished(self.status): + return True + + if self.output_token_id: + if ( + self.output_token_id[-1] == self.eos_token_id and not self.ignore_eos + ) or self.output_len >= self.max_output_len: + self.status = RequestStatus.COMPLETED + return True + + return False + + def revoke_finished_status(self) -> None: + """ + Revoke the finished status of the sequence. + This is only used by speculative decoding for now. + """ + if RequestStatus.is_finished(self.status): + self.status = RequestStatus.RUNNING + + def __hash__(self): + return hash(self.request_id) + + def mark_running(self) -> None: + """ + Set status for prefill reqs. + """ + assert ( + self.status == RequestStatus.WAITING or RequestStatus.RECYCLED + ), "Sequence is not in WAITTING/RECYCLED STATUS" + self.status = RequestStatus.RUNNING + + def mark_finished(self) -> None: + """ + Set status for finished reqs. + """ + self.status = RequestStatus.COMPLETED + + def mark_aborted(self) -> None: + """ + Set status for aborted reqs. + """ + self.status = RequestStatus.ABORTED + + def recycle(self) -> None: + """ + Recycle a running sequnce to waiitting list + """ + assert ( + not self.check_finish() and not self.status == RequestStatus.ABORTED + ), "The running sequence \ + is already done but it still in running list" + self.status = RequestStatus.RECYCLED + + def __repr__(self) -> str: + return ( + f"(request_id={self.request_id}, " + f"prompt={self.prompt},\n" + f"output_token_id={self.output_token_id},\n" + f"output={self.output},\n" + f"status={self.status.name},\n" + f"sample_params={self.sample_params},\n" + f"input_len={self.input_len},\n" + f"output_len={self.output_len})\n" + ) + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return [pad] * (max_len - len(x)) + x diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py new file mode 100644 index 000000000000..072bedec3587 --- /dev/null +++ b/colossalai/inference/utils.py @@ -0,0 +1,115 @@ +""" +Utils for model inference +""" +import os +import re +from pathlib import Path +from typing import Optional, Tuple + +import torch +from torch import nn + +from colossalai.testing import free_port + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(self.dtype).cuda() + self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() + + +def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: + """ + Check whether the checkpoint has an index file. + + Args: + checkpoint_path (str): path to the checkpoint. + + Returns: + Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) + """ + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + # check if it is .index.json + reg = re.compile("(.*?).index((\..*)?).json") + if reg.fullmatch(checkpoint_path.name) is not None: + return True, checkpoint_path + else: + return False, None + elif checkpoint_path.is_dir(): + index_files = list(checkpoint_path.glob("*.index.*json")) + + for index_file in index_files: + if "safetensors" in index_file.__str__(): + return True, index_file.__str__() # return the safetensors file first + + if len(index_files) == 1: + return True, index_files[0] + else: + assert ( + len(index_files) == 1 + ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" + return False, None + else: + raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") + + +def get_model_size(model: nn.Module): + """Calculates the total size of the model weights (including biases) in bytes. + Args: + model: The PyTorch model to analyze. + Returns: + The total size of the model weights in bytes. + """ + total_size = 0 + for key, param in model.named_parameters(): + total_size += param.element_size() * param.numel() + return total_size / (1024**3) + + +def find_available_ports(num: int): + try: + free_ports = [free_port() for i in range(num)] + except OSError as e: + print(f"An OS error occurred: {e}") + raise RuntimeError("Error finding available ports") + return free_ports diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 2dff3bcbcc5e..2411b6482ac1 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -8,6 +8,7 @@ FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension, FusedOptimizerCudaExtension, + InferenceOpsCudaExtension, LayerNormCudaExtension, MoeCudaExtension, ScaledMaskedSoftmaxCudaExtension, @@ -21,6 +22,7 @@ "LayerNormLoader", "MoeLoader", "FusedOptimizerLoader", + "InferenceOpsLoader", "ScaledMaskedSoftmaxLoader", "ScaledUpperTriangleMaskedSoftmaxLoader", ] @@ -97,6 +99,10 @@ class FusedOptimizerLoader(KernelLoader): REGISTRY = [FusedOptimizerCudaExtension] +class InferenceOpsLoader(KernelLoader): + REGISTRY = [InferenceOpsCudaExtension] + + class ScaledMaskedSoftmaxLoader(KernelLoader): REGISTRY = [ScaledMaskedSoftmaxCudaExtension] diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 20da71d394bd..4d2c17db1824 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -8,24 +8,24 @@ # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd - from .copy_kv_cache_dest import copy_kv_cache_to_dest - from .fused_layernorm import layer_norm - from .gptq_triton import gptq_fused_linear_triton - from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd - from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd + from .context_attn_unpad import context_attention_unpadded + from .flash_decoding import flash_decoding_attention + from .fused_rotary_embedding import fused_rotary_embedding + from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache + from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding + from .rms_layernorm import rms_layernorm + from .rotary_cache_copy import get_xine_cache from .softmax import softmax - from .token_attention_kernel import token_attention_fwd __all__ = [ - "llama_context_attn_fwd", - "bloom_context_attn_fwd", + "context_attention_unpadded", + "flash_decoding_attention", + "copy_k_to_blocked_cache", + "copy_kv_to_blocked_cache", "softmax", - "layer_norm", - "copy_kv_cache_to_dest", - "token_attention_fwd", - "gptq_fused_linear_triton", - "int8_rotary_embedding_fwd", - "smooth_llama_context_attn_fwd", - "smooth_token_attention_fwd", + "rms_layernorm", + "rotary_embedding", + "fused_rotary_embedding", + "get_xine_cache", + "decoding_fused_rotary_embedding", ] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py deleted file mode 100644 index 1725581d637c..000000000000 --- a/colossalai/kernel/triton/context_attention.py +++ /dev/null @@ -1,434 +0,0 @@ -import math - -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - """ - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 - """ - if triton.__version__ < "2.1.0": - - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - else: - # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 - @triton.jit - def _context_flash_attention_kernel_2( - Q, - K, - V, - sm_scale, - Alibi, - B_Start_Loc, - B_Seqlen, - Out, - kv_group_num, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - if kv_group_num is not None: - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - if kv_group_num is None or kv_group_num == 1: - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - else: - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if Alibi is not None: - alibi_m = tl.load(Alibi + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if Alibi is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @torch.no_grad() - def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - num_warps = 4 if Lk <= 64 else 8 - - if triton.__version__ < "2.1.0": - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - alibi, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - _context_flash_attention_kernel_2[grid]( - q, - k, - v, - sm_scale, - alibi, - b_start_loc, - b_seq_len, - o, - None, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - - return - - @torch.no_grad() - def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - - if triton.__version__ < "2.1.0": - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - kv_group_num = q.shape[1] // k.shape[1] - _context_flash_attention_kernel_2[grid]( - q, - k, - v, - sm_scale, - None, - b_start_loc, - b_seq_len, - o, - kv_group_num, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - - return diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py new file mode 100644 index 000000000000..9c69c4125d62 --- /dev/null +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -0,0 +1,727 @@ +# Applying the FlashAttention V2 as described in: +# "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" +# by Tri Dao, 2023 +# https://github.com/Dao-AILab/flash-attention +# +# Inspired and modified from Triton Tutorial - Fused Attention +# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html + +import torch +import triton +import triton.language as tl + + +# Triton 2.1.0 +@triton.jit +def _fwd_context_paged_attention_kernel( + Q, + K, + V, + O, + KCache, + VCache, + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + context_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the current sequence length from provided context lengths tensor + cur_seq_len = tl.load(context_lengths + cur_seq_idx) + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the current sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + # block table for the current sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the current block table, + # as we have BLOCK_M the same size as the block size. + cur_block_table_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + if block_start_m * BLOCK_M >= cur_seq_len: + return + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0) + offsets_kcachebs = tl.arange(0, BLOCK_SIZE) + offsets_kcache = ( + KCache + + offset_kvcache + + offsets_dmodel[None, :] * stride_cached + + offsets_kcachebs[:, None] * stride_cachebs + ) + tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_vd = offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) + offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here + offsets_vcache = ( + VCache + + offset_kvcache + + offsets_vcachebs[None, :] * stride_cachebs + + offsets_dmodel[:, None] * stride_cached + ) + tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + + return + + +# Triton 2.1.0 +# TODO(yuanheng-zhao): This is a temporary dispatch to use the new layout for kcache +# merge `_fwd_context_paged_attention_kernel_v2` with `_fwd_context_paged_attention_kernel` later +# as the kcache layout has been supported in the whole triton flow. +@triton.jit +def _fwd_context_paged_attention_kernel_v2( + Q, + K, + V, + O, + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_cacheb, # v cache stride(0) - num_blocks + stride_cacheh, # v cache stride(1) - num_kv_heads + stride_cachebs, # v cache stride(2) - block_size + stride_cached, # v cache stride(3) - head_dim + stride_bts, + stride_btb, + context_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, # k stride on the second last dimension + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the current sequence length from provided context lengths tensor + cur_seq_len = tl.load(context_lengths + cur_seq_idx) + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the current sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(block_start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + # block table for the current sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the current block table, + # as we have BLOCK_M the same size as the block size. + cur_block_table_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + if block_start_m * BLOCK_M >= cur_seq_len: + return + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + block_range = tl.arange(0, BLOCK_SIZE) + X_range = tl.arange(0, KCACHE_X) + # unroll the loop aggressively + for split_x in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) + offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0) + # HACK: KCache must be contiguous in order to apply the following offsets calculation + offsets_kcache = ( + KCache + + offset_kvcache + + split_x * BLOCK_SIZE * KCACHE_X + + block_range[:, None] * KCACHE_X + + X_range[None, :] + ) + tl.store(offsets_kcache, k, mask=block_range[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_vd = tl.arange(0, HEAD_DIM) # offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + offsets_n + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) + offsets_vcache = ( + VCache + offset_kvcache + block_range[None, :] * stride_cachebs + offsets_vd[:, None] * stride_cached + ) + tl.store(offsets_vcache, v, mask=block_range[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + + return + + +# Triton 2.1.0 +@triton.jit +def _alibi_fwd_context_paged_attention_kernel( + Q, + K, + V, + O, + KCache, + VCache, + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + alibi_slopes, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + context_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + global_block_start_offest = block_start_m * BLOCK_M + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the current sequence length from provided context lengths tensor + cur_seq_len = tl.load(context_lengths + cur_seq_idx) + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the current sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(global_block_start_offest, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(global_block_start_offest, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + # block table for the current sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the current block table, + # as we have BLOCK_M the same size as the block size. + cur_block_table_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_m = global_block_start_offest + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # load alibi_slope + alibi_slope = tl.load(alibi_slopes + cur_head_idx) + m_alibi_offset = tl.arange(0, BLOCK_M)[:, None] + global_block_start_offest + n_alibi_offset = tl.arange(0, BLOCK_N)[None, :] + + if global_block_start_offest >= cur_seq_len: + return + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + alibi = (n_alibi_offset + block_start_n - m_alibi_offset) * alibi_slope + alibi = tl.where((alibi <= 0) & (m_alibi_offset < cur_seq_len), alibi, float("-inf")) + S_ij += alibi + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kt = global_block_start_offest + tl.arange(0, BLOCK_M) + offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0) + offsets_kcachebs = tl.arange(0, BLOCK_SIZE) + offsets_kcache = ( + KCache + + offset_kvcache + + offsets_dmodel[None, :] * stride_cached + + offsets_kcachebs[:, None] * stride_cachebs + ) + tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_vd = offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) + offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here + offsets_vcache = ( + VCache + + offset_kvcache + + offsets_vcachebs[None, :] * stride_cachebs + + offsets_dmodel[:, None] * stride_cached + ) + tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + + return + + +def context_attention_unpadded( + q: torch.Tensor, # [num_tokens, num_heads, head_dim] + k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] + v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] + k_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] + v_cache: torch.Tensor, # [num_blocks, num_kv_heads, block_size, head_dim] + context_lengths: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], + block_size: int, + output: torch.Tensor = None, # [num_tokens, num_heads, head_dim] + alibi_slopes: torch.Tensor = None, # [num_heads] + max_seq_len: int = None, + sm_scale: int = None, + # NOTE(yuanheng-zhao): the following flag is used to determine whether to use the new layout for kcache + # [num_blocks, num_kv_heads, head_dim // x, block_size, x] - must be contiguous + use_new_kcache_layout: bool = False, +): + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk == Lv + assert Lk in {32, 64, 128, 256} + assert q.shape[0] == k.shape[0] == v.shape[0] + k_cache_shape = k_cache.shape + v_cache_shape = v_cache.shape + if use_new_kcache_layout: + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == v_cache_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3] + ), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + else: + assert k_cache_shape == v_cache_shape, f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + assert context_lengths.shape[0] == block_tables.shape[0] + + num_tokens, num_heads, head_dim = q.shape + num_kv_heads = k.shape[-2] + assert num_kv_heads > 0 and num_heads % num_kv_heads == 0 + num_kv_group = num_heads // num_kv_heads + + num_seqs, max_blocks_per_seq = block_tables.shape + max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len + sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale + output = ( + torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output + ) + + # NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with + # the size of physical cache block (i.e. `block_size`) + assert block_size in {16, 32, 64, 128} + BLOCK_M = BLOCK_N = block_size + + # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton + # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) + grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M)) + + if use_new_kcache_layout: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + assert ( + alibi_slopes is None + ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" + x = k_cache_shape[4] # Intuition: 16 // dtype_size + + _fwd_context_paged_attention_kernel_v2[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + KV_GROUPS=num_kv_group, + BLOCK_SIZE=block_size, + HEAD_DIM=Lk, + KCACHE_X=x, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return output + + if alibi_slopes is not None: + _alibi_fwd_context_paged_attention_kernel[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + alibi_slopes, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + num_kv_group, + block_size, + HEAD_DIM=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + _fwd_context_paged_attention_kernel[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + num_kv_group, + block_size, + HEAD_DIM=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + return output diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py deleted file mode 100644 index b8e6ab1d05ad..000000000000 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py - @triton.jit - def _fwd_copy_kv_cache_dest( - kv_cache_ptr, - dest_index_ptr, - out, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - head_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - ): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(dest_index_ptr + cur_index) - - cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] - k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets - - o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - o_ptrs = out + dest_index * stride_o_bs + o_offsets - - k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) - return - - # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py - @torch.no_grad() - def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): - seq_len = dest_index_ptr.shape[0] - head_num = k_ptr.shape[1] - head_dim = k_ptr.shape[2] - assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" - assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" - - num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( - k_ptr, - dest_index_ptr, - out, - k_ptr.stride(0), - k_ptr.stride(1), - k_ptr.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - head_num, - BLOCK_DMODEL=head_dim, - BLOCK_HEAD=triton.next_power_of_2(head_num), - num_warps=num_warps, - num_stages=2, - ) - return diff --git a/colossalai/kernel/triton/custom_autotune.py b/colossalai/kernel/triton/custom_autotune.py deleted file mode 100644 index 17bb1cf0070c..000000000000 --- a/colossalai/kernel/triton/custom_autotune.py +++ /dev/null @@ -1,176 +0,0 @@ -# code from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/nn_modules/triton_utils/custom_autotune.py - -import builtins -import math -import time -from typing import Dict - -import triton - - -class CustomizedTritonAutoTuner(triton.KernelInterface): - def __init__( - self, - fn, - arg_names, - configs, - key, - reset_to_zero, - prune_configs_by: Dict = None, - nearest_power_of_two: bool = False, - ): - if not configs: - self.configs = [triton.Config({}, num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.nearest_power_of_two = nearest_power_of_two - self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] - if "early_config_prune" in prune_configs_by: - early_config_prune = prune_configs_by["early_config_prune"] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - self.fn = fn - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - - try: - # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses - # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) - except triton.compiler.OutOfResources: - return (float("inf"), float("inf"), float("inf")) - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple(args[i] for i in self.key_idx) - - # This reduces the amount of autotuning by rounding the keys to the nearest power of two - # In my testing this gives decent results, and greatly reduces the amount of tuning required - if self.nearest_power_of_two: - key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) - - if key not in self.cache: - # prune configs - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = { - config: self.perf_model( - **self.nargs, - **kwargs, - **config.kwargs, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) - for config in pruned_configs - } - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): - def decorator(fn): - return CustomizedTritonAutoTuner( - fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two - ) - - return decorator - - -def matmul248_kernel_config_pruner(configs, nargs): - """ - The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. - """ - m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16) - n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16) - k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16) - - used = set() - for config in configs: - block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"]) - block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"]) - block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"]) - group_size_m = config.kwargs["GROUP_SIZE_M"] - - if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: - continue - - used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) - yield triton.Config( - { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - }, - num_stages=config.num_stages, - num_warps=config.num_warps, - ) diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index ac733dede3b7..0012f8ec9f30 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -1,47 +1,533 @@ -# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py +# Applying Flash-Decoding as descibed in +# https://pytorch.org/blog/flash-decoding/ +# by Tri Dao, 2023 import torch +import triton +import triton.language as tl -try: - from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 - from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 - - HAS_LIGHTLLM_KERNEL = True -except: - print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") - HAS_LIGHTLLM_KERNEL = False - - -if HAS_LIGHTLLM_KERNEL: - - def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - - calcu_shape1 = (batch_size, q_head_num, head_dim) - - if getattr(infer_state, "mid_o", None) is None: - infer_state.mid_o = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], - dtype=torch.float32, - device="cuda", - ) - infer_state.mid_o_logexpsum = torch.empty( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - mid_o = infer_state.mid_o - mid_o_logexpsum = infer_state.mid_o_logexpsum - - flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.block_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, + +# Triton 2.1.0 +@triton.jit +def _flash_decoding_fwd_kernel( + Q, # [batch_size * q_len, head_num, head_dim] + KCache, # [num_blocks, num_kv_heads, block_size, head_dim] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim], + # or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided + block_tables, # [batch_size, max_blocks_per_sequence] + mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] + kv_seq_len, # [batch_size] + q_len, + batch_size, + kv_group_num, + x, + sm_scale, + stride_qt, + stride_qh, + stride_qd, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcd, + stride_vcb, + stride_vch, + stride_vcs, + stride_vcd, + stride_bts, + stride_btb, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_o_lset, + stride_mid_o_lseh, + stride_mid_o_lseb, + BLOCK_KV: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len + if cur_seq_idx >= batch_size: + return + cur_token_off = (cur_token_idx % q_len) - q_len + 1 + cur_head_idx = tl.program_id(1) + block_start_kv = tl.program_id(2) # for splitting k/v + + # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same + # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) + # and then support calculating multiple kv cache blocks on an instance + tl.static_assert(BLOCK_KV == BLOCK_SIZE) + # get the current (kv) sequence length + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off + if block_start_kv * BLOCK_KV >= cur_kv_seq_len: + return + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_block = tl.arange(0, BLOCK_SIZE) + + # block table for the current sequence + block_table_ptr = block_tables + cur_seq_idx * stride_bts + # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) + # cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) + cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) + cur_occupied_size = tl.where( + (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE + ) + tl.device_assert(cur_occupied_size >= 0) + + offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + q = tl.load(Q + offsets_q) + cur_kv_head_idx = cur_head_idx // kv_group_num + offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch + offsets_k = ( + offset_kvcache + + (offsets_dmodel[None, :] // x) * stride_kcsplit_x + + (offsets_dmodel[None, :] % x) * stride_kcd + + offsets_block[:, None] * stride_kcs + ) + k_cur_block = tl.load(KCache + offsets_k) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kvcache, + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_vcs, stride_vcd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE, HEAD_DIM), + order=(0, 1), + ) + v_cur_block = tl.load(V_block_ptr) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + # use block size of the paged/blocked kv cache + S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, + # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. + # Refer to https://github.com/openai/triton/discussions/895 + S_ij += tl.sum(q[None, :] * k_cur_block, 1) + S_ij *= sm_scale + S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf")) + + m = tl.max(S_ij, 0) + S_ij -= m + p_ij_hat = tl.exp(S_ij) + l_i = tl.sum(p_ij_hat, 0) + p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) + acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) + acc = acc / l_i + + offsets_mid_o = ( + cur_token_idx * stride_mid_ot + + cur_head_idx * stride_mid_oh + + block_start_kv * stride_mid_ob + + offsets_dmodel * stride_mid_od + ) + tl.store(mid_o + offsets_mid_o, acc) + offsets_mid_o_lse = ( + cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + ) + # logsumexp l_i^(j) = m^(j) + log(l_i^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i)) + + +# Triton 2.1.0 +@triton.jit +def _alibi_flash_decoding_fwd_kernel( + Q, # [batch_size * q_len, head_num, head_dim] + KCache, # [num_blocks, num_kv_heads, block_size, head_dim] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + block_tables, # [batch_size, max_blocks_per_sequence] + mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] + kv_seq_len, # [batch_size] + q_len, + batch_size, + alibi_slopes, + stride_qt, + stride_qh, + stride_qd, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_o_lset, + stride_mid_o_lseh, + stride_mid_o_lseb, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len + if cur_seq_idx >= batch_size: + return + cur_token_off = (cur_token_idx % q_len) - q_len + 1 + cur_head_idx = tl.program_id(1) + block_start_kv = tl.program_id(2) # for splitting k/v + + # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same + # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) + # and then support calculating multiple kv cache blocks on an instance + tl.static_assert(BLOCK_KV == BLOCK_SIZE) + # get the current (kv) sequence length + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off + if block_start_kv * BLOCK_KV >= cur_kv_seq_len: + return + + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + q = tl.load(Q + offsets_q) + # block table for the current sequence + block_table_ptr = block_tables + cur_seq_idx * stride_bts + # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) + # cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) + cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) + cur_occupied_size = tl.where( + (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE + ) + tl.device_assert(cur_occupied_size >= 0) + + cur_kv_head_idx = cur_head_idx // KV_GROUPS + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + K_block_ptr = tl.make_block_ptr( + base=KCache + offset_kvcache, + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), + offsets=(0, 0), + block_shape=(BLOCK_SIZE, HEAD_DIM), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kvcache, + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), + offsets=(0, 0), + block_shape=(BLOCK_SIZE, HEAD_DIM), + order=(0, 1), + ) + k_cur_block = tl.load(K_block_ptr) + v_cur_block = tl.load(V_block_ptr) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + # use block size of the paged/blocked kv cache + S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + alibi_slope = tl.load(alibi_slopes + cur_head_idx) + position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) + + # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, + # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. + # Refer to https://github.com/openai/triton/discussions/895 + S_ij += tl.sum(q[None, :] * k_cur_block, 1) + S_ij *= sm_scale + S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset) + S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf")) + + m = tl.max(S_ij, 0) + S_ij -= m + p_ij_hat = tl.exp(S_ij) + l_i = tl.sum(p_ij_hat, 0) + p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) + acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) + acc = acc / l_i + + offsets_mid_o = ( + cur_token_idx * stride_mid_ot + + cur_head_idx * stride_mid_oh + + block_start_kv * stride_mid_ob + + offsets_dmodel * stride_mid_od + ) + tl.store(mid_o + offsets_mid_o, acc) + offsets_mid_o_lse = ( + cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + ) + # logsumexp l_i^(j) = m^(j) + log(l_i^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i)) + + +# Triton 2.1.0 +@triton.jit +def _flash_decoding_fwd_reduce_kernel( + mid_o, # [batch_size, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size, head_num, kv_split_num] + O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] + kv_seq_len, + q_len, + batch_size, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_o_lset, + stride_o_lseh, + stride_o_lseb, + stride_ot, + stride_oh, + stride_od, + BLOCK_KV: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_token_off = (cur_token_idx % q_len) - q_len + 1 + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off + offsets_dmodel = tl.arange(0, HEAD_DIM) + + # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have + # BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted. + kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV + m_i = float("-inf") # max logic + l_i = 0.0 # sum exp + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel + offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh + for block_i in range(0, kv_split_num, 1): + mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob) + lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb) + m_ij = tl.maximum(m_i, lse) + scale = tl.exp(m_i - m_ij) + acc = acc * scale + lse -= m_ij + exp_logic = tl.exp(lse) + acc += exp_logic * mid_o_block + l_i = scale * l_i + exp_logic + m_i = m_ij + + acc = acc / l_i + offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel + tl.store(O + offsets_O, acc.to(O.type.element_ty)) + return + + +# Decoding Stage +# Used with blocked KV Cache (PagedAttention) +def flash_decoding_attention( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_seq_len: torch.Tensor, + block_tables: torch.Tensor, + block_size: int, + max_seq_len_in_batch: int = None, + output: torch.Tensor = None, + mid_output: torch.Tensor = None, + mid_output_lse: torch.Tensor = None, + alibi_slopes: torch.Tensor = None, + sm_scale: int = None, + kv_group_num: int = 1, + q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment. + use_new_kcache_layout: bool = False, +): + """ + Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. + + Args: + q (torch.Tensor): [bsz * q_len, num_heads, head_dim] + q_len > 1 only for verification process in speculative-decoding. + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] + v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] + kv_seq_len (torch.Tensor): [batch_size] + records the (kv) sequence lengths incorporating past kv sequence lengths. + block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] + max_seq_len_in_batch (int): Maximum sequence length in the batch. + output (torch.Tensor): [bsz, num_heads * head_dim] + mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim] + Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. + q_len > 1 only for verification process in speculative-decoding. + mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num] + Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. + q_len > 1 only for verification process in speculative-decoding. + alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding. + block_size (int): Size of each block in the blocked key/value cache. + num_kv_group (int, optional): Number of key/value groups. Defaults to 1. + q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). + Defaults to 1. + use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False. + + Returns: + Output tensor with shape [bsz * q_len, num_heads * head_dim] + """ + q = q.squeeze() if q.dim() == 4 else q + assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" + n_tokens, num_heads, head_dim = q.shape + assert n_tokens % q_len == 0, "Invalid q_len" + bsz = n_tokens // q_len + + assert head_dim in {32, 64, 128, 256} + assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, " + f"batch size {bsz}" + ) + assert k_cache.size(-2) == v_cache.size(-2) == block_size, ( + f"Got incompatible block size on kv caches:\n" + f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, " + f"v_cache block_size {v_cache.size(-2)}" + ) + + # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v + # For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`) + assert block_size in {16, 32, 64, 128} + BLOCK_KV = block_size + + sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale + max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch + # For compatibility (TODO revise modeling in future) + kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV + + if mid_output is None: + mid_output = torch.empty( + (bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + if mid_output_lse is None: + mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + if output is None: + # A hack to prevent `view` operation in modeling + output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device) + + assert ( + mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num + ), "Incompatible kv split number of intermediate output tensors" + assert ( + mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens + ), f"Incompatible first dimension of output tensors" + + # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton + # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) + grid = lambda META: ( + triton.next_power_of_2(bsz * q_len), + num_heads, + triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]), + ) + + if alibi_slopes is not None: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + assert ( + not use_new_kcache_layout + ), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready" + + _alibi_flash_decoding_fwd_kernel[grid]( + q, + k_cache, + v_cache, + block_tables, + mid_output, + mid_output_lse, + kv_seq_len, + q_len, + bsz, + alibi_slopes, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), + sm_scale, + KV_GROUPS=kv_group_num, + BLOCK_KV=block_size, + BLOCK_SIZE=block_size, + HEAD_DIM=head_dim, + ) + else: + # For KCache and VCache with the same layout + x = head_dim + kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3) + # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x] + if use_new_kcache_layout: + assert ( + k_cache.dim() == 5 + and k_cache.shape[1] == v_cache.shape[1] + and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3] + ), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}" + x = k_cache.size(-1) + kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:] + + _flash_decoding_fwd_kernel[grid]( + q, + k_cache, + v_cache, + block_tables, + mid_output, + mid_output_lse, + kv_seq_len, + q_len, + bsz, + kv_group_num, + x, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + kcsplit_x_stride, + kcs_stride, + kcd_stride, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), + BLOCK_KV=block_size, + BLOCK_SIZE=block_size, + HEAD_DIM=head_dim, + ) + + grid = (triton.next_power_of_2(bsz * q_len), num_heads) + _flash_decoding_fwd_reduce_kernel[grid]( + mid_output, + mid_output_lse, + output, + kv_seq_len, + q_len, + bsz, + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), + output.stride(0), + head_dim, + 1, + BLOCK_KV=block_size, + HEAD_DIM=head_dim, + ) + + return output diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py deleted file mode 100644 index 24083b050808..000000000000 --- a/colossalai/kernel/triton/fused_layernorm.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - # CREDITS: These functions are adapted from the Triton tutorial - # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html - - @triton.jit - def _layer_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, - ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute mean - mean = 0 - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.0) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask) - b = tl.load(B + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w + b - # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) - - @torch.no_grad() - def layer_norm(x, weight, bias, eps): - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # enqueue kernel - _layer_norm_fwd_fused[(M,)]( - x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps - ) - return y diff --git a/colossalai/kernel/triton/fused_rotary_embedding.py b/colossalai/kernel/triton/fused_rotary_embedding.py new file mode 100644 index 000000000000..cf2a70f7b64e --- /dev/null +++ b/colossalai/kernel/triton/fused_rotary_embedding.py @@ -0,0 +1,181 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_rotary_emb( + q, + k, + cos_cache, + sin_cache, + cumsum_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_dim_stride, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + K_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_ELEMENTS: tl.constexpr, +): + block_head_index = tl.program_id(0) + block_group_index = tl.program_id(1) + group_token_index = tl.program_id(2) + idx = block_group_index * BLOCK_SIZE + group_token_index + + # original seq_idx and pos + cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) + ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) + cos = tl.load( + cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride + ) # [1,HEAD_DIM//2] + sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride) + + cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * q_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_q1 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + off_k0 = ( + idx * k_token_stride + + cur_head_range[None, :, None] * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + idx * q_token_stride + + cur_head_range[None, :, None] * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + q_0 = tl.load( + q + off_q0, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + q_1 = tl.load( + q + off_q1, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + k_0 = tl.load( + k + off_k0, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + k_1 = tl.load( + k + off_k1, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + other=0.0, + ) + + out_q0 = q_0 * cos - q_1 * sin + out_q1 = k_0 * sin + k_1 * cos + + out_k0 = q_0 * cos - q_1 * sin + out_k1 = k_0 * sin + k_1 * cos + # concat + tl.store( + q + off_q0, + out_q0, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + ) + tl.store( + q + off_q1, + out_q1, + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), + ) + + tl.store( + k + off_k0, + out_k0, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + ) + tl.store( + k + off_k1, + out_k1, + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), + ) + + +def fused_rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + lengths, +): + """ + Args: + q: query tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, head_num, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + lengths [num_seqs] + """ + q_total_tokens, q_head_num, head_dim = q.shape + assert q.size(0) == k.size(0) + BLOCK_HEAD = 4 + BLOCK_SIZE = 8 + cumsum_lens = torch.cumsum(lengths, dim=0) + + grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE) + + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + q_token_stride = q.stride(0) + q_head_stride = q.stride(1) + head_dim_stride = q.stride(2) + + k_token_stride = k.stride(0) + k_head_stride = k.stride(1) + + k_head_num = q.shape[1] + + cos_token_stride = cos.stride(0) + cos_dim_stride = cos.stride(1) + + fused_rotary_emb[grid]( + q, + k, + cos, + sin, + cumsum_lens, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_dim_stride, + q_total_tokens, + Q_HEAD_NUM=q_head_num, + K_HEAD_NUM=k_head_num, + HEAD_DIM=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SIZE=BLOCK_SIZE, + N_ELEMENTS=triton.next_power_of_2(q_total_tokens), + num_warps=num_warps, + ) diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py deleted file mode 100644 index 2dc1fe04438a..000000000000 --- a/colossalai/kernel/triton/gptq_triton.py +++ /dev/null @@ -1,543 +0,0 @@ -# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ - -import torch -import triton -import triton.language as tl - -from .custom_autotune import autotune, matmul248_kernel_config_pruner - - -@triton.jit -def tanh(x): - # Tanh is just a scaled sigmoid - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def cosh(x): - exp_x = tl.exp(x) - return (exp_x + 1.0 / exp_x) * 0.5 - - -# a Triton implementation of the most used activations -# See for instance http://arxiv.org/abs/1606.08415 for an overview - - -# ReLU -@triton.jit -def relu(x): - """ - ReLU_ activation function - - .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html - """ - return tl.where(x >= 0, x, 0.0) - - -@triton.jit -def squared_relu(x): - """ - Squared ReLU activation, as proposed in the Primer_ paper. - - .. _Primer: https://arxiv.org/abs/2109.08668 - """ - x_sq = x * x - return tl.where(x > 0.0, x_sq, 0.0) - - -@triton.jit -def star_relu(x): - """ - Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. - - .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf - """ - x_sq = x * x - return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 - - -# Leaky ReLU -@triton.jit -def leaky_relu(x): - """ - LeakyReLU_ activation - - .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html - """ - return tl.where(x >= 0.0, x, 0.01 * x) - - -@triton.jit -def gelu(x): - """ - GeLU_ activation - Gaussian error linear unit - - .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf - """ - return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) - - -@triton.jit -def smelu(x): - """ - SmeLU_ activation - Smooth ReLU with beta=2.0 - - .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf - """ - beta = 2.0 - - relu = tl.where(x >= beta, x, 0.0) - return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) - - -@triton.jit -def silu(x): - return x * tl.sigmoid(x) - - -@autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def cai_gptq_matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - bias_ptr, - residual_ptr, - M, - N, - K, - bits, - maxq, - gptq_group_size, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - QKV_FUSED: tl.constexpr, - ADD_BIAS: tl.constexpr, - ADD_RESIDUAL: tl.constexpr, - ACT_TYPE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - NK = K - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) - qkv_offset = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - # offs_bk = offs_k + qkv_offset * NK - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = ( - b_ptr - + qkv_offset * N * NK // infearure_per_bits - + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - # g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] - zeros_ptrs = ( - zeros_ptr - + qkv_offset * NK * N // gptq_group_size // infearure_per_bits - + (offs_bn[None, :] // infearure_per_bits) - ) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - g_idx_base = tl.arange(0, BLOCK_SIZE_K) - g_idx_base = g_idx_base // gptq_group_size - g_idx = g_idx_base - # tl.device_print("gidx, ", g_idx) - - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 - - for k in range(0, num_pid_k): - # g_idx = tl.load(g_ptrs) - # if (k + 1) * BLOCK_SIZE_K > currend_group_end: - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros).to(tl.float16) * scales # Scale and shift - accumulator += tl.dot(a, b) - - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size - # if (k + 2) * BLOCK_SIZE_K > currend_group_end: - - c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - - if ADD_BIAS: - bias_mask = offs_bn < N - offs_bn += qkv_offset * N - bias_ptrs = bias_ptr + stride_cn * offs_bn - bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - accumulator += bias[None, :] - - if ACT_TYPE == 1: - accumulator = relu(accumulator) - elif ACT_TYPE == 2: - accumulator = gelu(accumulator) - elif ACT_TYPE == 3: - accumulator = silu(accumulator) - - if ADD_RESIDUAL: - residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - res = tl.load(residual_ptrs, mask=c_mask, other=0.0) - accumulator += res - - tl.store(c_ptrs, accumulator, mask=c_mask) - - -# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ -@autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4 - ), - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, -) -@triton.jit -def cai_gptq_idx_matmul_248_kernel( - a_ptr, - b_ptr, - c_ptr, - scales_ptr, - zeros_ptr, - idx_ptr, - bias_ptr, - residual_ptr, - M, - N, - K, - bits, - maxq, - gptq_group_size, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - QKV_FUSED: tl.constexpr, - ADD_BIAS: tl.constexpr, - ADD_RESIDUAL: tl.constexpr, - ACT_TYPE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - NK = K - - # if QKV_FUSED: - # NK = K//3 - # else: - # NK = K - # NK = K - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) - qkv_offset = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - # offs_bk = offs_k + qkv_offset * NK - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = ( - b_ptr - + qkv_offset * N * NK // infearure_per_bits - + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - # g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] - zeros_ptrs = ( - zeros_ptr - + qkv_offset * NK * N // gptq_group_size // infearure_per_bits - + (offs_bn[None, :] // infearure_per_bits) - ) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - g_ptrs = idx_ptr + offs_k - g_idx = tl.load(g_ptrs) - # tl.device_print("gidx, ", g_idx) - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros).to(tl.float16) * scales # Scale and shift - accumulator += tl.dot(a, b) - - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - - if ADD_BIAS: - bias_mask = offs_bn < N - offs_bn += qkv_offset * N - bias_ptrs = bias_ptr + stride_cn * offs_bn - bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - accumulator += bias[None, :] - - if ACT_TYPE == 1: - accumulator = relu(accumulator) - elif ACT_TYPE == 2: - accumulator = gelu(accumulator) - elif ACT_TYPE == 3: - accumulator = silu(accumulator) - - if ADD_RESIDUAL: - residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - res = tl.load(residual_ptrs, mask=c_mask, other=0.0) - accumulator += res - - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def gptq_fused_linear_triton( - input, - qweight, - scales, - qzeros, - bias, - residual, - bits, - maxq, - gptq_group_size, - qkv_fused, - add_bias, - add_residual, - g_idx=None, - act_type=0, -): - # print("gptq fused ", qkv_fused, add_bias, add_residual) - assert input.is_cuda, "input is not in cuda" - assert qweight.is_cuda, "qweight is not in cuda" - assert scales.is_cuda, "scales is not in cuda" - assert qzeros.is_cuda, "qzeros is not in cuda" - - with torch.cuda.device(input.device): - if qkv_fused: - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]) - * 3, - ) - output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16) - else: - grid = lambda META: ( - triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), - ) - output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) - # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) - if g_idx is None: - cai_gptq_matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - bias, - residual, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - gptq_group_size, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - QKV_FUSED=qkv_fused, - ADD_BIAS=add_bias, - ADD_RESIDUAL=add_residual, - ACT_TYPE=act_type, - ) - else: - cai_gptq_idx_matmul_248_kernel[grid]( - input, - qweight, - output, - scales, - qzeros, - g_idx, - bias, - residual, - input.shape[0], - qweight.shape[1], - input.shape[1], - bits, - maxq, - gptq_group_size, - input.stride(0), - input.stride(1), - qweight.stride(0), - qweight.stride(1), - output.stride(0), - output.stride(1), - scales.stride(0), - qzeros.stride(0), - QKV_FUSED=qkv_fused, - ADD_BIAS=add_bias, - ADD_RESIDUAL=add_residual, - ACT_TYPE=act_type, - ) - if qkv_fused: - return output.view(3, input.shape[0], qweight.shape[1]) - else: - return output diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py deleted file mode 100644 index 537dd164d1ab..000000000000 --- a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py +++ /dev/null @@ -1,117 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - input_scale, - output_scale, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - q0 = q0.to(tl.float32) * input_scale - q1 = q1.to(tl.float32) * input_scale - - out0 = (q0 * cos - q1 * sin) / output_scale - out1 = (q0 * sin + q1 * cos) / output_scale - - out0 = out0.to(tl.int8) - out1 = out1.to(tl.int8) - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - input_scale, - output_scale, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py new file mode 100644 index 000000000000..77397b5cb6cf --- /dev/null +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -0,0 +1,296 @@ +import torch +import triton +import triton.language as tl + + +# Triton 2.1.0 +# supports two types of cache layouts +# 1. [num_blocks, num_kv_heads, block_size, head_dim] +# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x] +@triton.jit +def _copy_to_kcache_seqlen_n_kernel( + K, # K or V + KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x] + BLOCK_TABLES, + seq_lengths, + stride_kt, + stride_kh, + stride_kd, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcx, + stride_bts, + stride_btb, + block_size, + n_tokens, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, +): + # `n_tokens` is used to specify the number of tokens to copy for each sequence + # When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid, + # `seq_lengths` must be the lengths of sequences counting the number of tokens to copy + # E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9] + # for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14]. + # When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // n_tokens + # `cur_token_shift` is only valid and functional when `n_tokens` > 1 + cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1)) + cur_kv_head_idx = tl.program_id(1) + split_x_idx = tl.program_id(2) + + past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift + last_bt_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) + offset_last_block = past_kv_seq_len % block_size + offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X) + offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd + k = tl.load(K + offsets_k) + offsets_kcache = ( + block_id * stride_kcb + + cur_kv_head_idx * stride_kch + + split_x_idx * stride_kcsplit_x + + offset_last_block * stride_kcs + + tl.arange(0, KCACHE_X) + ) + tl.store(KCache + offsets_kcache, k) + return + + +# Triton 2.1.0 +@triton.jit +def _copy_to_kvcache_seqlen1_kernel( + K, + V, + KCache, + VCache, + BLOCK_TABLES, + context_lengths, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_kcb, + stride_kch, + stride_kcsplit_x, + stride_kcs, + stride_kcd, + stride_vcb, + stride_vch, + stride_vcs, + stride_vcd, + stride_bts, + stride_btb, + block_size, + HEAD_DIM: tl.constexpr, + KCACHE_X: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + cur_kv_head_idx = tl.program_id(1) + + past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1 + last_bt_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) + offsets_in_last_block = past_kv_seq_len % block_size + + range_x = tl.arange(0, KCACHE_X) + offsets_dmodel_x_partition = tl.arange(0, KCACHE_X) + + for split_x in tl.static_range(HEAD_DIM // KCACHE_X): + offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X) + offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd + k = tl.load(K + offsets_k) + offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd + v = tl.load(V + offsets_v) + + offsets_kcache = ( + block_id * stride_kcb + + cur_kv_head_idx * stride_kch + + split_x * stride_kcsplit_x + + offsets_in_last_block * stride_kcs + + range_x + ) + tl.store(KCache + offsets_kcache, k) + offsets_vcache = ( + block_id * stride_vcb + + cur_kv_head_idx * stride_vch + + offsets_in_last_block * stride_vcs + + offsets_dmodel_x_partition * stride_vcd + ) + tl.store(VCache + offsets_vcache, v) + return + + +def copy_k_to_blocked_cache( + k: torch.Tensor, + k_cache: torch.Tensor, + kv_lengths: torch.Tensor, + block_tables: torch.Tensor, + n: int = 1, + use_new_kcache_layout: bool = False, +): + """ + Copy keys or values to the blocked key/value cache during decoding stage. + + Args: + k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. + [bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache. + new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] + kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + n (int): Number of tokens to copy for each sequence. Default to 1. + use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. + """ + assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." + if k.dim() == 4: + k = k.reshape(-1, k.size(-2), k.size(-1)) + k_shape = k.shape + bsz, num_kv_heads, head_dim = k_shape + # NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim] + if n > 1: + assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied" + bsz = bsz // n + + assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " + f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" + ) + + k_cache_shape = k_cache.shape + # Modify if the shape of kv cahce is changed. + block_size = k_cache_shape[-2] + + x = head_dim + stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) + if use_new_kcache_layout: + # when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x] + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == k_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == k_shape[2] + ), f"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}" + x = k_cache.size(-1) + stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] + + num_warps = 8 if head_dim > 128 else 4 + grid = (bsz * n, num_kv_heads, head_dim // x) + _copy_to_kcache_seqlen_n_kernel[grid]( + k, + k_cache, + block_tables, + kv_lengths, + k.stride(0), + k.stride(1), + k.stride(2), + k_cache.stride(0), + k_cache.stride(1), + stride_kcsplit_x, + stride_kcs, + stride_kcd, + block_tables.stride(0), + block_tables.stride(1), + block_size, + n_tokens=n, + HEAD_DIM=head_dim, + KCACHE_X=x, + num_warps=num_warps, + ) + + +def copy_kv_to_blocked_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_lengths: torch.Tensor, + block_tables: torch.Tensor, + use_new_kcache_layout: bool = False, +): + """ + Copy keys or values to the blocked key/value cache during decoding stage. + + Args: + k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1. + v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1. + k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache. + v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache. + kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. + block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False. + """ + k_cache_shape = k_cache.shape + v_cache_shape = v_cache.shape + + if use_new_kcache_layout: + assert ( + len(k_cache_shape) == 5 + and k_cache_shape[1] == v_cache_shape[1] + and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3] + ), f"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + else: + assert k.size(-1) == k_cache_shape[-1], "Incompatible head dim" + assert ( + k_cache_shape == v_cache_shape + ), f"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}" + assert v.size(-1) == v_cache_shape[-1], "Incompatible head dim" + + k = k.squeeze(1) if k.dim() == 4 else k + assert k.dim() == 3, f"Incompatible k dim {k.dim()}" + v = v.squeeze(1) if v.dim() == 4 else v + assert v.dim() == 3, f"Incompatible v dim {v.dim()}" + + bsz, num_kv_heads, head_dim = k.shape + assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( + f"Got incompatible batch size (number of seqs):\n" + f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " + f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" + ) + + # Modify if the shape of kv cahce is changed. + block_size = k_cache.size(-2) + + x = head_dim + stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3) + if use_new_kcache_layout: + x = k_cache.size(-1) + stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:] + + num_warps = 8 if head_dim > 128 else 4 + grid = (bsz, num_kv_heads) + _copy_to_kvcache_seqlen1_kernel[grid]( + k, + v, + k_cache, + v_cache, + block_tables, + kv_lengths, + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + k_cache.stride(0), + k_cache.stride(1), + stride_kcsplit_x, + stride_kcs, + stride_kcd, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + block_size, + HEAD_DIM=head_dim, + KCACHE_X=x, + num_warps=num_warps, + ) diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py new file mode 100644 index 000000000000..e0da816bdc90 --- /dev/null +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -0,0 +1,659 @@ +import warnings +from typing import Optional + +import torch +import triton +import triton.language as tl + +""" +# Base autotune if needed +@triton.autotune( + configs=[ + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=4), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":8,},num_warps=8), + triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":8,},num_warps=8), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=16), + triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=32), + triton.Config({'BLOCK_HEAD':16,"BLOCK_TOKENS":16,},num_warps=4), + triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":16,},num_warps=8), + ], + key=['HEAD_DIM','q_total_tokens','Q_HEAD_NUM'] +) +""" + + +@triton.jit +def rotary_embedding_kernel( + q, + k, + cos, + sin, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + KV_GROUP_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, # token range length +): + cur_head_idx = tl.program_id(0) + cur_token_block_idx = tl.program_id(1) + + tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride + loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + + off_q0 = ( + tokens_range[:, None, None] * q_token_stride + + cur_head_idx * q_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_q1 = ( + tokens_range[:, None, None] * q_token_stride + + cur_head_idx * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + loaded_q0 = tl.load( + q + off_q0, + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + loaded_q1 = tl.load( + q + off_q1, + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] + out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] + + tl.store( + q + off_q0, + out_q0, + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + q + off_q1, + out_q1, + mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + + handle_kv = cur_head_idx % KV_GROUP_NUM == 0 + if handle_kv: + k_head_idx = cur_head_idx // KV_GROUP_NUM + off_k0 = ( + tokens_range[:, None, None] * k_token_stride + + k_head_idx * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + tokens_range[:, None, None] * k_token_stride + + k_head_idx * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + loaded_k0 = tl.load( + k + off_k0, + mask=(tokens_range[:, None, None] < q_total_tokens), + other=0.0, + ) + loaded_k1 = tl.load( + k + off_k1, + mask=(tokens_range[:, None, None] < q_total_tokens), + other=0.0, + ) + out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] + out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] + tl.store( + k + off_k0, + out_k0, + mask=(tokens_range[:, None, None] < q_total_tokens), + ) + tl.store( + k + off_k1, + out_k1, + mask=(tokens_range[:, None, None] < q_total_tokens), + ) + + +@triton.jit +def fused_rotary_embedding_kernel( + q, + k, + cos, + sin, + kv_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cacheb_stride, + cacheh_stride, + cachebs_stride, + cached_stride, + bts_stride, + btb_stride, + block_size, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + K_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, +): + block_head_index = tl.program_id(0) + block_token_index = tl.program_id(1) + + tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_q1 = ( + tokens_range[:, None, None] * q_token_stride + + head_range[None, :, None] * q_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + off_k0 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + + dim_range0[None, None, :] * head_dim_stride + ) + off_k1 = ( + tokens_range[:, None, None] * k_token_stride + + head_range[None, :, None] * k_head_stride + + dim_range1[None, None, :] * head_dim_stride + ) + + loaded_q0 = tl.load( + q + off_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + loaded_q1 = tl.load( + q + off_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k0 = tl.load( + k + off_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + loaded_k1 = tl.load( + k + off_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + other=0.0, + ) + + off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) + + out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :] + out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :] + + out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] + out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens)) + offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride + + kv_range0 = ( + block_ids[:, None, None, None] * cacheb_stride + + head_range[None, :, None, None] * cacheh_stride + + offsets_in_last_block[:, None, None, None] + + dim_range0[None, None, None, :] * cached_stride + ) + kv_range1 = ( + block_ids[:, None, None, None] * cacheb_stride + + head_range[None, :, None, None] * cacheh_stride + + offsets_in_last_block[:, None, None, None] + + dim_range1[None, None, None, :] * cached_stride + ) + + tl.store( + kv_cache + kv_range0, + out_k0[:, :, None, :], + ) + tl.store( + kv_cache + kv_range1, + out_k1[:, :, None, :], + ) + + # concat + tl.store( + q + off_q0, + out_q0, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + q + off_q1, + out_q1, + mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k0, + out_k0, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + tl.store( + k + off_k1, + out_k1, + mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), + ) + + +@triton.jit +def fused_rotary_embedding_kernel_v2( + q, + k, + cos, + sin, + kv_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cacheb_stride, + cacheh_stride, + cachebs_stride, + cached_stride, + bts_stride, + btb_stride, + block_size, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + block_head_index = tl.program_id(0) + if block_head_index >= Q_HEAD_NUM: + return + block_token_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride + off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride + off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride + off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride + + loaded_q0 = tl.load( + q + off_q0, + ) + loaded_q1 = tl.load( + q + off_q1, + ) + + loaded_k0 = tl.load( + k + off_k0, + ) + + loaded_k1 = tl.load( + k + off_k1, + ) + + off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + + out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin + out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens)) + offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride + + kv_range0 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range0 * cached_stride + ) + kv_range1 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range1 * cached_stride + ) + + tl.store( + kv_cache + kv_range0, + out_k0, + ) + tl.store( + kv_cache + kv_range1, + out_k1, + ) + + # concat + tl.store( + q + off_q0, + out_q0, + ) + tl.store( + q + off_q1, + out_q1, + ) + + +@triton.jit +def decoding_fused_rotary_embedding_kernel( + q, + k, + v, + cos, + sin, + k_cache, + v_cache, + BLOCK_TABLES, + context_lengths, + x, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + kcb_stride, + kch_stride, + kcsplit_x_stride, + kcs_stride, + kcd_stride, + vcb_stride, + vch_stride, + vcs_stride, + vcd_stride, + bts_stride, + btb_stride, + block_size, + KV_GROUP_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_head_idx = tl.program_id(0) + cur_token_idx = tl.program_id(1) + + dim_range = tl.arange(0, HEAD_DIM) + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride + off_q0 = off_q + dim_range0 * head_dim_stride + off_q1 = off_q + dim_range1 * head_dim_stride + + loaded_q0 = tl.load(q + off_q0) + loaded_q1 = tl.load(q + off_q1) + off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride + loaded_cos = tl.load(cos + off_cos_sin) + loaded_sin = tl.load(sin + off_cos_sin) + + out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin + out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos + tl.store(q + off_q0, out_q0) + tl.store(q + off_q1, out_q1) + + handle_kv = cur_head_idx % KV_GROUP_NUM == 0 + if handle_kv: + cur_k_head_idx = cur_head_idx // KV_GROUP_NUM + off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride + off_k0 = off_kv + dim_range0 * head_dim_stride + off_k1 = off_kv + dim_range1 * head_dim_stride + loaded_k0 = tl.load(k + off_k0) + loaded_k1 = tl.load(k + off_k1) + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos + + # NOTE The precondition here is that it's only for unpadded inputs during decoding stage, + # and so that we could directly use the token index as the sequence index + past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride) + offsets_in_last_block = past_kv_seq_len % block_size + offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride + k_range0 = ( + offsets_cache_base + + offsets_in_last_block * kcs_stride + + (dim_range0 // x) * kcsplit_x_stride + + (dim_range0 % x) * kcd_stride + ) + k_range1 = ( + offsets_cache_base + + offsets_in_last_block * kcs_stride + + (dim_range1 // x) * kcsplit_x_stride + + (dim_range1 % x) * kcd_stride + ) + tl.store(k_cache + k_range0, out_k0) + tl.store(k_cache + k_range1, out_k1) + + off_v = off_kv + dim_range * head_dim_stride + loaded_v = tl.load(v + off_v) + v_range = ( + block_ids * vcb_stride + + cur_k_head_idx * vch_stride + + offsets_in_last_block * vcs_stride + + dim_range * vcd_stride + ) + tl.store(v_cache + v_range, loaded_v) + + +def rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_cache: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + kv_lengths: Optional[torch.Tensor] = None, +): + """ + Args: + q: query tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, kv_head_num, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim] + kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] + block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] + """ + q_total_tokens, q_head_num, head_dim = q.shape + assert q.size(0) == k.size(0) + BLOCK_TOKENS = 4 + + if head_dim >= 512: + num_warps = 16 + elif head_dim >= 256: + num_warps = 8 + else: + num_warps = 4 + + k_head_num = k.size(1) + q_token_stride, q_head_stride, head_dim_stride = q.stride() + k_token_stride, k_head_stride, _ = k.stride() + cos_token_stride, cos_stride = cos.stride() + + assert q_head_num % k_head_num == 0 + kv_group_num = q_head_num // k_head_num + + if k_cache == None: + grid = lambda META: ( + q_head_num, + triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), + ) + rotary_embedding_kernel[grid]( + q, + k, + cos, + sin, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + q_total_tokens, + Q_HEAD_NUM=q_head_num, + KV_GROUP_NUM=kv_group_num, + HEAD_DIM=head_dim, + BLOCK_TOKENS=BLOCK_TOKENS, + num_warps=num_warps, + ) + else: + warnings.warn("Fused rotary embedding Triton kernel will be deprecated as the new kcache layout is supported") + grid = (triton.next_power_of_2(q_head_num), q_total_tokens) + fused_rotary_embedding_kernel_v2[grid]( + q, + k, + cos, + sin, + k_cache, + block_tables, + kv_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + k_cache.size(-2), + q_total_tokens, + Q_HEAD_NUM=q_head_num, + HEAD_DIM=head_dim, + num_warps=num_warps, + ) + return + + +def decoding_fused_rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_cache: Optional[torch.Tensor] = None, + v_cache: Optional[torch.Tensor] = None, + block_tables: Optional[torch.Tensor] = None, + kv_lengths: Optional[torch.Tensor] = None, + use_new_kcache_layout: bool = False, +): + """ + Args: + q: query tensor, [total_tokens, head_num, head_dim] + k: key tensor, [total_tokens, kv_head_num, head_dim] + v: value tensor, [total tokens, kv_head_num, head_dim] + cos: cosine for rotary embedding, [max_position_len, head_dim] + sin: sine for rotary embedding, [max_position_len, head_dim] + k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim] + v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim] + kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz] + block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence] + """ + q_total_tokens, q_head_num, head_dim = q.shape + assert q.size(0) == k.size(0) == v.size(0) + + if head_dim >= 512: + num_warps = 16 + elif head_dim >= 256: + num_warps = 8 + else: + num_warps = 4 + k_head_num = k.size(1) + kv_group_num = q_head_num // k_head_num + + # For KCache and VCache with the same layout + x = head_dim + kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3) + # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x] + if use_new_kcache_layout: + assert ( + k_cache.dim() == 5 + and k_cache.shape[1] == v_cache.shape[1] + and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3] + ), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}" + x = k_cache.size(-1) + kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:] + + grid = (q_head_num, q_total_tokens) + decoding_fused_rotary_embedding_kernel[grid]( + q, + k, + v, + cos, + sin, + k_cache, + v_cache, + block_tables, + kv_lengths, + x, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + k_cache.stride(0), + k_cache.stride(1), + kcsplit_x_stride, + kcs_stride, + kcd_stride, + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + k_cache.size(-2), + KV_GROUP_NUM=kv_group_num, + HEAD_DIM=head_dim, + num_warps=num_warps, + ) + return diff --git a/colossalai/kernel/triton/rms_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py new file mode 100644 index 000000000000..fb320750340f --- /dev/null +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -0,0 +1,116 @@ +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + # CREDITS: These functions are adapted from the Triton tutorial + # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + @triton.jit + def _rmsnorm_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # This triton kernel implements Root Mean Square Layer Norm (RMSNorm). + + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + @triton.jit + def _rmsnorm_with_residual_kernel( + X, # pointer to the input + Y, # pointer to the output + R, # pointer to the residual + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # This triton kernel implements Root Mean Square Layer Norm (RMSNorm). + + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + R += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32) + r = tl.where(cols < N, r, 0.0) + x = x + r + _var += x * x + mask = cols < N + tl.store(X + cols, x.to(tl.float16), mask=mask) + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + def rms_layernorm(x, weight, eps, norm_output=None, residual=None): + # allocate output + y = ( + x * 0 if norm_output is None else norm_output + ) # to make the operation non-functional, store y as the intermediate activation + M, N = x.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > MAX_FUSED_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + # heuristics for number of warps + num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32) + + # enqueue kernel + if residual is None: + _rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + else: + _rmsnorm_with_residual_kernel[(M,)]( + x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) + return y, x diff --git a/colossalai/kernel/triton/rotary_cache_copy.py b/colossalai/kernel/triton/rotary_cache_copy.py new file mode 100644 index 000000000000..48dc7de4377e --- /dev/null +++ b/colossalai/kernel/triton/rotary_cache_copy.py @@ -0,0 +1,147 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def prefill_cache_kernel( + cos_cache, + sin_cache, + cumsum_lengths, + cos_output, + sin_output, + cache_stride, + hidden_stride, + total_length, + HIDDEN_DIM: tl.constexpr, + N_ELEMENTS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + idx0 = tl.program_id(axis=0) + idx1 = tl.program_id(axis=1) + idx = idx0 * BLOCK_SIZE + idx1 + + # original seq_idx and pos + cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) + ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) + cos_cache_part = tl.load( + cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length + ) + sin_cache_part = tl.load( + sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length + ) + tl.store( + cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, + cos_cache_part, + mask=idx < total_length, + ) + tl.store( + sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, + sin_cache_part, + mask=idx < total_length, + ) + + +@triton.jit +def decoding_cache_kernel( + cos_cache, + sin_cache, + lengths, + cos_output, + sin_output, + cache_stride, + hidden_stride, + HIDDEN_DIM: tl.constexpr, + NUM_SEQS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,] + cos_cache_part = tl.load( + cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, + mask=idx[:, None] < NUM_SEQS, + ) + sin_cache_part = tl.load( + sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride, + mask=idx[:, None] < NUM_SEQS, + ) + tl.store( + cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), + cos_cache_part, + mask=idx[:, None] < NUM_SEQS, + ) + tl.store( + sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride), + sin_cache_part, + mask=idx[:, None] < NUM_SEQS, + ) + + +def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False): + """ + Transform cos/sin cache into no pad sequence, with two different modes. + Args: + lengths: shape(num_seqs,), stores lenghth of each sequence. + cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model. + is_prompts: bool, mark if in prefill mode. + For prefill mode: + cos/sin cache for each sequence is equal to its length. + For decoding mode: + cos/sin cache is only needed for the last token. + """ + assert cos_cache.shape[1] == sin_cache.shape[1] + _, hidden_dim = cos_cache.shape + num_seqs = lengths.numel() + + if hidden_dim >= 256: + num_warps = 16 + elif hidden_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + cache_stride = cos_cache.stride(0) + hidden_stride = cos_cache.stride(1) + + if is_prompts: + BLOCK_SIZE = 16 + total_length = lengths.sum().item() + cumsum_lens = torch.cumsum(lengths, dim=0) + cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) + sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) + grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE) + prefill_cache_kernel[grid]( + cos_cache, + sin_cache, + cumsum_lens, + cos_output, + sin_output, + cache_stride, + hidden_stride, + total_length, + HIDDEN_DIM=hidden_dim, + N_ELEMENTS=triton.next_power_of_2(num_seqs), + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + else: + BLOCK_SIZE = 4 + nlengths = torch.as_tensor(lengths) - 1 + cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device) + sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device) + grid = (triton.cdiv(num_seqs, BLOCK_SIZE),) + decoding_cache_kernel[grid]( + cos_cache, + sin_cache, + nlengths, + cos_output, + sin_output, + cache_stride, + hidden_stride, + HIDDEN_DIM=hidden_dim, + NUM_SEQS=num_seqs, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return cos_output, sin_output diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py deleted file mode 100644 index 50d6786bd940..000000000000 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ /dev/null @@ -1,164 +0,0 @@ -import torch - -try: - import triton - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - from .qkv_matmul_kernel import qkv_gemm_4d_kernel - from .softmax import softmax_kernel - - # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 - def self_attention_forward_without_fusion( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float - ): - r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels - Args: - q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) - scale: the float scale value which is used to multiply with Q*K^T before doing softmax - - Return: - output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) - """ - assert len(q.shape) == 4, "the shape of q val must be 4" - batches, M, H, K = q.shape - assert q.shape == k.shape, "the shape of q and the shape of k must be equal" - assert q.shape == v.shape, "the shape of q and the shape of v must be equal" - assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" - - N = k.shape[1] - - # head_size * num_of_head - d_model = q.shape[-1] * q.shape[-2] - - score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - q, - k, - score_output, - M, - N, - K, - q.stride(0), - q.stride(2), - q.stride(1), - q.stride(3), - k.stride(0), - k.stride(2), - k.stride(3), - k.stride(1), - score_output.stride(0), - score_output.stride(1), - score_output.stride(2), - score_output.stride(3), - scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, - ) - - softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype) - score_output_shape = score_output.shape - - score_output = score_output.view(-1, score_output.shape[-1]) - n_rows, n_cols = score_output.shape - - if n_rows <= 350000: - block_size = max(triton.next_power_of_2(n_cols), 2) - num_warps = 4 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: - num_warps = 4 - - softmax_kernel[(n_rows,)]( - softmax_output, - score_output, - score_output.stride(0), - n_cols, - mask_ptr=input_mask, - num_warps=num_warps, - BLOCK_SIZE=block_size, - ) - - else: - # NOTE: change softmax kernel functions to make it suitable for large size dimension - softmax_output = torch.nn.functional.softmax(score_output, dim=-1) - softmax_output = softmax_output.view(*score_output_shape) - - batches, H, M, K = softmax_output.shape - N = v.shape[-1] - - output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - softmax_output, - v, - output, - M, - N, - K, - softmax_output.stride(0), - softmax_output.stride(1), - softmax_output.stride(2), - softmax_output.stride(3), - v.stride(0), - v.stride(2), - v.stride(1), - v.stride(3), - output.stride(0), - output.stride(2), - output.stride(1), - output.stride(3), - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=64, - BLOCK_SIZE_K=64, - GROUP_SIZE_M=8, - scale=-1, - ) - return output.view(batches, -1, d_model) - - # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 - def self_attention_compute_using_triton( - qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False - ): - assert qkv.is_contiguous() - assert alibi is None, "current triton self-attention does not support alibi" - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model : d_model * 2] - v = qkv[:, :, d_model * 2 :] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale) - - return data_output_triton diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py deleted file mode 100644 index 071de58e20c0..000000000000 --- a/colossalai/kernel/triton/smooth_attention.py +++ /dev/null @@ -1,652 +0,0 @@ -import math - -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - """ - this functions are modified from https://github.com/ModelTC/lightllm - """ - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - v = v.to(tl.float16) * v_input_scale.to(tl.float16) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @torch.no_grad() - def smooth_llama_context_attn_fwd( - q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len - ): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk, "context process only supports equal query, key, value length" - assert Lk == Lv, "context process only supports equal query, key, value length" - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / math.sqrt(Lk) - batch, head = b_seq_len.shape[0], q.shape[1] - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - - _context_flash_attention_kernel[grid]( - q, - k, - v, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_1_kernel( - Q, - K, - q_input_scale, - k_input_scale, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - q_input_scale, - k_input_scale, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - q = q.to(tl.float16) * q_input_scale.to(tl.float16) - - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - k = k.to(tl.float16) * k_input_scale.to(tl.float16) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @torch.no_grad() - def token_attn_fwd_1( - q, - k, - attn_out, - q_input_scale, - k_input_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - alibi=None, - ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) - - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 - - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, - k, - q_input_scale, - k_input_scale, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - else: - _token_attn_1_kernel[grid]( - q, - k, - q_input_scale, - k_input_scale, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, - ) - v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2( - prob, - v, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - ): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def smooth_token_attention_fwd( - q, - k, - v, - attn_out, - q_input_scale, - k_input_scale, - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=None, - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - q_input_scale, - k_input_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) - - prob = torch.empty_like(att_m_tensor) - - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - token_attn_fwd_2( - prob, - v, - attn_out.view(calcu_shape1), - v_input_scale, - pv_output_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = None - - return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py deleted file mode 100644 index d8ac278c77dd..000000000000 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ /dev/null @@ -1,244 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm - - -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -try: - from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import ( - token_att_fwd as lightllm_bloom_token_att_fwd, - ) - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import ( - token_att_fwd as lightllm_llama_token_att_fwd, - ) - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import ( - token_att_fwd2 as lightllm_llama_token_att_fwd2, - ) - from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import ( - token_softmax_fwd as lightllm_llama_token_softmax_fwd, - ) - - HAS_TRITON_TOKEN_ATTENTION = True -except ImportError: - print("unable to import lightllm kernels") - HAS_TRITON_TOKEN_ATTENTION = False - -if HAS_TRITON: - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - if alibi is None: - lightllm_llama_token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - else: - lightllm_bloom_token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = torch.empty_like(att_m_tensor) - - lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - lightllm_llama_token_att_fwd2( - prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch - ) - prob = None - return - - -class Llama2TokenAttentionForwards: - @staticmethod - @triton.jit - - # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 - def _fwd_kernel( - Logics, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_b_loc_b, - stride_b_loc_s, - other_kv_index, # avoid nan information - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s - - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=other_kv_index, - ) - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), - ) - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 - @staticmethod - @torch.no_grad() - def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - Llama2TokenAttentionForwards._fwd_kernel[grid]( - logics, - v, - o, - b_loc, - b_start_loc, - b_seq_len, - max_input_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - b_loc.stride(0), - b_loc.stride(1), - other_kv_index, - kv_group_num, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return - - # this is the interface of llama2 attn forward - @staticmethod - @torch.no_grad() - def token_attn( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index - ): - total_token_num = k.shape[0] - batch_size, head_num, head_dim = q.shape - calcu_shape1 = (batch_size, head_num, head_dim) - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - lightllm_llama_token_att_fwd( - q, - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - if triton.__version__ == "2.0.0": - prob = torch.empty_like(att_m_tensor) - lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - - lightllm_llama_token_att_fwd2( - prob, - v, - attn_out.view(calcu_shape1), - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - ) - - prob = None - return - - elif triton.__version__ >= "2.1.0": - Llama2TokenAttentionForwards.token_softmax_reducev_fwd( - att_m_tensor, - v, - attn_out.view(calcu_shape1), - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - other_kv_index, - ) - else: - raise Exception("not support triton version") diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 21d44d4244d3..736ffc5e4ea2 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -1,3 +1,4 @@ +import copy import os from typing import Callable, Optional, Union @@ -74,6 +75,24 @@ def new_from_pretrained( subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + + kwargs.pop("state_dict", None) + kwargs.pop("from_tf", False) + kwargs.pop("from_flax", False) + kwargs.pop("output_loading_info", False) + kwargs.pop("trust_remote_code", None) + kwargs.pop("low_cpu_mem_usage", None) + kwargs.pop("device_map", None) + kwargs.pop("max_memory", None) + kwargs.pop("offload_folder", None) + kwargs.pop("offload_state_dict", False) + kwargs.pop("load_in_8bit", False) + kwargs.pop("load_in_4bit", False) + kwargs.pop("quantization_config", None) + kwargs.pop("adapter_kwargs", {}) + kwargs.pop("adapter_name", "default") + kwargs.pop("use_flash_attention_2", False) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) if len(kwargs) > 0: @@ -108,6 +127,10 @@ def new_from_pretrained( **kwargs, ) else: + config = copy.deepcopy(config) + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp model_kwargs = kwargs if commit_hash is None: diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py index 019a678ceb02..10239071fe1b 100644 --- a/colossalai/legacy/inference/hybridengine/engine.py +++ b/colossalai/legacy/inference/hybridengine/engine.py @@ -133,7 +133,7 @@ def inference(self, input_list): """ assert isinstance( input_list, (BatchEncoding, dict) - ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + ), f"Only accept BatchEncoding or dict as input, but got {input_list.__class__.__name__}." if isinstance(input_list, BatchEncoding): input_list = input_list.data out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index cb7eceae4d25..9b77774aaeaa 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -249,7 +249,6 @@ class VocabParallelEmbedding1D(PaddingParallelModule): The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: :: - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 6d99efc19bbf..a6d19edf5b53 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -22,6 +22,7 @@ def forward( ignore_index: int, process_group: ProcessGroup, vocab_size: int, + dtype=torch.float32, ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: @@ -34,7 +35,7 @@ def forward( Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is [batch_size, seq_len, vocab_size] - labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + target (:class:`torch.Tensor`): The labels of the vocabulary, shape is [batch_size, seq_len] Returns: @@ -86,7 +87,7 @@ def forward( dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) - sum_exp_logits = torch.sum(exp_logits, dim=-1) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) # calculate the loss @@ -97,9 +98,10 @@ def forward( loss = torch.sum(loss).div_(num_non_zero) # calculate the softmax - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits[target == ignore_index] = 0.0 ctx.save_for_backward(exp_logits, mask, masked_target_1d) + ctx.dtype = dtype return loss @@ -114,11 +116,11 @@ def backward(ctx, grad_output): partion_vocab_size = grad_logits.shape[-1] grad_logits_2d = grad_logits.view(-1, partion_vocab_size) - update = 1.0 - mask.view(-1).float() + update = 1.0 - mask.view(-1).float().to(ctx.dtype) grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None, None + return grad_logits, None, None, None, None, None def cross_entropy_1d( @@ -127,5 +129,6 @@ def cross_entropy_1d( ignore_index: int = -100, process_group: ProcessGroup = None, vocab_size: int = None, + dtype: torch.dtype = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index c4f326364596..bf74d0833cb0 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -10,6 +10,7 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -27,6 +28,8 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -354,7 +357,7 @@ def bloom_for_causal_lm_forward( past_key_values = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + lm_logits = self.lm_head(hidden_states).contiguous() loss = None if labels is not None: @@ -365,10 +368,21 @@ def bloom_for_causal_lm_forward( shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = lm_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + else: + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1065,3 +1079,79 @@ def forward( ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import BloomForCausalLM + + def forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + past_key_values = None + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + new_vocab_size = lm_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index df3b09c71cbc..a43bdf4814ed 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -14,6 +14,7 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -31,6 +32,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: def build_falcon_alibi_tensor( @@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward( loss = None if labels is not None: # Shift so that tokens < n predict n + labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = shift_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + else: + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length), + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -747,3 +764,79 @@ def falcon_for_question_answering_forward( else: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import FalconForCausalLM + + def forward( + self: FalconForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + past_key_values = None + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + new_vocab_size = shift_logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bfa995645ef1..c49458dbdf55 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -389,6 +389,7 @@ def gpt2_lmhead_model_forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, ) else: loss = loss_fct(shift_logits, shift_labels) @@ -1294,6 +1295,7 @@ def forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.transformer.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 8a6a7cf17e08..d6f10ffafec7 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -332,6 +332,7 @@ def llama_for_causal_lm_forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -768,6 +769,7 @@ def forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 93da71abb4a2..5f96ebe3d5cd 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -281,6 +281,7 @@ def mistral_for_causal_lm_forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -701,6 +702,7 @@ def forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 5282e2eaac22..f10860fef558 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -348,6 +348,7 @@ def opt_for_causal_lm_forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.decoder.dtype, ) else: loss_fct = CrossEntropyLoss() @@ -988,6 +989,7 @@ def forward( shift_labels, process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, + dtype=self.model.decoder.dtype, ) if not return_dict: diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4f076d23368b..724a6b77ccc4 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -16,6 +16,7 @@ get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, + get_lm_forward_with_dist_cross_entropy, ) from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -287,12 +288,18 @@ def module_policy(self): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=BloomForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 23d6efbeb27a..e5c16733752e 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,7 +7,12 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward +from ..modeling.falcon import ( + FalconPipelineForwards, + build_falcon_alibi_tensor_fn, + get_lm_forward_with_dist_cross_entropy, + get_tp_falcon_decoder_layer_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -233,12 +238,19 @@ def module_policy(self): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=FalconForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=FalconForCausalLM + ) + else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 98e72d8b346a..453e8d23ebdb 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -125,9 +125,3 @@ def _turn_on_all_optimization(self): # It may also slow down training when seq len is small. Plz enable manually. # self.enable_sequence_parallelism = True # self.enable_sequence_overlap = True - - def _infer(self): - """ - Set default params for inference. - """ - # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b3991c4f0d9b..b54c5827316e 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,6 +1,7 @@ import os from typing import Dict, List, Tuple +import torch.distributed as dist import torch.nn as nn from torch import Tensor @@ -36,7 +37,11 @@ class ShardFormer: """ def __init__(self, shard_config: ShardConfig): - self.coordinator = DistCoordinator() + self.is_distributed = dist.is_initialized() + if self.is_distributed: + self.coordinator = DistCoordinator() + else: + self.coordinator = None self.shard_config = shard_config def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 41110612c307..5878abbaaa19 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,7 @@ ## 新闻 +* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference) * [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source) * [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series) * [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) @@ -74,11 +75,9 @@
  • 推理
  • @@ -370,6 +369,19 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 ## 推理 +### Colossal-Inference +

    + +

    + +

    + +

    + + - AI大模型推理速度部分接近翻倍,与vLLM的离线推理性能相比 +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/inference) +[[博客]](https://hpc-ai.com/blog/colossal-inference) + ### Grok-1

    @@ -388,25 +400,6 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 - [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): 开源解决方案打破了多轮对话的 LLM 长度限制,推理性能提高了46% - -

    - -

    - -- [Energon-AI](https://github.com/hpcaitech/EnergonAI) :用相同的硬件推理加速50% - -

    - -

    - -- [OPT推理服务](https://colossalai.org/docs/advanced_tutorials/opt_service): 体验1750亿参数OPT在线推理服务 - -

    - -

    - -- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): 降低1760亿参数BLOOM模型部署推理成本超10倍 -

    (返回顶端)

    ## 安装 diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py deleted file mode 100644 index a23ab500a6c2..000000000000 --- a/examples/inference/benchmark_llama.py +++ /dev/null @@ -1,167 +0,0 @@ -import argparse -import time - -import torch -import torch.distributed as dist -import transformers - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.inference import InferenceEngine -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -GIGABYTE = 1024**3 -MEGABYTE = 1024 * 1024 - -CONFIG_MAP = { - "toy": transformers.LlamaConfig(num_hidden_layers=4), - "llama-7b": transformers.LlamaConfig( - hidden_size=4096, - intermediate_size=11008, - num_attention_heads=32, - num_hidden_layers=32, - num_key_value_heads=32, - max_position_embeddings=2048, - ), - "llama-13b": transformers.LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_attention_heads=40, - num_hidden_layers=40, - num_key_value_heads=40, - max_position_embeddings=2048, - ), - "llama2-7b": transformers.LlamaConfig( - hidden_size=4096, - intermediate_size=11008, - num_attention_heads=32, - num_hidden_layers=32, - num_key_value_heads=32, - max_position_embeddings=4096, - ), - "llama2-13b": transformers.LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_attention_heads=40, - num_hidden_layers=40, - num_key_value_heads=40, - max_position_embeddings=4096, - ), -} - - -def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) - attention_mask = torch.ones_like(input_ids) - data = dict(input_ids=input_ids, attention_mask=attention_mask) - return data - - -def print_details_info(outputs, model_config, args, whole_end2end): - msg: str = "" - - if dist.get_rank() == 0: - msg += "-------Perf Summary-------\n" - if args.verbose: - timestamps = outputs[1] - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) - ) - end2end.append(timestamp[-1] - timestamp[0]) - - mb_avg_end2end = sum(end2end) / len(end2end) - mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size) - - msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n" - msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n" - msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n" - msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n" - - whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) - num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) - num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.dtype in ["fp16", "bf16"]: - num_bytes = 2 - else: - num_bytes = 4 - - msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n" - msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n" - msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n" - msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" - - if torch.cuda.is_available(): - msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" - msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" - msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" - - print(msg) - - -def benchmark_inference(args): - config = CONFIG_MAP[args.model] - model = transformers.LlamaForCausalLM(config) - if dist.get_rank() == 0: - print("Model loaded") - engine = InferenceEngine( - pp_size=args.pp_size, - tp_size=args.tp_size, - dtype=args.dtype, - micro_batch_size=args.mb_size, - model=model, - verbose=args.verbose, - max_batch_size=args.batch_size, - max_input_len=args.seq_len, - max_output_len=args.output_len, - ) - data = data_gen(args.batch_size, args.seq_len) - - N_WARMUP_STEPS = 2 - - for _ in range(N_WARMUP_STEPS): - engine.generate(data) - - torch.cuda.synchronize() - whole_end2end = time.time() - outputs = engine.generate(data) - torch.cuda.synchronize() - whole_end2end = time.time() - whole_end2end - - print_details_info(outputs, model.config, args, whole_end2end) - - -def hybrid_inference(rank, world_size, port, args): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - benchmark_inference(args) - - -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def benchmark(args): - spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-m", - "--model", - default="toy", - help="the size of model", - choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], - ) - parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") - parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") - parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") - parser.add_argument("--pp_size", type=int, default=1, help="pipeline size") - parser.add_argument("--tp_size", type=int, default=1, help="pipeline size") - parser.add_argument("--output_len", type=int, default=128, help="Output length") - parser.add_argument("--dtype", type=str, default="fp16", help="data type") - parser.add_argument("-v", "--verbose", default=False, action="store_true") - args = parser.parse_args() - benchmark(args) diff --git a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py new file mode 100644 index 000000000000..18fe76cf0688 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py @@ -0,0 +1,133 @@ +import torch +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + +from colossalai.inference.modeling.layers.attention import PagedAttention +from colossalai.kernel.triton import context_attention_unpadded +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + +HEAD_DIM = 32 +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_LEN"], + x_vals=[2**i for i in range(8, 13)], + # x_vals=[x for x in range(256, 8192, 256)], + line_arg="provider", + line_vals=["torch", "triton", "triton_new_klayout"], + line_names=["Torch", "Triton", "Triton_new_klayout"], + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}", + args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, + ) +] + + +@triton.testing.perf_report(configs) +def bench_kernel( + bsz, + KV_LEN, + provider, + block_size: int, + kv_group_num: int, + same_context_len: bool, +): + num_attn_heads = 16 + max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) + max_seq_len = block_size * max_num_blocks_per_seq + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(context_lengths).item() + + qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) + qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) + q_unpad = q_unpad.contiguous() + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM) + k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) + v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) + q_padded, k_padded, v_padded = ( + q_padded.to(device=device), + k_padded.to(device=device), + v_padded.to(device=device), + ) + q_padded = q_padded.transpose(1, 2) + k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num) + v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num) + # This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0 + ) + attn_mask = attn_mask.to(device=q_padded.device) + fn = lambda: torch_attn_ref( + q_padded, + k_padded, + v_padded, + attn_mask, + bsz, + max_seq_len, + max_seq_len, + num_attn_heads, + num_kv_heads, + HEAD_DIM, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton": + k_cache_triton = torch.zeros_like(k_cache_ref) + v_cache_triton = torch.zeros_like(v_cache_ref) + fn = lambda: context_attention_unpadded( + q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton_new_klayout": + # NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) + # to be applied around the cuda and triton kernels. + # Here we want to make sure it does not cause downgrade in performance. + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, HEAD_DIM // x, block_size, x) + k_cache_triton = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache_triton = torch.zeros_like(v_cache_ref) + fn = lambda: context_attention_unpadded( + q_unpad, + k_unpad, + v_unpad, + k_cache_triton, + v_cache_triton, + context_lengths, + block_tables, + block_size, + use_new_kcache_layout=True, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + bench_kernel.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_decoding_attn.py b/examples/inference/benchmark_ops/benchmark_decoding_attn.py new file mode 100644 index 000000000000..4471ddadab9c --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_decoding_attn.py @@ -0,0 +1,143 @@ +import torch + +from colossalai.kernel.triton import flash_decoding_attention +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + convert_kv_unpad_to_padded, + create_attention_mask, + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + torch_attn_ref, +) +from tests.test_infer.test_kernels.triton.test_decoding_attn import prepare_data + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + +Q_LEN = 1 +HEAD_DIM = 128 +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_LEN"], + x_vals=[2**i for i in range(8, 14)], + # x_vals=[x for x in range(256, 8192, 256)], + line_arg="provider", + line_vals=["torch", "triton", "triton_new_kcache_layout"], + line_names=["Torch", "Triton", "Triton New KCache Layout"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms", + plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}", + args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, + ) +] + + +@triton.testing.perf_report(configs) +def bench_kernel( + bsz, + KV_LEN, + provider, + block_size: int, + kv_group_num: int, + same_context_len: bool, +): + num_attn_heads = 16 + max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) + max_seq_len = block_size * max_num_blocks_per_seq + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + q, k_unpad, v_unpad, kv_lengths = prepare_data( + bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device + ) + max_seq_len_in_b = kv_lengths.max().item() # for random lengths + # the maximum block length splitted on kv should be the kv cache block size + kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + sm_scale = 1.0 / (HEAD_DIM**0.5) + output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) + mid_output = torch.empty( + size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b) + torch_padding_mask = create_attention_mask(kv_lengths, bsz, Q_LEN, max_seq_len_in_b, q.device) + fn = lambda: torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + bsz, + Q_LEN, + max_seq_len_in_b, + num_attn_heads, + num_kv_heads, + HEAD_DIM, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton": + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + fn = lambda: flash_decoding_attention( + # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), + # refer to attention forward in modeling. + q.squeeze(2), + k_cache, + v_cache, + kv_lengths, + block_tables, + block_size, + max_seq_len_in_b, + output, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + ) # [bsz, 1, num_heads, head_dim] + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + elif provider == "triton_new_kcache_layout": + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + fn = lambda: flash_decoding_attention( + # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), + # refer to attention forward in modeling. + q.squeeze(2), + k_cache, + v_cache, + kv_lengths, + block_tables, + block_size, + max_seq_len_in_b, + output, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + use_new_kcache_layout=True, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + bench_kernel.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py new file mode 100644 index 000000000000..da85f4230ac2 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py @@ -0,0 +1,185 @@ +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import flash_decoding_attention +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + generate_caches_and_block_tables_vllm, +) + +try: + import triton # noqa +except ImportError: + print("please install triton from https://github.com/openai/triton") + +inference_ops = InferenceOpsLoader().load() + +# Triton benchmark plot attributions +configs = [ + triton.testing.Benchmark( + x_names=["MAX_NUM_BLOCKS_PER_SEQ"], + x_vals=[2**i for i in range(2, 8)], + line_arg="provider", + line_vals=[ + "vllm_paged_decoding_attention", + "triton_flash_decoding_attention", + "cuda_flash_decoding_attention", + ], + line_names=[ + "vllm_paged_decoding_attention", + "triton_flash_decoding_attention", + "cuda_flash_decoding_attention", + ], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms", + plot_name=f"FlashDecodingAttention benchmarking results", + args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2}, + ) +] + + +def prepare_data( + BATCH_SIZE: int, + HEAD_SIZE: int, + NUM_ATTN_HEADS: int, + NUM_KV_HEADS: int, + MAX_SEQ_LEN: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + # returns + # q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE] + # k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE] + kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device) + num_tokens = torch.sum(kv_lengths).item() + + q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths + + +@triton.testing.perf_report(configs) +def benchmark_flash_decoding_attention( + provider: str, + BATCH_SIZE: int, + BLOCK_SIZE: int, + MAX_NUM_BLOCKS_PER_SEQ: int, + HEAD_SIZE: int, + KV_GROUP_NUM: int, +): + try: + from vllm._C import ops as vllm_ops + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + + warmup = 10 + rep = 1000 + + dtype = torch.float16 + + NUM_ATTN_HEADS = 16 + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + triton_k_cache, triton_v_cache, _ = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + alibi_slopes = None + kv_scale = 1.0 + + mid_output = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device + ) + mid_output_lse = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device + ) + exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) + max_logits = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) + + if provider == "vllm_paged_decoding_attention": + alibi_slopes = None + fn = lambda: vllm_ops.paged_attention_v1( + output, + q.squeeze(2), + vllm_k_cache, + vllm_v_cache, + NUM_KV_HEADS, + sm_scale, + block_tables, + kv_seq_lengths, + BLOCK_SIZE, + max_seq_len_across_batch, + alibi_slopes, + "auto", + kv_scale, + ) + elif provider == "triton_flash_decoding_attention": + fn = lambda: flash_decoding_attention( + q.squeeze(2), + triton_k_cache, + triton_v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + output, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=KV_GROUP_NUM, + ) # [bsz, 1, num_heads, head_dim] + elif provider == "cuda_flash_decoding_attention": + fn = lambda: inference_ops.flash_decoding_attention( + output, + q.squeeze(2), + k_cache, + v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + mid_output, + exp_sums, + max_logits, + alibi_slopes, + sm_scale, + ) + else: + raise ValueError("Undefined provider.") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + return ms + + +if __name__ == "__main__": + benchmark_flash_decoding_attention.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py new file mode 100644 index 000000000000..80939f5a1e50 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -0,0 +1,137 @@ +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding +from tests.test_infer.test_kernels.triton.kernel_utils import ( + mock_alloc_block_table_and_kvcache_v2, + mock_alloc_block_table_and_kvcache_v3, + mock_alloc_single_token, +) + +inference_ops = InferenceOpsLoader().load() + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + + +BATCH = 16 +configs = [ + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[2**i for i in range(4, 11)], + line_arg="provider", + line_vals=[ + "triton_rotary_emb_func", + "triton_fused_rotary_emb_func", + "triton_fused_rotary_emb_func_new_kcache_layout", + "cuda_rotary_emb_func", + "cuda_fused_rotary_emb_func", + ], + line_names=[ + "triton_rotary_emb_func", + "triton_fused_rotary_emb_func", + "triton_fused_rotary_emb_func(new layout)", + "cuda_rotary_emb_func", + "cuda_fused_rotary_emb_func", + ], + styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")], + ylabel="ms", + plot_name=f"rotary_emb-batch-{BATCH}", + args={"num_kv_heads": 16}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_rotary_emb( + provider: str, + num_tokens: int, + num_kv_heads: int, +): + BATCH_SIZE = 16 + SEQ_LEN = num_tokens // BATCH_SIZE + max_num_blocks_per_seq = 8 + block_size = 64 + warmup = 10 + rep = 100 + + head_dim = 4096 + dtype = torch.float16 + + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + v = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + + cos_shape = (num_tokens, head_dim // 2) + + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + x = 16 // torch.tensor([], dtype=dtype).element_size() + new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda") + + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + _ = mock_alloc_block_table_and_kvcache_v3( + k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton_rotary_emb_func": + fn = lambda: [ + rotary_embedding(new_q, new_k, cos, sin), + copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables + ), + ] + elif provider == "triton_fused_rotary_emb_func": + fn = lambda: decoding_fused_rotary_embedding( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths + ) + elif provider == "triton_fused_rotary_emb_func_new_kcache_layout": + x = 16 // torch.tensor([], dtype=dtype).element_size() + kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v3( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + block_tables = block_tables.to(device="cuda") + fn = lambda: decoding_fused_rotary_embedding( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True + ) + elif provider == "cuda_rotary_emb_func": + fn = lambda: [ + inference_ops.rotary_embedding(new_q, new_k, cos, sin, True), + inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables), + ] + elif provider == "cuda_fused_rotary_emb_func": + fn = lambda: inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True + ) + else: + raise ValueError("Undefined provider") + + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles) + return ms, min_ms, max_ms + + +if __name__ == "__main__": + benchmark_rotary_emb.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py new file mode 100644 index 000000000000..0232cb90e677 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py @@ -0,0 +1,91 @@ +import torch + +from colossalai.inference.modeling.layers.attention import copy_to_cache +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import copy_kv_to_blocked_cache +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout +from tests.test_infer.test_kernels.triton.test_kvcache_copy import prepare_data + +try: + import triton # noqa +except ImportError: + print("please install triton from https://github.com/openai/triton") + +inference_ops = InferenceOpsLoader().load() + +HEAD_DIM = 128 +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_SEQ_LEN"], + x_vals=[2**i for i in range(8, 13)], + line_arg="provider", + line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"], + line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}", + args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_kvcache_copy( + provider: str, + bsz: int, + block_size: int, + max_seq_len: int, + KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) + num_kv_heads: int, + same_context_len: bool, +): + dtype = torch.float16 + device = get_current_device() + + assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" + + new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + HEAD_DIM, + block_size, + max_seq_len // block_size, + same_context_len, + KV_SEQ_LEN, + device=device, + dtype=dtype, + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch_copy_func": + fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding") + elif provider == "triton_copy_func": + fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) + elif provider == "triton_new_kcache_layout": + # NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout + fn = lambda: copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True + ) + elif provider == "cuda_copy_func": + _, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout( + bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype + ) + new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k + new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v + fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + return ms, min_ms, max_ms + + +if __name__ == "__main__": + benchmark_kvcache_copy.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_rmsnorm.py b/examples/inference/benchmark_ops/benchmark_rmsnorm.py new file mode 100644 index 000000000000..deddac8b127a --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_rmsnorm.py @@ -0,0 +1,87 @@ +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import rms_layernorm + +try: + import triton # noqa +except ImportError: + print("please install triton from https://github.com/openai/triton") + +inference_ops = InferenceOpsLoader().load() + +# Triton benchmark plot attributions +configs = [ + triton.testing.Benchmark( + x_names=["SEQUENCE_TOTAL"], + x_vals=[i for i in range(128, 1025, 128)], + line_arg="provider", + line_vals=[ + "vllm_rms_layernorm", + "triton_rms_layernorm", + "cuda_rms_layernorm", + "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", + ], + line_names=[ + "vllm_rms_layernorm", + "triton_rms_layernorm", + "cuda_rms_layernorm", + "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", + ], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")], + ylabel="ms", + plot_name=f"RMSNorm benchmarking results", + args={"HIDDEN_SIZE": 5120}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_rms_layernorm( + provider: str, + SEQUENCE_TOTAL: int, + HIDDEN_SIZE: int, +): + try: + from vllm.model_executor.layers.layernorm import RMSNorm + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + + warmup = 10 + rep = 1000 + + dtype = torch.float16 + eps = 1e-5 + x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) + w_shape = (x_shape[-1],) + residual = torch.rand(x_shape, dtype=dtype, device="cuda") + weight = torch.ones(w_shape, dtype=dtype, device="cuda") + vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + if provider == "vllm_rms_layernorm": + fn = lambda: vllm_norm(x) + elif provider == "triton_rms_layernorm": + fn = lambda: rms_layernorm(x, weight, eps=eps) + elif provider == "cuda_rms_layernorm": + out = torch.empty_like(x) + fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps) + elif provider == "vllm_rms_layernorm_with_residual": + fn = lambda: vllm_norm(x, residual=residual) + elif provider == "triton_rms_layernorm_with_residual": + fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual) + elif provider == "cuda_rms_layernorm_with_residual": + fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) + else: + raise ValueError("Undefined provider.") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + return ms + + +if __name__ == "__main__": + benchmark_rms_layernorm.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_rotary_embedding.py b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py new file mode 100644 index 000000000000..97cf2e0b2451 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py @@ -0,0 +1,76 @@ +import torch +import triton +from vllm._C import ops + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import rotary_embedding + +inference_ops = InferenceOpsLoader().load() + +BATCH = 16 +configs = [ + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[2**i for i in range(4, 12)], + line_arg="provider", + line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], + ylabel="ms", + plot_name=f"rotary_emb-batch-{BATCH}", + args={"num_kv_heads": 16}, + ) +] + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@triton.testing.perf_report(configs) +def benchmark_rotary_emb( + provider: str, + num_tokens: int, + num_kv_heads: int, +): + warmup = 10 + rep = 100 + + head_dim = 128 + dtype = torch.float16 + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (4096, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + cos_sin = torch.stack((cos, sin), dim=1).contiguous() + + positions = torch.arange(num_tokens).cuda() + + if provider == "triton_func": + fn = lambda: rotary_embedding(q, k, cos, sin) + elif provider == "colossal_cuda_func": + fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin) + elif provider == "vllm_cuda_func": + q = q.view(num_tokens, -1) + k = k.view(num_tokens, -1) + fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + benchmark_rotary_emb.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/benchmark_xine_copy.py b/examples/inference/benchmark_ops/benchmark_xine_copy.py new file mode 100644 index 000000000000..633ceb6f1651 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_xine_copy.py @@ -0,0 +1,54 @@ +import torch + +from colossalai.kernel.triton import get_xine_cache +from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + + +configs = [ + triton.testing.Benchmark( + x_names=["max_num_tokens"], + x_vals=[2**i for i in range(6, 12)], + line_arg="provider", + line_vals=["torch_get_cos_sin", "triton_get_cos_sin"], + line_names=["torch_get_cos_sin", "triton_get_cos_sin"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name="Get_cos-sin_func", + args={"batch_size": 16, "head_dim": 256}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_get_xine_cache( + provider: str, + max_num_tokens: int, + batch_size: int, + head_dim: int, +): + warmup = 10 + rep = 1000 + dtype = torch.float16 + cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda") + + if provider == "torch_get_cos_sin": + fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + elif provider == "triton_get_cos_sin": + fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + benchmark_get_xine_cache.run(save_path=".", print_data=True) diff --git a/examples/inference/benchmark_ops/test_ci.sh b/examples/inference/benchmark_ops/test_ci.sh new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/inference/build_smoothquant_weight.py b/examples/inference/build_smoothquant_weight.py deleted file mode 100644 index d60ce1c1d618..000000000000 --- a/examples/inference/build_smoothquant_weight.py +++ /dev/null @@ -1,59 +0,0 @@ -import argparse -import os - -import torch -from datasets import load_dataset -from transformers import LlamaTokenizer - -from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM - - -def build_model_and_tokenizer(model_name): - tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) - kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} - model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) - model = model.to(torch.float32) - return model, tokenizer - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model-name", type=str, help="model name") - parser.add_argument( - "--output-path", - type=str, - help="where to save the checkpoint", - ) - parser.add_argument( - "--dataset-path", - type=str, - help="location of the calibration dataset", - ) - parser.add_argument("--num-samples", type=int, default=10) - parser.add_argument("--seq-len", type=int, default=512) - args = parser.parse_args() - return args - - -@torch.no_grad() -def main(): - args = parse_args() - model_path = args.model_name - dataset_path = args.dataset_path - output_path = args.output_path - num_samples = args.num_samples - seq_len = args.seq_len - - model, tokenizer = build_model_and_tokenizer(model_path) - if not os.path.exists(dataset_path): - raise FileNotFoundError(f"Cannot find the dataset at {args.dataset_path}") - dataset = load_dataset("json", data_files=dataset_path, split="train") - - model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) - model = model.cuda() - - model.save_quantized(output_path, model_basename="llama-7b") - - -if __name__ == "__main__": - main() diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py new file mode 100644 index 000000000000..a65c8b667263 --- /dev/null +++ b/examples/inference/client/locustfile.py @@ -0,0 +1,58 @@ +from locust import HttpUser, between, tag, task + + +class QuickstartUser(HttpUser): + wait_time = between(1, 5) + + @tag("online-generation") + @task(5) + def completion(self): + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) + + @tag("online-generation") + @task(5) + def completion_streaming(self): + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + + @tag("online-chat") + @task(5) + def chat(self): + self.client.post( + "/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "False", + }, + ) + + @tag("online-chat") + @task(5) + def chat_streaming(self): + self.client.post( + "/chat", + json={ + "converation": [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ], + "stream": "True", + }, + ) + + @tag("offline-generation") + @task(5) + def generate_streaming(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"}) + + @tag("offline-generation") + @task(5) + def generate(self): + self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"}) + + @tag("online-generation", "offline-generation") + @task + def get_models(self): + self.client.get("/models") diff --git a/examples/inference/client/run_locust.sh b/examples/inference/client/run_locust.sh new file mode 100644 index 000000000000..ab0a267de2db --- /dev/null +++ b/examples/inference/client/run_locust.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +#argument1: model_path + +# launch server +model_path=${1:-"lmsys/vicuna-7b-v1.3"} +chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" +echo "Model Path: $model_path" +echo "Chat Tempelate" "${chat_template}" +echo "Starting server..." +python -m colossalai.inference.server.api_server --model $model_path --chat-template "${chat_template}" & +SERVER_PID=$! + +# waiting time +sleep 60 + +# Run Locust +echo "Starting Locust..." +echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." +echo "Test completion api first" +locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10 +echo "Test chat api" +locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10 +# kill Server +echo "Stopping server..." +kill $SERVER_PID + +echo "Test and server shutdown completely" diff --git a/examples/inference/client/test_ci.sh b/examples/inference/client/test_ci.sh new file mode 100644 index 000000000000..b130fc486bfe --- /dev/null +++ b/examples/inference/client/test_ci.sh @@ -0,0 +1,4 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" + +# bash ./run_benchmark.sh diff --git a/examples/inference/llama/README.md b/examples/inference/llama/README.md new file mode 100644 index 000000000000..cde81a41d839 --- /dev/null +++ b/examples/inference/llama/README.md @@ -0,0 +1,47 @@ +## Run Inference + +The provided example `llama_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `AutoModelForCausalLM` and `NoPaddingLlamaModelInferPolicy` as model class and policy class, and the script is good to run inference with Llama 3. + +For a basic setting, you could run the example by: +```bash +colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --max_length 128 +``` + +Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs: +```bash +colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --max_length 128 --tp_size 2 +``` + +## Run Speculative Decoding + +Colossal-Inference supports speculative decoding using the inference engine, with optimized kernels and cache management for the main model. + +Both a drafter model (small model) and a main model (large model) will be used during speculative decoding process. The drafter model will generate a few tokens sequentially, and then the main model will validate those candidate tokens in parallel and accept validated ones. The decoding process will be speeded up, for the latency of speculating multiple tokens by the drafter model is lower than that by the main model. + +Moreover, Colossal-Inference also supports GLIDE, a modified draft model architecture that reuses key and value caches from the main model, which improves the acceptance rate and increment the speed-up ratio. Details can be found in research paper GLIDE with a CAPE - A Low-Hassle Method to Accelerate Speculative Decoding on [arXiv](https://arxiv.org/pdf/2402.02082.pdf). + +Right now, Colossal-Inference offers a GLIDE model compatible with vicuna7B (https://huggingface.co/lmsys/vicuna-7b-v1.5). You can find the fine-tuned GLIDE drafter model `cxdu/glide-vicuna7b` on the HuggingFace Hub: https://huggingface.co/cxdu/glide-vicuna7b. + +Benchmarking with gsm8k and MT-Bench dataset with batch size 1 on H800, the speed increase for using speculative decoding is around 1.28x, and the speed increase for using speculative decoding with Glide model (as drafter model) is around 1.5x. + +## Usage + +For main model, you might want to use model card `lmsys/vicuna-7b-v1.5` at [HuggingFace Hub](https://huggingface.co/lmsys/vicuna-7b-v1.5). +For regular drafter model, you might want to use model card `JackFram/llama-68m` at [HuggingFace Hub](https://huggingface.co/JackFram/llama-68m). +For the GLIDE drafter model, you could use model card `cxdu/glide-vicuna7b` at [HuggingFace Hub](https://huggingface.co/cxdu/glide-vicuna7b). + + +You could run speculative decoding by +```bash +colossalai run --nproc_per_node 1 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128 +``` + +Run multi-GPU inference (Tensor Parallelism), as in the following example using 2 GPUs. +```bash +colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_model PATH_DRAFTER_MODEL --max_length 128 --tp_size 2 +``` + +If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by +```python +engine.enable_spec_dec(drafter_model, use_glide_drafter=True) +``` diff --git a/examples/inference/llama/benchmark_llama.py b/examples/inference/llama/benchmark_llama.py new file mode 100644 index 000000000000..2d24d87adfd1 --- /dev/null +++ b/examples/inference/llama/benchmark_llama.py @@ -0,0 +1,275 @@ +import argparse +import time +from contextlib import nullcontext + +import torch +import torch.distributed as dist +import transformers +from transformers import AutoTokenizer, GenerationConfig +from vllm import LLM, SamplingParams + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +CONFIG_MAP = { + "toy": transformers.LlamaConfig(num_hidden_layers=4), + "llama-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=2048, + ), + "llama-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=2048, + ), + "llama2-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=4096, + ), + "llama2-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=4096, + ), + "llama3-8b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=14336, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=8, + max_position_embeddings=8192, + ), + "llama3-70b": transformers.LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_attention_heads=64, + num_hidden_layers=80, + num_key_value_heads=8, + max_position_embeddings=8192, + ), +} + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) + return input_ids + + +def print_details_info(model_config, args, whole_end2end, total_token_num): + msg: str = "" + + if dist.get_rank() == 0: + msg += "-------Perf Summary-------\n" + whole_avg_latency = whole_end2end / (total_token_num) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 + if args.dtype in ["fp16", "bf16"]: + num_bytes = 2 + else: + num_bytes = 4 + + msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n" + msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n" + msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n" + msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" + + if torch.cuda.is_available(): + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" + + print(msg) + + +def benchmark_inference(args): + with torch.no_grad(): + config = CONFIG_MAP[args.model] + config.pad_token_id = config.eos_token_id + + if args.mode != "vllm": + if args.test_random_weight: + model = transformers.LlamaForCausalLM(config).cuda() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + else: + assert args.model_path, "When testing pretrained weights, the model path must be provided.'" + model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda() + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + model = model.eval() + + if args.dtype == "fp16": + model = model.half() + elif args.dtype == "bf16": + model = model.to(torch.bfloat16) + + generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, + max_length=args.seq_len + args.output_len, + # max_new_tokens=args.max_output_len, + ) + + if args.continous_batching: + mbsz = args.mbsz + else: + mbsz = args.batch_size + if args.mode == "colossalai": + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=mbsz, + max_input_len=args.seq_len, + max_output_len=args.output_len, + prefill_ratio=1.2, + block_size=32, + tp_size=args.tp_size, + use_cuda_kernel=True, + ) + engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + elif args.mode == "vllm": + engine = LLM( + model=args.model_path, + tokenizer="hf-internal-testing/llama-tokenizer", + max_num_seqs=mbsz, + dtype="float16", + enforce_eager=True, + ) + + sampling_params = SamplingParams( + max_tokens=args.output_len, + ) + else: + engine = model + + data = data_gen(mbsz, args.seq_len) + + if args.mode == "colossalai" or args.mode == "vllm": + data = data.tolist() + + N_WARMUP_STEPS = 2 + + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode), + ) + if args.profile + else nullcontext() + ) + + with ctx: + for _ in range(N_WARMUP_STEPS): + if args.mode == "colossalai": + engine.generate(prompts_token_ids=data, generation_config=generation_config) + elif args.mode == "vllm": + engine.generate(prompt_token_ids=data, sampling_params=sampling_params) + else: + engine.generate(data, generation_config=generation_config) + if args.profile: + ctx.step() + + if args.nsys: + torch.cuda.cudart().cudaProfilerStart() + + torch.cuda.synchronize() + + whole_end2end = time.perf_counter() + + if args.mode == "colossalai": + for _ in range(args.batch_size // mbsz): + output, output_tokens_list = engine.generate( + prompts_token_ids=data, generation_config=generation_config, return_token_ids=True + ) + elif args.mode == "vllm": + for _ in range(args.batch_size // mbsz): + output = engine.generate(prompt_token_ids=data, sampling_params=sampling_params) + else: + for _ in range(args.batch_size // mbsz): + output = engine.generate(data, generation_config=generation_config) + + whole_end2end = time.perf_counter() - whole_end2end + + if args.mode == "colossalai": + total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list]) + elif args.mode == "vllm": + total_token_num = sum([len(out.outputs[0].token_ids) for out in output]) + else: + total_token_num = sum([len(out) for out in output]) + + print("total_token_num: ", total_token_num) + if args.nsys: + torch.cuda.cudart().cudaProfilerStop() + if args.profile: + ctx.step() + print(f"config:batch_size {args.batch_size}, input_len{ args.seq_len}, output_len {args.output_len}") + print_details_info(config, args, whole_end2end, total_token_num) + + +def hybrid_inference(rank, world_size, port, args): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + benchmark_inference(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + spawn(hybrid_inference, nprocs=args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + default="toy", + help="the size of model", + choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"], + ) + parser.add_argument("--model_path", type=str, default=None, help="The pretrained weights path") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("--output_len", type=int, default=128, help="Output length") + parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument( + "--test_random_weight", default=False, action="store_true", help="whether to test random weight" + ) + parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler") + parser.add_argument( + "--mode", + default="colossalai", + choices=["colossalai", "transformers", "vllm"], + help="decide which inference framework to run", + ) + parser.add_argument( + "-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching" + ) + args = parser.parse_args() + benchmark(args) diff --git a/examples/inference/llama/benchmark_llama3.py b/examples/inference/llama/benchmark_llama3.py new file mode 100644 index 000000000000..07ebdb2b1bfb --- /dev/null +++ b/examples/inference/llama/benchmark_llama3.py @@ -0,0 +1,216 @@ +import argparse +import time +from contextlib import nullcontext + +import torch +import transformers +from transformers import AutoTokenizer, GenerationConfig + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +GIGABYTE = 1024**3 +MEGABYTE = 1024**2 +N_WARMUP_STEPS = 2 + +CONFIG_MAP = { + "toy": transformers.LlamaConfig(num_hidden_layers=4), + "llama-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=2048, + ), + "llama-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=2048, + ), + "llama2-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=4096, + ), + "llama2-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=4096, + ), + "llama3-8b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=14336, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=8, + max_position_embeddings=8192, + ), + "llama3-70b": transformers.LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_attention_heads=64, + num_hidden_layers=80, + num_key_value_heads=8, + max_position_embeddings=8192, + ), +} + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) + return input_ids.tolist() + + +def print_details_info(model_config, whole_end2end, total_token_num, dtype, coordinator=None): + if coordinator is None: + coordinator = DistCoordinator() + msg = "-------Perf Summary-------\n" + whole_avg_latency = whole_end2end / (total_token_num) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 + if dtype in ["fp16", "bf16"]: + num_bytes = 2 + elif dtype == "fp32": + num_bytes = 4 + else: + raise ValueError(f"Unsupported dtype {dtype}") + + msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n" + msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n" + msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n" + msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" + if torch.cuda.is_available(): + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" + + coordinator.print_on_master(msg) + + +def benchmark_inference(args): + coordinator = DistCoordinator() + + config = CONFIG_MAP[args.model] + config.pad_token_id = config.eos_token_id + if args.model_path is not None: + model = transformers.LlamaForCausalLM.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + else: + # Random weights + model = transformers.LlamaForCausalLM(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + if args.dtype == "fp16": + model = model.half() + elif args.dtype == "bf16": + model = model.to(torch.bfloat16) + + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.batch_size, + max_input_len=args.max_seq_len, + max_output_len=args.max_output_len, + prefill_ratio=1.2, + block_size=32, + tp_size=args.tp_size, + use_cuda_kernel=True, + ) + engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + + data = data_gen(args.batch_size, args.max_seq_len) + generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, + max_length=args.max_seq_len + args.max_output_len, + # max_new_tokens=args.max_output_len, + ) + coordinator.print_on_master(f"Generation Config: \n{generation_config.to_dict()}") + + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"./tb_log_{args.batch_size}_{args.max_seq_len}_{args.max_output_len}" + ), + ) + if args.profile + else nullcontext() + ) + with ctx: + for _ in range(N_WARMUP_STEPS): + engine.generate(prompts_token_ids=data, generation_config=generation_config) + if args.profile: + ctx.step() + if args.nsys: + torch.cuda.cudart().cudaProfilerStart() + + torch.cuda.synchronize() + whole_end2end = time.perf_counter() + output, output_tokens_list = engine.generate( + prompts_token_ids=data, generation_config=generation_config, return_token_ids=True + ) + torch.cuda.synchronize() + whole_end2end = time.perf_counter() - whole_end2end + + total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list]) + coordinator.print_on_master(f"total_token_num: {total_token_num}") + if args.nsys: + torch.cuda.cudart().cudaProfilerStop() + if args.profile: + ctx.step() + + print_details_info(model.config, whole_end2end, total_token_num, args.dtype, coordinator=coordinator) + + +def inference(rank, world_size, port, args): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + benchmark_inference(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + spawn(inference, nprocs=args.tp_size, args=args) + + +# python benchmark_llama3.py -m llama3-8b -b 16 -s 256 -o 256 +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + default="llama3-8b", + help="The version of Llama model", + choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"], + ) + parser.add_argument("-p", "--model_path", type=str, default=None, help="The pretrained weights path") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--max_seq_len", type=int, default=8, help="input sequence length") + parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Output length") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler") + + args = parser.parse_args() + + benchmark(args) diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py new file mode 100644 index 000000000000..c0a1a585a1b9 --- /dev/null +++ b/examples/inference/llama/llama_generation.py @@ -0,0 +1,105 @@ +import argparse + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaModelInferPolicy + +# For Llama 3, we'll use the following configuration +MODEL_CLS = AutoModelForCausalLM +POLICY_CLS = NoPaddingLlamaModelInferPolicy + + +def infer(args): + # ============================== + # Launch colossalai, setup distributed environment + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Load model and tokenizer + # ============================== + model_path_or_name = args.model + model = MODEL_CLS.from_pretrained(model_path_or_name) + tokenizer = AutoTokenizer.from_pretrained(model_path_or_name) + tokenizer.pad_token = tokenizer.eos_token + # coordinator.print_on_master(f"Model Config:\n{model.config}") + + # ============================== + # Initialize InferenceEngine + # ============================== + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + prefill_ratio=1.2, + block_size=16, + tp_size=args.tp_size, + use_cuda_kernel=args.use_cuda_kernel, + ) + coordinator.print_on_master(f"Initializing Inference Engine...") + engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) + + # ============================== + # Generation + # ============================== + generation_config = GenerationConfig( + pad_token_id=tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_length=args.max_length, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=generation_config) + coordinator.print_on_master(out) + + # ============================== + # Optionally, load drafter model and proceed speculative decoding + # ============================== + drafter_model_path_or_name = args.drafter_model + if drafter_model_path_or_name is not None: + drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name) + # turn on speculative decoding with the drafter model + engine.enable_spec_dec(drafter_model) + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=generation_config) + coordinator.print_on_master(out) + + engine.disable_spec_dec() + + +# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH +# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2 +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") + parser.add_argument("--drafter_model", type=str, help="Path to the drafter model or model name") + parser.add_argument( + "-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt" + ) + parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") + parser.add_argument("-i", "--max_input_len", type=int, default=128, help="Max input length") + parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Max output length") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + # Generation configs + parser.add_argument("--max_length", type=int, default=64, help="Max length for generation") + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation") + parser.add_argument("--top_k", type=int, default=50, help="Top k for generation") + parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation") + args = parser.parse_args() + + infer(args) diff --git a/examples/inference/llama/run_benchmark.sh b/examples/inference/llama/run_benchmark.sh new file mode 100755 index 000000000000..1927159765ba --- /dev/null +++ b/examples/inference/llama/run_benchmark.sh @@ -0,0 +1,33 @@ +ROOT=$(realpath $(dirname $0)) +echo $ROOT +PY_SCRIPT=${ROOT}/benchmark_llama.py +GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) +mode=$1 + +mkdir -p logs + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 + +# benchmark llama2-7b one single GPU +for input_len in 128 512 1024; do + for output_len in 128 256; do + for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${bsz}_${input_len}_${output_len}_${mode}_${GPU}.txt + done + done +done diff --git a/examples/inference/llama/test_ci.sh b/examples/inference/llama/test_ci.sh new file mode 100644 index 000000000000..b130fc486bfe --- /dev/null +++ b/examples/inference/llama/test_ci.sh @@ -0,0 +1,4 @@ +#!/bin/bash +echo "Skip the test (this test is slow)" + +# bash ./run_benchmark.sh diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh deleted file mode 100755 index 394222ea62b8..000000000000 --- a/examples/inference/run_benchmark.sh +++ /dev/null @@ -1,15 +0,0 @@ -ROOT=$(realpath $(dirname $0)) -PY_SCRIPT=${ROOT}/benchmark_llama.py -GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) - -mkdir -p logs - -# benchmark llama2-7b one single GPU -for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt -done - - -for bsz in 4 8 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt -done diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py deleted file mode 100644 index a4e6fd0a143d..000000000000 --- a/examples/inference/run_llama_inference.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse - -import torch -import torch.distributed as dist -from transformers import LlamaForCausalLM, LlamaTokenizer - -import colossalai -from colossalai.accelerator import get_accelerator -from colossalai.inference import InferenceEngine -from colossalai.testing import spawn - -INPUT_TEXTS = [ - "What is the longest river in the world?", - "Explain the difference between process and thread in compouter science.", -] - - -def run_inference(args): - llama_model_path = args.model_path - llama_tokenize_path = args.tokenizer_path or args.model_path - - max_input_len = args.max_input_len - max_output_len = args.max_output_len - max_batch_size = args.batch_size - micro_batch_size = args.micro_batch_size - tp_size = args.tp_size - pp_size = args.pp_size - rank = dist.get_rank() - - tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left") - tokenizer.pad_token_id = tokenizer.eos_token_id - - if args.quant is None: - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id) - elif args.quant == "gptq": - from auto_gptq import AutoGPTQForCausalLM - - model = AutoGPTQForCausalLM.from_quantized( - llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device() - ) - elif args.quant == "smoothquant": - from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM - - model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name) - model = model.cuda() - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - micro_batch_size=micro_batch_size, - quant=args.quant, - dtype=args.dtype, - ) - - inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} - outputs = engine.generate(inputs) - - if rank == 0: - output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - for input_text, output_text in zip(INPUT_TEXTS, output_texts): - print(f"Input: {input_text}") - print(f"Output: {output_text}") - - -def run_tp_pipeline_inference(rank, world_size, port, args): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_inference(args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True) - parser.add_argument("-i", "--input", default="What is the longest river in the world?") - parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None) - parser.add_argument( - "-q", - "--quant", - type=str, - choices=["gptq", "smoothquant"], - default=None, - help="quantization type: 'gptq' or 'smoothquant'", - ) - parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name") - parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length") - parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size") - parser.add_argument("--dtype", default="fp16", type=str) - - args = parser.parse_args() - spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index fdd8442f506b..5a9e30dd4542 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,20 @@ replace_return_docstrings, ) -from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN +try: + # TODO: remove this after updating openmoe example + # NOTE(yuanheng-zhao): This is a temporary fix for the issue that + # the flash_attention module is not imported correctly for different CI tests. + # We replace the import path `colossalai.kernel.extensions.flash_attention` + # because in the current example test, colossalai version <= 0.3.6 is installed, + # where `colossalai.kernel.extensions.flash_attention` is still valid; + # however in unit test `test_moe_checkpoint`, the lastest version of colossalai is installed, + # where extension has been refactored and the path is not valid. + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/extensions/__init__.py b/extensions/__init__.py index 0dbadba81905..c392a16b5a61 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -1,9 +1,14 @@ -from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension -from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension -from .layernorm import LayerNormCudaExtension -from .moe import MoeCudaExtension -from .optimizer import FusedOptimizerCudaExtension -from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension +from .pybind.cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .pybind.flash_attention import ( + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionSdpaCudaExtension, +) +from .pybind.inference import InferenceOpsCudaExtension +from .pybind.layernorm import LayerNormCudaExtension +from .pybind.moe import MoeCudaExtension +from .pybind.optimizer import FusedOptimizerCudaExtension +from .pybind.softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension ALL_EXTENSIONS = [ CpuAdamArmExtension, @@ -11,6 +16,7 @@ LayerNormCudaExtension, MoeCudaExtension, FusedOptimizerCudaExtension, + InferenceOpsCudaExtension, ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, FlashAttentionDaoCudaExtension, @@ -24,6 +30,7 @@ "LayerNormCudaExtension", "MoeCudaExtension", "FusedOptimizerCudaExtension", + "InferenceOpsCudaExtension", "ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension", "FlashAttentionDaoCudaExtension", diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py index 3adb65fb8f4e..aaa43f964c25 100644 --- a/extensions/cpp_extension.py +++ b/extensions/cpp_extension.py @@ -25,6 +25,9 @@ def __init__(self, name: str, priority: int = 1): def csrc_abs_path(self, path): return os.path.join(self.relative_to_abs_path("csrc"), path) + def pybind_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("pybind"), path) + def relative_to_abs_path(self, code_path: str) -> str: """ This function takes in a path relative to the colossalai root directory and return the absolute path. @@ -116,6 +119,7 @@ def include_dirs(self) -> List[str]: """ This function should return a list of include files for extensions. """ + return [self.csrc_abs_path("")] @abstractmethod def cxx_flags(self) -> List[str]: diff --git a/extensions/csrc/__init__.py b/extensions/csrc/__init__.py index 0eac28d23e24..e69de29bb2d1 100644 --- a/extensions/csrc/__init__.py +++ b/extensions/csrc/__init__.py @@ -1,11 +0,0 @@ -from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .multihead_attention import MultiHeadAttention -from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax - -__all__ = [ - "LayerNorm", - "MultiHeadAttention", - "FusedScaleMaskSoftmax", - "ScaledUpperTriangMaskedSoftmax", - "AttnMaskType", -] diff --git a/extensions/csrc/common/data_type.h b/extensions/csrc/common/data_type.h new file mode 100644 index 000000000000..7cc7cfabbdaf --- /dev/null +++ b/extensions/csrc/common/data_type.h @@ -0,0 +1,53 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#endif + +namespace colossalAI { +namespace dtype { + +struct bfloat164 { +#ifdef COLOSSAL_WITH_CUDA + __nv_bfloat162 x; + __nv_bfloat162 y; +#endif +}; + +struct bfloat168 { +#ifdef COLOSSAL_WITH_CUDA + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +#endif +}; + +struct half4 { +#ifdef COLOSSAL_WITH_CUDA + half2 x; + half2 y; +#endif +}; + +struct half8 { +#ifdef COLOSSAL_WITH_CUDA + half2 x; + half2 y; + half2 z; + half2 w; +#endif +}; + +struct float8 { +#ifdef COLOSSAL_WITH_CUDA + float2 x; + float2 y; + float2 z; + float2 w; +#endif +}; + +} // namespace dtype +} // namespace colossalAI diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/common/micros.h similarity index 85% rename from extensions/csrc/cuda/type_shim.h rename to extensions/csrc/common/micros.h index 03ccc02635fa..cf7d0ce35c1f 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/common/micros.h @@ -4,9 +4,10 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. */ -#include -#include "compat.h" +#pragma once + +#include #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ @@ -24,6 +25,37 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ + TYPE, NAME, ...) \ + if (HIGH_PRECISION) { \ + const bool high_precision = true; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + } else { \ + const bool high_precision = false; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + } + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ case at::ScalarType::Float: { \ @@ -191,89 +223,12 @@ "'"); \ } -template -__device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} +#if defined(COLOSSAL_WITH_CUDA) +#define HOST __host__ +#define DEVICE __device__ +#define HOSTDEVICE __host__ __device__ +#else +#define HOST +#define DEVICE +#define HOSTDEVICE +#endif diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h new file mode 100644 index 000000000000..7a27f26507a5 --- /dev/null +++ b/extensions/csrc/common/mp_type_traits.h @@ -0,0 +1,55 @@ +#pragma once + +#include + +#include "micros.h" + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#endif + +namespace colossalAI { +namespace common { + +template +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait { + using Type = float; +}; + +#if defined(COLOSSAL_WITH_CUDA) +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait<__nv_bfloat16> { + using Type = float; +}; +#endif + +template +struct ScalarTypeTrait { + using Type = + typename std::conditional::Type, + T>::type; +}; + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/common/target.h b/extensions/csrc/common/target.h new file mode 100644 index 000000000000..ee3072f62d71 --- /dev/null +++ b/extensions/csrc/common/target.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include + +namespace colossalAI { +namespace common { + +class Target { + public: + enum class OS : int { + Unk = -1, + Linux, + Windows, + }; + enum class Arch : int { + Unk = -1, + X86, + Arm, + NVGPU, + AMDGPU, + Ascend, + }; + enum class BitLen : int { + Unk = -1, + k32, + k64, + }; + + explicit Target(OS os, Arch arch, BitLen bitlen) + : os_(os), arch_(arch), bitlen_(bitlen) {} + + bool defined() const { + return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk); + } + + std::string str() const { + std::string s{"OS: "}; + switch (os_) { + case OS::Unk: + s += "Unk"; + break; + case OS::Linux: + s += "Linux"; + break; + case OS::Windows: + s += "Windows"; + break; + default: + throw std::invalid_argument("Invalid OS type!"); + } + s += "\t"; + s += "Arch: "; + + switch (arch_) { + case Arch::Unk: + s += "Unk"; + break; + case Arch::X86: + s += "X86"; + break; + case Arch::Arm: + s += "Arm"; + break; + case Arch::NVGPU: + s += "NVGPU"; + break; + case Arch::AMDGPU: + s += "AMDGPU"; + break; + case Arch::Ascend: + s += "Ascend"; + break; + default: + throw std::invalid_argument("Invalid Arch type!"); + } + s += "\t"; + s += "BitLen: "; + + switch (bitlen_) { + case BitLen::Unk: + s += "Unk"; + break; + case BitLen::k32: + s += "k32"; + break; + case BitLen::k64: + s += "k64"; + break; + default: + throw std::invalid_argument("Invalid target bit length!"); + } + + return s; + } + + OS os() const { return os_; } + Arch arch() const { return arch_; } + BitLen bitlen() const { return bitlen_; } + + static Target DefaultX86Target(); + static Target DefaultArmTarget(); + static Target DefaultRocmTarget(); + static Target DefaultAscendTarget(); + + static Target DefaultCUDATarget() { + return Target(OS::Linux, Arch::NVGPU, BitLen::k64); + } + + friend std::ostream& operator<<(std::ostream& os, const Target& target); + friend bool operator==(const Target& lhs, const Target& rhs); + friend bool operator!=(const Target& lhs, const Target& rhs); + + private: + OS os_{OS::Unk}; + Arch arch_{Arch::Unk}; + BitLen bitlen_{BitLen::Unk}; +}; + +std::ostream& operator<<(std::ostream& os, const Target& target) { + std::cout << target.str() << std::endl; +} +bool operator==(const Target& lhs, const Target& rhs) { + return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) && + (lhs.bitlen_ == rhs.bitlen_); +} +bool operator!=(const Target& lhs, const Target& rhs) { + return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) && + (lhs.bitlen_ != rhs.bitlen_); +} + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/common/vec_type_traits.h b/extensions/csrc/common/vec_type_traits.h new file mode 100644 index 000000000000..9e12ab71b86c --- /dev/null +++ b/extensions/csrc/common/vec_type_traits.h @@ -0,0 +1,76 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#endif + +#include +#include + +#include "common/data_type.h" + +namespace colossalAI { +namespace common { + +template +struct VecTypeTrait {}; + +template +struct FloatVecTypeTrait {}; + +#define VEC_TYPE_TRAITS_SPECIALIZATION(T, VEC_SIZE, VECT, ARGS...) \ + template \ + struct VecTypeTrait { \ + using Type = VECT; \ + }; + +VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T) + +#if defined(COLOSSAL_WITH_CUDA) + +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4) + +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2) +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162); +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164); +VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4); +VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8); +VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8) +#endif /* defined(COLOSSAL_WITH_CUDA) */ + +#undef VEC_TYPE_TRAITS_SPECIALIZATION + +#define FLOATVEC_TYPE_TRAITS_SPECIALIZATION(T, FLOATT, ARGS...) \ + template \ + struct FloatVecTypeTrait { \ + using Type = FLOATT; \ + }; + +#if defined(COLOSSAL_WITH_CUDA) +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2) +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4) +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, float4); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, float4); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8); +#endif /* COLOSSAL_WITH_CUDA */ + +#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/compat.h b/extensions/csrc/cuda/compat.h deleted file mode 100644 index a62beef91a8a..000000000000 --- a/extensions/csrc/cuda/compat.h +++ /dev/null @@ -1,10 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h deleted file mode 100644 index 38103c1734c8..000000000000 --- a/extensions/csrc/cuda/include/block_reduce.h +++ /dev/null @@ -1,312 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Tencent/TurboTransformers - This block_reduce_n is adapted from Tencent/TurboTransformers -*/ -#pragma once -#include -#include -#include - -enum class ReduceType { kMax = 0, kSum }; -const unsigned int WARP_REDUCE_MASK = 0xffffffff; -const float REDUCE_FLOAT_INF_NEG = -100000000.f; -const float REDUCE_FLOAT_INF_POS = 100000000.f; -const unsigned int WARP_REDUCE_SIZE = 32; - -template -__forceinline__ __device__ T warpReduceSum(T val) { - for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) - val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__forceinline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) shared[wid] = val; - __syncthreads(); - - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; - val = warpReduceSum(val); - return val; -} - -template -__inline__ __device__ void blockReduce(float *pval); - -// use template to make code more concise -template -__inline__ __device__ void warpReduce(float *pval); - -// static -template <> -__inline__ __device__ void warpReduce(float *pval) { - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32)); -} - -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp; -#define WarpReduceMaxOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval) = max(val0_tmp, *(pval)); \ - *(pval + 1) = max(val1_tmp, *(pval + 1)); - - WarpReduceMaxOneStep(16, 32); - WarpReduceMaxOneStep(8, 32); - WarpReduceMaxOneStep(4, 32); - WarpReduceMaxOneStep(2, 32); - WarpReduceMaxOneStep(1, 32); -#undef WarpReduceMaxOneStep -} - -template <> -__inline__ __device__ void warpReduce(float *pval) { - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32); -} - -/* - * Unorll for loop for warpreduce to - * imporve instruction issue efficiency - * ElemX means there are X numbers to be summed - */ - -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp - - WarpReduceSumOneStep(16, 32); - WarpReduceSumOneStep(8, 32); - WarpReduceSumOneStep(4, 32); - WarpReduceSumOneStep(2, 32); - WarpReduceSumOneStep(1, 32); - -#undef WarpReduceSumOneStep -} - -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp, val2_tmp, val3_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ - val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp; \ - *(pval + 2) += val2_tmp; \ - *(pval + 3) += val3_tmp - - WarpReduceSumOneStep(16, 32); - WarpReduceSumOneStep(8, 32); - WarpReduceSumOneStep(4, 32); - WarpReduceSumOneStep(2, 32); - WarpReduceSumOneStep(1, 32); -#undef WarpReduceSumOneStep -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 2; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 4; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} diff --git a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu deleted file mode 100644 index 41781ebc7fe0..000000000000 --- a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,89 +0,0 @@ -/*This code from NVIDIA Megatron: - * with minor changes. */ - -#include -#include -#include -#include -#include -#include -#include - -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, - float scale_factor) { - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, - // seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty( - {batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), scale_factor, - query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); - return softmax_results; -} - -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, - // seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - // Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); - - // backward pass is completely in-place - return output_grads; -} -} // namespace scaled_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 62c56e6f7870..000000000000 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,75 +0,0 @@ -/*This code from NVIDIA Megatron: - * with minor changes. */ - -#include -#include -#include -#include -#include -#include -#include - -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), scale_factor, seq_len, - seq_len, attn_batches);); - return softmax_results; -} - -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - // output grads is a 3d tensor with dimensions [attn_batches, seq_len, - // seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - // Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, seq_len, seq_len, attn_batches);); - - // backward pass is completely in-place - return output_grads; -} -} // namespace scaled_upper_triang_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/extensions/csrc/funcs/binary_functor.h b/extensions/csrc/funcs/binary_functor.h new file mode 100644 index 000000000000..90726a02fcb1 --- /dev/null +++ b/extensions/csrc/funcs/binary_functor.h @@ -0,0 +1,231 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#include +#include +#endif + +#include + +#include "cast_functor.h" +#include "common/data_type.h" +#include "common/micros.h" + +namespace colossalAI { +namespace funcs { + +enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; + +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16. +// Implementation of common and simple binary operators should be placed here, +// otherwise, they should be placed in a new file under functors dir. +template +struct BinaryOpFunctor; + +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( \ + LT, RT, RET, BINARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct BinaryOpFunctor \ + : public std::binary_function { \ + FUNCTION_MODIFIER RET operator()(LT lhs, RT rhs) STMTS \ + }; + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kAdd, HOSTDEVICE, + STMTS_WRAPPER({ return lhs + rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMinus, + HOSTDEVICE, + STMTS_WRAPPER({ return lhs - rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMul, HOSTDEVICE, + STMTS_WRAPPER({ return lhs * rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kDiv, HOSTDEVICE, + STMTS_WRAPPER({ return lhs / rhs; }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMax, HOSTDEVICE, + STMTS_WRAPPER({ return max(lhs, rhs); }), + typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE, + STMTS_WRAPPER({ return min(lhs, rhs); }), + typename T) + +#if defined(COLOSSAL_WITH_CUDA) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus, + DEVICE, STMTS_WRAPPER({ + return __hsub(lhs, rhs); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd2(lhs, rhs); + })) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd(lhs, rhs); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kMinus, + DEVICE, STMTS_WRAPPER({ + return __hsub(lhs, rhs); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, + __nv_bfloat162, BinaryOpType::kAdd, + DEVICE, STMTS_WRAPPER({ + return __hadd2(lhs, rhs); + })) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kAdd, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs)); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs)); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), + __high2float(lhs) + __high2float(rhs)); + })) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul2(lhs, rhs); + })) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul(lhs, rhs); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, + __nv_bfloat162, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + return __hmul2(lhs, rhs); + })) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) * __bfloat162float(rhs)); + })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), + __high2float(lhs) * __high2float(rhs)); + })) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + float2, float2, float2, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ return make_float2(lhs.x * rhs.x, lhs.y * rhs.y); })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(float4, float4, float4, + BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + return make_float4( + lhs.x * rhs.x, lhs.y * rhs.y, + lhs.z * rhs.z, lhs.w * rhs.w); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, float2, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat162, float2> cast; + BinaryOpFunctor mul; + float2 fa = cast(lhs); + float2 fb = cast(rhs); + return mul(fa, fb); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164, + float4, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + float4 fc; + CastFunctor<__nv_bfloat16, float> cast; + fc.x = cast(lhs.x.x) * cast(rhs.x.x); + fc.y = cast(lhs.x.y) * cast(rhs.x.y); + fc.z = cast(lhs.y.x) * cast(rhs.y.x); + fc.w = cast(lhs.y.y) * cast(rhs.y.y); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul, + DEVICE, STMTS_WRAPPER({ + dtype::float8 fc; + BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + BinaryOpType::kMul> + mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + fc.z = mul(lhs.z, rhs.z); + fc.w = mul(lhs.w, rhs.w); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + half2, half2, float2, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + BinaryOpFunctor mul; + float2 fa = cast(lhs); + float2 fb = cast(rhs); + return mul(fa, fb); + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4, + BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + float4 fc; + CastFunctor cast; + fc.x = cast(lhs.x.x) * cast(rhs.x.x); + fc.y = cast(lhs.x.y) * cast(rhs.x.y); + fc.z = cast(lhs.y.x) * cast(rhs.y.x); + fc.w = cast(lhs.y.y) * cast(rhs.y.y); + return fc; + })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + dtype::float8 fc; + BinaryOpFunctor mul; + fc.x = mul(lhs.x, rhs.x); + fc.y = mul(lhs.y, rhs.y); + fc.z = mul(lhs.z, rhs.z); + fc.w = mul(lhs.w, rhs.w); + return fc; + })) + +#endif /* defined(COLOSSAL_WITH_CUDA) */ + +#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION +#undef STMTS_WRAPPER +} // namespace funcs +} // namespace colossalAI diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h new file mode 100644 index 000000000000..588357d6b4bf --- /dev/null +++ b/extensions/csrc/funcs/cast_functor.h @@ -0,0 +1,503 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#include +#include +#include +#endif + +#include +#include + +#include + +#include "common/data_type.h" +#include "common/micros.h" + +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16 + +namespace colossalAI { +namespace funcs { + +template +struct CastFunctor : public std::unary_function { + HOSTDEVICE To operator()(From val) { return static_cast(val); } +}; + +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, FUNCTION_MODIFIER, \ + STMTS) \ + template <> \ + struct CastFunctor : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ + }; + +#if defined(COLOSSAL_WITH_CUDA) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, DEVICE, STMTS_WRAPPER({ + return make_float2(val.x, val.y); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, DEVICE, STMTS_WRAPPER({ + return make_float2(val, val); + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, DEVICE, STMTS_WRAPPER({ + return __half22float2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, DEVICE, STMTS_WRAPPER({ + return __float22half2_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, DEVICE, STMTS_WRAPPER({ + return __float2half_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, DEVICE, STMTS_WRAPPER({ + return __float2half2_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, DEVICE, STMTS_WRAPPER({ + return __half2half2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, DEVICE, STMTS_WRAPPER({ + return __half2float(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE, + STMTS_WRAPPER({ + dtype::half4 dst; + dst.x = __floats2half2_rn(val.x, val.y); + dst.y = __floats2half2_rn(val.z, val.w); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE, + STMTS_WRAPPER({ + float4 dst; + dst.x = __half2float(val.x.x); + dst.y = __half2float(val.x.y); + dst.z = __half2float(val.y.x); + dst.w = __half2float(val.y.y); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE, + STMTS_WRAPPER({ + dtype::half8 dst; + dst.x = __float22half2_rn(val.x); + dst.y = __float22half2_rn(val.y); + dst.z = __float22half2_rn(val.z); + dst.w = __float22half2_rn(val.w); + return dst; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat162_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE, + STMTS_WRAPPER({ + return __bfloat162float(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE, + STMTS_WRAPPER({ + dtype::bfloat164 dst; + dst.x = + __floats2bfloat162_rn(val.x, val.y); + dst.y = + __floats2bfloat162_rn(val.z, val.w); + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE, + STMTS_WRAPPER({ + float4 dst; + dst.x = __bfloat162float(val.x.x); + dst.y = __bfloat162float(val.x.y); + dst.z = __bfloat162float(val.y.x); + dst.w = __bfloat162float(val.y.y); + return dst; + })) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __bfloat162bfloat162(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE, + STMTS_WRAPPER({ + return __bfloat1622float2(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __float22bfloat162_rn(val); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE, + STMTS_WRAPPER({ + dtype::bfloat168 dst; + dst.x = __float22bfloat162_rn(val.x); + dst.y = __float22bfloat162_rn(val.y); + dst.z = __float22bfloat162_rn(val.z); + dst.w = __float22bfloat162_rn(val.w); + return dst; + })) +#else +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + __nv_bfloat162 dst; + dst.x = val; + dst.y = val; + return dst; + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat162, float2, DEVICE, + STMTS_WRAPPER({ + return make_float2(__low2float(val), + __high2float(val)); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, + STMTS_WRAPPER({ + return __floats2bfloat162_rn(val.x, + val.y); + })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ + dtype::bfloat168 dst; + dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); + dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); + dst.z = __floats2bfloat162_rn(val.z.x, val.z.y); + dst.w = __floats2bfloat162_rn(val.w.x, val.w.y); + return dst; + })) +#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */ + +// quant utils +// fp8 -> half raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({ + __half_raw res = __nv_cvt_fp8_to_halfraw( + val, __NV_E5M2); + return res.x; + })) + +// half raw -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({ + __half_raw tmp; + tmp.x = val; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// fp8x2 -> half2 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({ + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = + __nv_cvt_fp8x2_to_halfraw2( + val, __NV_E5M2); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; + })) + +// fp8x4 -> half2x2 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, uint2, DEVICE, STMTS_WRAPPER({ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + CastFunctor()(static_cast(val)); + tmp.u32[1] = + CastFunctor()(static_cast(val >> 16U)); + return tmp.u32x2; + })) + +// fp8x8 -> half2x4 raw +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, uint4, DEVICE, STMTS_WRAPPER({ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = CastFunctor()(val.x); + tmp.u64[1] = CastFunctor()(val.y); + return tmp.u64x2; + })) + +// fp8 -> half +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({ + __half_raw res = __nv_cvt_fp8_to_halfraw( + val, __NV_E5M2); + return half(res); + })) + +// half -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, uint8_t, DEVICE, STMTS_WRAPPER({ + __half_raw tmp(val); + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// fp8x2 -> half2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({ + __half2_raw res = + __nv_cvt_fp8x2_to_halfraw2( + val, __NV_E5M2); + return half2(res); + })) + +// half2 -> fp8x2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, uint16_t, DEVICE, STMTS_WRAPPER({ + __half2_raw tmp(val); + __nv_fp8x2_storage_t res = + __nv_cvt_halfraw2_to_fp8x2( + tmp, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// fp8x4 -> half4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({ + half2 tmp1, tmp2; + tmp1 = CastFunctor()(static_cast(val)); + tmp2 = CastFunctor()(static_cast(val >> 16U)); + dtype::half4 res; + res.x = tmp1; + res.y = tmp2; + return res; + })) + +// half4 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::half4, uint32_t, DEVICE, STMTS_WRAPPER({ + half2 x, y; + x = val.x; + y = val.y; + uint16_t lo, hi; + lo = CastFunctor()(x); + hi = CastFunctor()(y); + uint32_t res; + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(lo), "h"(hi)); + return res; + })) + +// fp8x8 -> half8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::half8, DEVICE, STMTS_WRAPPER({ + dtype::half4 tmp1, tmp2; + tmp1 = CastFunctor()(val.x); + tmp2 = CastFunctor()(val.y); + dtype::half8 res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; + })) + +// fp8 -> __nv_bfloat16 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint8_t, __nv_bfloat16, DEVICE, STMTS_WRAPPER({ + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(val, __NV_E5M2); + // half -> float -> bf16 + float tmp; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(tmp) : "h"(res.x)); + return __float2bfloat16(tmp); + })) + +// fp8x2 -> __nv_bfloat162 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint16_t, __nv_bfloat162, DEVICE, STMTS_WRAPPER({ + __nv_bfloat162 res; + res.x = CastFunctor()(static_cast(val)); + res.y = CastFunctor()( + static_cast(val >> 8U)); + return res; + })) + +// fp8x4 -> bfloat164 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, dtype::bfloat164, DEVICE, STMTS_WRAPPER({ + dtype::bfloat164 res; + res.x = + CastFunctor()(static_cast(val)); + res.y = CastFunctor()( + static_cast(val >> 16U)); + return res; + })) + +// fp8x8 -> bfloat168 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ + dtype::bfloat164 tmp1, tmp2; + tmp1 = CastFunctor()(val.x); + tmp2 = CastFunctor()(val.y); + dtype::bfloat168 res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; + })) + +// fp8 -> float +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint8_t, float, DEVICE, STMTS_WRAPPER({ + // fp8 -> half + uint16_t tmp = CastFunctor()(val); + // half -> float + float res; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(res) : "h"(tmp)); + return res; + })) + +// float -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({ + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8( + val, __NV_SATFINITE, __NV_E5M2); + return static_cast(res); + })) + +// fp8x2 -> float2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint16_t, float2, DEVICE, STMTS_WRAPPER({ + // fp8x2 -> half2 + uint32_t tmp = CastFunctor()(val); + // half2 -> float2 + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(tmp)); + float lof, hif; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(lof) : "h"(lo)); + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(hif) : "h"(hi)); + return make_float2(lof, hif); + })) + +// float2 -> fp8x2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + float2, uint16_t, DEVICE, STMTS_WRAPPER({ + uint16_t tmp1 = + static_cast(CastFunctor()(val.x)); + uint16_t tmp2 = + static_cast(CastFunctor()(val.y)); + uint16_t res = (tmp2 << 8U) | tmp1; + return res; + })) + +// float4 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({ + uint32_t a, b, c, d; + a = CastFunctor()(val.x); + b = CastFunctor()(val.y); + c = CastFunctor()(val.z); + d = CastFunctor()(val.w); + return (d << 24U) | (c << 16U) | + (b << 8U) | a; + })) + +// fp8x4 -> float4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint32_t, float4, DEVICE, STMTS_WRAPPER({ + float4 res; + res.x = CastFunctor()(static_cast(val)); + res.y = CastFunctor()(static_cast(val >> 8U)); + res.z = CastFunctor()(static_cast(val >> 16U)); + res.w = CastFunctor()(static_cast(val >> 24U)); + return res; + })) + +// fp8x8 -> float8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + uint2, dtype::float8, DEVICE, STMTS_WRAPPER({ + dtype::float8 res; + res.x = CastFunctor()(static_cast(val.x)); + res.y = + CastFunctor()(static_cast(val.x >> 16U)); + res.z = CastFunctor()(static_cast(val.y)); + res.w = + CastFunctor()(static_cast(val.y >> 16U)); + return res; + })) + +// bf16 -> fp8 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE, + STMTS_WRAPPER({ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = + __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(val), + __NV_SATFINITE, __NV_E5M2); + return static_cast(res); +#endif + })) + +// bf162 -> fp8x2 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, uint16_t, DEVICE, STMTS_WRAPPER({ + uint16_t a = + static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.x)); + uint16_t b = + static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.y)); + return (b << 8U) | a; + })) + +// bf164 -> fp8x4 +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::bfloat164, uint32_t, DEVICE, STMTS_WRAPPER({ + uint32_t res; + uint16_t a, b; + a = CastFunctor<__nv_bfloat162, uint16_t>()(val.x); + b = CastFunctor<__nv_bfloat162, uint16_t>()(val.y); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(a), "h"(b)); + return res; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(val); + return uint32; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({ + uint2 b; + float2 c; + c.x = val.x; + c.y = val.y; + b.x = CastFunctor()(c); + + c.x = val.z; + c.y = val.w; + b.y = CastFunctor()(c); + + return b; + })) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( + dtype::float8, uint4, DEVICE, STMTS_WRAPPER({ + uint4 b; + b.x = CastFunctor()(val.x); + b.y = CastFunctor()(val.y); + b.z = CastFunctor()(val.z); + b.w = CastFunctor()(val.w); + return b; + })) + +#endif /* defined(COLOSSAL_WITH_CUDA) */ + +#undef STMTS_WRAPPER +#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION +} // namespace funcs +} // namespace colossalAI diff --git a/extensions/csrc/funcs/reduce_function.h b/extensions/csrc/funcs/reduce_function.h new file mode 100644 index 000000000000..58ff1e5bc0cc --- /dev/null +++ b/extensions/csrc/funcs/reduce_function.h @@ -0,0 +1,94 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#include + +#include "binary_functor.h" + +namespace colossalAI { +namespace funcs { + +const float kReduceFloatInfNeg = -100000000.f; +const float kReduceFloatInfPos = 100000000.f; +const unsigned int kWarpReduceMask = 0xffffffff; + +enum class ReduceType { kMax = 0, kSum }; + +template +struct GetOpForReduceType; + +template +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; + +template +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; + +#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ + _Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = \ + OP(*(VAL_PTR + offset), \ + __shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \ + } + +#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \ + _Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ + } + +#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \ + REDUCE_TYPE) \ + __shared__ T shm[LANES][32]; \ + int lane_id = threadIdx.x & 0x1f; \ + int warp_id = threadIdx.x >> 5; \ + \ + warp_reduce(VAL_PTR); \ + if (lane_id == 0) { \ + for (int offset = 0; offset < LANES; ++offset) { \ + shm[offset][warp_id] = *(VAL_PTR + offset); \ + } \ + } \ + __syncthreads(); \ + \ + _Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \ + ? shm[offset][lane_id] \ + : static_cast(DEFAULT_VALUE); \ + } \ + warp_reduce(VAL_PTR); + +template +__forceinline__ __device__ void warp_reduce(T* pval) { + typename GetOpForReduceType::Op op; + COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes); +} + +template +__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() { + if constexpr (rtype == ReduceType::kSum) { + return static_cast(0.0f); + } else if constexpr (rtype == ReduceType::kMax) { + return static_cast(kReduceFloatInfNeg); + } +} + +template +__forceinline__ __device__ void block_reduce(T* pval) { + constexpr T kDefaultValue = GetDefaultValueForBlockReduce(); + typename GetOpForReduceType::Op op; + COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype); +} + +#undef COLOSSAL_SHFL_FUNCTION +#undef COLOSSAL_WARP_REDUCE_IMPL +#undef COLOSSAL_BLOCK_REDUCE_IMPL + +} // namespace funcs +} // namespace colossalAI + +#endif /* defined(COLOSSAL_WITH_CUDA) */ diff --git a/extensions/csrc/funcs/ternary_functor.h b/extensions/csrc/funcs/ternary_functor.h new file mode 100644 index 000000000000..8d0c95f10d63 --- /dev/null +++ b/extensions/csrc/funcs/ternary_functor.h @@ -0,0 +1,214 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#include +#include +#endif + +#include + +#include + +#include "cast_functor.h" +#include "common/micros.h" + +namespace colossalAI { +namespace funcs { + +enum class TernaryOpType { kFma = 0 }; + +template +struct TernaryOpFunctor; + +#define STMTS_WRAPPER(...) __VA_ARGS__ + +#define COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( \ + LT, RT, RET, TERNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct TernaryOpFunctor { \ + FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \ + }; + +#if defined(COLOSSAL_WITH_CUDA) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float d; + d = fma(a, b, c); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float2, float2, float2, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float2, float2, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float4, float4, float4, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float4, float4, + TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; + })) + +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half, float, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ return __half2float(a) * __half2float(b) + c; })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half2, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + TernaryOpFunctor fma; + float2 fa = cast(a); + float2 fb = cast(b); + return fma(fa, fb, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + CastFunctor cast; + TernaryOpFunctor fma; + return fma(cast(a), b, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 fd; + CastFunctor cast; + TernaryOpFunctor fma; + fd = fma(cast(a), cast(b), c); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float4 fd; + CastFunctor cast0; + CastFunctor cast1; + TernaryOpFunctor fma; + fd = fma(cast0(a), cast1(b), c); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + dtype::float8 fd; + TernaryOpFunctor fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + fd.z = fma(a.z, b.z, c.z); + fd.w = fma(a.w, b.w, c.w); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + dtype::float8 fd; + CastFunctor cast; + TernaryOpFunctor fma; + half2 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + fd.z = fma(s, b.z, c.z); + fd.w = fma(s, b.w, c.w); + return fd; + })) + +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, float, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ return __bfloat162float(a) * __bfloat162float(b) + c; })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat162, float2> cast; + TernaryOpFunctor fma; + float2 fa = cast(a); + float2 fb = cast(b); + return fma(fa, fb, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + return fma(cast(a), b, c); + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 fd; + CastFunctor cast; + TernaryOpFunctor fma; + fd = fma(cast(a), cast(b), c); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 fd; + CastFunctor<__nv_bfloat16, float> cast0; + CastFunctor cast1; + TernaryOpFunctor fma; + fd = fma(cast0(a), cast1(b), c); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, + DEVICE, STMTS_WRAPPER({ + dtype::float8 fd; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + fd.x = fma(a.x, b.x, c.x); + fd.y = fma(a.y, b.y, c.y); + fd.z = fma(a.z, b.z, c.z); + fd.w = fma(a.w, b.w, c.w); + return fd; + })) +COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + dtype::float8 fd; + CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; + TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, + TernaryOpType::kFma> + fma; + __nv_bfloat162 s = cast(a); + fd.x = fma(s, b.x, c.x); + fd.y = fma(s, b.y, c.y); + fd.z = fma(s, b.z, c.z); + fd.w = fma(s, b.w, c.w); + return fd; + })) + +#endif /* defined(COLOSSAL_WITH_CUDA) */ + +#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION +#undef STMTS_WRAPPER + +} // namespace funcs +} // namespace colossalAI diff --git a/extensions/csrc/funcs/unary_functor.h b/extensions/csrc/funcs/unary_functor.h new file mode 100644 index 000000000000..207a0ff972d4 --- /dev/null +++ b/extensions/csrc/funcs/unary_functor.h @@ -0,0 +1,67 @@ +#pragma once + +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#include +#include +#endif + +#include + +#include "common/data_type.h" +#include "common/micros.h" + +namespace colossalAI { +namespace funcs { + +// Note(LiuYang): As a retrieved table to check which operation is supported +// already +enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum }; + +// Note(LiuYang): Implementation of common and simple unary operators should be +// placed here, otherwise, they should be placed in a new file under functors +// dir. +template +struct UnaryOpFunctor; + +#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \ + FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct UnaryOpFunctor \ + : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ + }; + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( + T, T, UnaryOpType::kAbs, HOSTDEVICE, { return std::abs(val); }, typename T) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil, + HOSTDEVICE, { + int log2_value = 0; + while ((1 << log2_value) < val) + ++log2_value; + return log2_value; + }) + +#if defined(COLOSSAL_WITH_CUDA) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE, + { return val.x + val.y; }) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE, + { return val.x + val.y + val.z + val.w; }) + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8, float, UnaryOpType::kSum, + DEVICE, { + return val.x.x + val.x.y + val.y.x + + val.y.y + val.z.x + val.z.y + + val.w.x + val.w.y; + }) + +#endif /* defined(COLOSSAL_WITH_CUDA) */ + +#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION + +} // namespace funcs +} // namespace colossalAI diff --git a/extensions/csrc/arm/cpu_adam_arm.cpp b/extensions/csrc/kernel/arm/cpu_adam_arm.cpp similarity index 100% rename from extensions/csrc/arm/cpu_adam_arm.cpp rename to extensions/csrc/kernel/arm/cpu_adam_arm.cpp diff --git a/extensions/csrc/arm/cpu_adam_arm.h b/extensions/csrc/kernel/arm/cpu_adam_arm.h similarity index 100% rename from extensions/csrc/arm/cpu_adam_arm.h rename to extensions/csrc/kernel/arm/cpu_adam_arm.h diff --git a/extensions/csrc/kernel/cuda/activation_kernel.cu b/extensions/csrc/kernel/cuda/activation_kernel.cu new file mode 100644 index 000000000000..c69003d84ac9 --- /dev/null +++ b/extensions/csrc/kernel/cuda/activation_kernel.cu @@ -0,0 +1,77 @@ +#include +#include +#include + +#include "common/micros.h" +#include "common/mp_type_traits.h" + +using colossalAI::common::MPTypeTrait; + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + using MT = typename MPTypeTrait::Type; + return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); +} + +template +__global__ void act_and_mul_kernel( + const scalar_t* __restrict__ ins_data, + scalar_t* __restrict__ outs_data, + const int64_t numel) { + using MT = typename MPTypeTrait::Type; + + int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); + const int64_t grid_size = blockDim.x * gridDim.x; + if(idx > numel) { + return; + } + + for(int64_t i = idx; i < numel; i += grid_size) { + scalar_t x = ins_data[i]; + scalar_t y = ins_data[i+numel]; + outs_data[i] = static_cast(static_cast(ACT_FN(x)) * static_cast(y)); + } +} + +// Note(LiuYang):This func is designed for calculation mode like +// silu(x[:half_1stdim]) * (x[half_1stdim:]) +torch::Tensor silu_and_mul(const torch::Tensor& ins) +{ + // Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api + // to manipulate ins_shape which is IntArrayRef + auto ins_shape = ins.sizes().vec(); + + ins_shape[0] = ins_shape[0]/2; + if (ins_shape[0] == 1) { + ins_shape.erase(ins_shape.begin()); + } + auto outs = torch::zeros(ins_shape,ins.options()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Note(Liuyang): numel of ins must be divisible by 2 + int64_t numel = ((torch::numel(ins)) >> 1); + + // Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now + // I comment this part code,because it also cost a little time to calculate a better config + // colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); + // auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); + // dim3 grid = config.grid; + // dim3 block = config.block; + + dim3 grid((numel+255)/256); + dim3 block(256); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + ins.scalar_type(), + "silu_and_mul", + act_and_mul_kernel><<>>( + ins.data_ptr(), + outs.data_ptr(), + numel + );) + + AT_CUDA_CHECK(cudaGetLastError()); + return outs; +} diff --git a/extensions/csrc/kernel/cuda/attention/attention_utils.h b/extensions/csrc/kernel/cuda/attention/attention_utils.h new file mode 100644 index 000000000000..732936809937 --- /dev/null +++ b/extensions/csrc/kernel/cuda/attention/attention_utils.h @@ -0,0 +1,210 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2024, The Colossal-AI team. + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "common/vec_type_traits.h" +#include "funcs/binary_functor.h" +#include "funcs/cast_functor.h" +#include "funcs/ternary_functor.h" +#include "funcs/unary_functor.h" + +namespace colossalAI { +namespace cuda { +namespace attention { + +#define WARP_SIZE 32 +#define VEC_SIZE_8 8 + +#define SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) { + using A_vec = typename common::FloatVecTypeTrait::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + funcs::BinaryOpFunctor mul_vect; + funcs::UnaryOpFunctor sum_vect; + funcs::TernaryOpFunctor fma; + + A_vec qk_vec = mul_vect(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ii++) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum_vect(qk_vec); +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_ROUNDS; + mask >>= 1) { + qk += SHFL_XOR_SYNC(qk, mask); + } + +#pragma unroll + for (int mask = (NUM_THREADS_PER_X >> 1); mask > 0; mask >>= 1) { + qk += SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const VecT (&q)[N], const VecT (&k)[N]) { + return qk_dot_(q, k); + } +}; + +template +inline __device__ float block_max(float* red_smem, float max) { + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +// Perform reduction across the threads in the same warp to get the max value +// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the +// max value among every NUM_THREADS_PER_TOKEN threads. +#pragma unroll + for (int mask = (NUM_THREADS_PER_ROUNDS >> 1); mask >= NUM_THREADS_PER_X; + mask >>= 1) { + max = fmaxf(max, SHFL_XOR_SYNC(max, mask)); + } + + if (lane == 0) red_smem[warp] = max; + __syncthreads(); + + // The warps compute the final maxs. + max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + +// Parallel reduction of all tokens from the same sequence inside the warp. +#pragma unroll + for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) { + max = fmaxf(max, SHFL_XOR_SYNC(max, mask)); + } + + // Broadcast to other threads. + return SHFL_SYNC(max, 0); +} + +// here we need another block_sum instead of using block_reduce +// since we need manage shared memory in a explicit way +template +inline __device__ float block_sum(float* red_smem, float sum) { + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +// Compute the sum per warp. +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + if (lane == 0) red_smem[warp] = sum; + __syncthreads(); + + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + +// Parallel reduction of all tokens from the same sequence inside the warp. +#pragma unroll + for (int mask = (NUM_WARPS >> 1); mask > 0; mask >>= 1) { + sum += SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return SHFL_SYNC(sum, 0); +} + +// here VecT is a vector of float, whose size is N +template +inline __device__ void block_sum(float* red_smem, VecT& acc) { + float* acc_ptr = reinterpret_cast(&acc); + int warp = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + +#pragma unroll + for (int i = 0; i < N; i++) { +#pragma unroll + for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_GROUP; + mask >>= 1) { + acc_ptr[i] += SHFL_XOR_SYNC(acc_ptr[i], mask); + } + } + +#pragma unroll + for (int limit = NUM_WARPS; limit > 1; limit >>= 1) { + int mid = limit >> 1; + if (warp >= mid && warp < limit) { + float* dst = red_smem + (warp - mid) * N * NUM_THREADS_PER_GROUP; + if (lane < NUM_THREADS_PER_GROUP) { + if constexpr (N == VEC_SIZE_8) { + VecT* vdst = &((reinterpret_cast(dst))[lane]); + const int idx0 = (lane >> 2) & 0x1; + const int idx1 = idx0 ^ 0x1; + (reinterpret_cast(vdst))[idx0] = + (reinterpret_cast(acc_ptr))[idx0]; + (reinterpret_cast(vdst))[idx1] = + (reinterpret_cast(acc_ptr))[idx1]; + } else { + (reinterpret_cast(dst))[lane] = acc; + } + } + } + __syncthreads(); + + if (warp < mid) { + float* src = red_smem + warp * N * NUM_THREADS_PER_GROUP; + VecT src_reg; + if (lane < NUM_THREADS_PER_GROUP) { + float* src_ptr = reinterpret_cast(&src_reg); + if constexpr (N == VEC_SIZE_8) { + VecT* vsrc = &((reinterpret_cast(src))[lane]); + const int idx0 = (lane >> 2) & 0x1; + const int idx1 = idx0 ^ 0x1; + (reinterpret_cast(src_ptr))[idx0] = + (reinterpret_cast(vsrc))[idx0]; + (reinterpret_cast(src_ptr))[idx1] = + (reinterpret_cast(vsrc))[idx1]; + } else { + src_reg = (reinterpret_cast(src))[lane]; + } +#pragma unroll + for (int j = 0; j < N; j++) { + acc_ptr[j] += src_ptr[j]; + } + } + } + __syncthreads(); + } +} + +#undef SHFL_SYNC +#undef SHFL_XOR_SYNC + +} // namespace attention +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000000..e9b7738b0565 --- /dev/null +++ b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu @@ -0,0 +1,238 @@ +#include +#include + +#include "utils/vec_copy.h" +#include "common/micros.h" + +using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; + + +template +__global__ void context_kv_cache_memcpy_kernel( + const T* __restrict__ key, + const T* __restrict__ value, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cu_seqlens, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int batch_size, + const int block_table_stride, + const int64_t key_stride, + const int64_t value_stride, + const int x +) +{ + const int seq_token_id = blockIdx.x; + const int seq_id = blockIdx.y; + const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; + + if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { + return ; + } + + const int block_offset = seq_token_id % block_size; + const int hidden_size = head_num * head_dim; + const int total_token_id = cu_seqlens[seq_id] + seq_token_id; + int head_id; + int head_offset; + int x_id; + int x_offset; + int64_t key_src_id; + int64_t value_src_id; + int64_t target_key_id; + int64_t target_value_id; + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + head_id = i / head_dim; + head_offset = i % head_dim; + x_id = head_offset / x; + x_offset = head_offset % x; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + target_value_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy(key + key_src_id, key_cache + target_key_id); + copy(value + value_src_id, value_cache + target_value_id); + } + + // tail process + if (!Aligned) { + for (; i < hidden_size; ++i ) { + head_id = i / head_dim; + head_offset = i % head_dim; + x_id = head_offset / x; + x_offset = head_offset % x; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + target_value_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_key_id] = CastFunctor()(key[key_src_id]); + value_cache[target_value_id] = CastFunctor()(value[value_src_id]); + } + } + +} + +template +void apply_context_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, head_num, head_dim] + torch::Tensor& value, // [num_tokens, head_num, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& cu_seqlens, // [batch_size + 1] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + int batch_size = block_tables.size(0); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(max_seq_len_in_batch, batch_size); + dim3 block(std::min(thread_nums, 512)); + +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + context_kv_cache_memcpy_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + sequence_lengths.data_ptr(), \ + cu_seqlens.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + batch_size, \ + block_table_stride, \ + key_stride, \ + value_stride, \ + x \ + ); \ + } while(0) + +#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + + if (aligned) { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false); + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void context_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, head_num, head_dim] + torch::Tensor& value, // [num_tokens, head_num, head_dim] + torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& cu_seqlens, // [batch_size + 1] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + +#define _(T, CacheT) \ + apply_context_kv_cache_memcpy( \ + key, \ + value, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + cu_seqlens, \ + block_tables, \ + max_seq_len_in_batch \ + ) + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t); + break; + case at::ScalarType::Half: + _(half, uint8_t); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t); + break; + } + } + else + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float); + break; + case at::ScalarType::Half: + _(half, half); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16); + break; + } + } +#undef _ +} diff --git a/extensions/csrc/kernel/cuda/convert_fp8_kernel.cu b/extensions/csrc/kernel/cuda/convert_fp8_kernel.cu new file mode 100644 index 000000000000..90a45f9aa99a --- /dev/null +++ b/extensions/csrc/kernel/cuda/convert_fp8_kernel.cu @@ -0,0 +1,127 @@ +#include +#include +#include + +#include + +#include "common/micros.h" +#include "utils/vec_copy.h" +#include "funcs/cast_functor.h" + + +using colossalAI::cuda::utils::copy; +using colossalAI::cuda::utils::get_vec_size; +using colossalAI::funcs::CastFunctor; + +template +__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail) +{ + int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); + const int64_t grid_size = blockDim.x * gridDim.x; + if(idx > numel + tail) { + return; + } + + for(int64_t i = idx; i < numel; i += grid_size) { + copy(ins_data + i * VecSize, outs_data + i * VecSize); + } + // Tail process + if(threadIdx.x == 0) + { + for(int i = 0; i < tail; ++i) + { + outs_data[i + numel * VecSize] = CastFunctor()(ins_data[i + numel * VecSize]); + } + } +} + +template +void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output) +{ + const int kVecSize = get_vec_size(input); + const int kNumel = torch::numel(input); + + const int kVecNumel = (kNumel >> static_cast(std::log2(kVecSize))); + const int kTail = kNumel & (kVecSize - 1); + int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(grid_size); + dim3 block(256); + +#define _(VEC_SIZE) \ + convert_fp8_kernel \ + <<>> \ + (reinterpret_cast(input.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + kVecNumel, \ + kTail) + + switch (kVecSize) + { + case 1: + _(1); + break; + case 2: + _(2); + break; + case 4: + _(4); + break; + } +#undef _ + AT_CUDA_CHECK(cudaGetLastError()); +} + +void convert_fp8(torch::Tensor& input, torch::Tensor& output) +{ + TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!"); + TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || + input.scalar_type() == at::ScalarType::Float || + input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!"); + TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte || + output.scalar_type() == at::ScalarType::Float || + output.scalar_type() == at::ScalarType::Half || + output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!"); + TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!"); + +#define _(InT, OutT) \ + apply_convert_fp8(input, output) + + + if(input.scalar_type() == at::ScalarType::Byte) + { + if(output.scalar_type() == at::ScalarType::Float) + { + _(uint8_t, float); + } + else if(output.scalar_type() == at::ScalarType::Half) + { + _(uint8_t, half); + } + else if(output.scalar_type() == at::ScalarType::BFloat16) + { + _(uint8_t, __nv_bfloat16); + } + } + else + { + if(input.scalar_type() == at::ScalarType::Float) + { + _(float, uint8_t); + } + else if(input.scalar_type() == at::ScalarType::Half) + { + _(half, uint8_t); + } + else if(input.scalar_type() == at::ScalarType::BFloat16) + { + _(__nv_bfloat16, uint8_t); + } + } + +#undef _ +} diff --git a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000000..3d011a4e48ff --- /dev/null +++ b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu @@ -0,0 +1,216 @@ +#include +#include + +#include "utils/vec_copy.h" +#include "funcs/cast_functor.h" +#include "common/micros.h" + +using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; + + +template +__global__ void decode_kv_cache_memcpy_kernel( + const T* __restrict__ key, + const T* __restrict__ value, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int64_t key_stride, + const int64_t value_stride, + const int block_table_stride, + const int x +) +{ + const int seq_id = blockIdx.x; + const int seq_len = sequence_lengths[seq_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size]; + const int hidden_size = head_num * head_dim; + + if ( block_id < 0 ) { + return ; + } + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int x_id = head_offset / x; + const int x_offset = head_offset % x; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + const int64_t target_value_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy(key + key_src_id, key_cache + target_key_id); + copy(value + value_src_id, value_cache + target_value_id); + } + + if (!Aligned) { + for (; i < hidden_size; ++i ) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int x_id = head_offset / x; + const int x_offset = head_offset % x; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_key_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + const int64_t target_value_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_key_id] = CastFunctor()(key[key_src_id]); + value_cache[target_value_id] = CastFunctor()(value[value_src_id]); + } + } + +} + +template +void apply_decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + decode_kv_cache_memcpy_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + sequence_lengths.data_ptr(), \ + block_tables.data_ptr(), \ + head_num, \ + head_dim, \ + block_size, \ + key_stride, \ + value_stride, \ + block_table_stride, \ + x \ + ); \ + } while(0) + +#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \ + do { \ + switch (__vec_size) { \ + case 1: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", __vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size); + } + else { + DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size); + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + +#define _(T, CacheT) \ + apply_decode_kv_cache_memcpy( \ + key, \ + value, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + block_tables \ + ) + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t); + break; + case at::ScalarType::Half: + _(half, uint8_t); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t); + break; + } + } + else + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float); + break; + case at::ScalarType::Half: + _(half, half); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16); + break; + } + } +#undef _ +} diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu new file mode 100644 index 000000000000..08cb06a3316f --- /dev/null +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -0,0 +1,831 @@ +/*This code adapted from vllm: + * https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu + */ + +#include +#include +#include + +#include "common/micros.h" +#include "funcs/cast_functor.h" +#include "funcs/ternary_functor.h" +#include "funcs/binary_functor.h" +#include "common/vec_type_traits.h" +#include "attention/attention_utils.h" + +#define WARP_SIZE 32 +#define PARTITION_SIZE 512 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +// 2^n => 2^n, 2^n-d => 2^(n-1) +#define ROUND_DOWN_HIGHEST_POWER_OF_TWO(x) (nextHighestPowerOf2((x - (x + 1) / 2 + 1))) + +// a bit magic, you can ask chatgpt for help +// 2^n => 2^n, 2^n-d => 2^n +constexpr unsigned int nextHighestPowerOf2(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ii++) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +using colossalAI::funcs::BinaryOpType; +using colossalAI::funcs::CastFunctor; +using colossalAI::funcs::TernaryOpFunctor; +using colossalAI::funcs::TernaryOpType; +using colossalAI::common::VecTypeTrait; +using colossalAI::common::FloatVecTypeTrait; +using namespace colossalAI::cuda::attention; + +template +__device__ void data_load( + const float4* q_ptr, + float4* q_shared, + scalar_t* q_shared_ptr, + KVecT* q_vecs, // query cached at register for qk_dot, should be constructed with reference to key cache's layout + const int* block_table, + int* block_table_shared, + const int lane, + const int max_num_blocks_per_seq +) { + + #pragma unroll + for (int idx = threadIdx.x; idx < Q_SHARED_SIZE; idx += blockDim.x) { + q_shared[idx] = q_ptr[idx]; + } + + #pragma unroll + for (int idx = threadIdx.x; idx < max_num_blocks_per_seq; idx += blockDim.x) { + block_table_shared[idx] = block_table[idx]; + } + + __syncthreads(); + + // each warp access a whole block + + #pragma unroll + for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = idx % NUM_THREADS_PER_X; + q_vecs[i] = *reinterpret_cast(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE); + } +} + +template +__device__ void qk_gemv( + const cache_t* __restrict__ k_cache, + const KVecT (&q_vecs)[NUM_VECS_PER_THREAD], // Qk_dot needs NUM_VECS_PER_THREAD to do loop unrolling + float* logits, // shared memory to cache Qk_dot results + int* block_table_shared, + const float alibi_slope, + const int context_len, + float &qk_max, + const float scale, + const int kv_head_idx, + const int warp_idx, + const int lane, + const int thread_group_offset, + const int start_block_idx, + const int end_block_idx, + const int start_token_idx, + const int kv_block_stride, + const int kv_head_stride) { + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); + + KVecT k_vecs[NUM_VECS_PER_THREAD]; + + #pragma unroll + for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) { + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + i * x; + #pragma unroll + for (int idx = lane, j = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, j += 1) { + const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS; + const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS; + const int offset2 = idx % NUM_THREADS_PER_X; + k_vecs[j] = CastFunctor()(*reinterpret_cast(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE)); + } + + float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + + if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) { + const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } +} + +template +__device__ void softmax( + float* red_shared_mem, + float* logits, + float &qk_max, + float &exp_sum, + int num_tokens) { + // there exists a __syncthreads within this function + qk_max = block_max(red_shared_mem, qk_max); + + // Get the sum of the exp values. + for (int i = threadIdx.x; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + + exp_sum = block_sum(&red_shared_mem[NUM_WARPS], exp_sum); + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = threadIdx.x; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); +} + +template +__device__ void sv_gemv( + const cache_t* __restrict__ v_cache, + int* block_table_shared, + float* out_shared_mem, // shared memory to cache sv_gemv results + float* logits, + FloatVecT* accs, // registers for accumulation + const int lane, + const int warp_idx, + const int kv_head_idx, + const int start_block_idx, + const int end_block_idx, + const int context_len, + const int start_token_idx, + const int kv_block_stride, + const int kv_head_stride) { + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + zero(accs[i]); + } + + VVecT zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = static_cast(block_table_shared[block_idx]); + scalar_t logit; + + #pragma unroll + for (int idx = lane; idx < BLOCK_SIZE * NUM_VECS_PER_TOKEN; idx += WARP_STRIDE) { + const int token_idx = block_idx * BLOCK_SIZE + idx / NUM_VECS_PER_TOKEN; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + idx * VEC_SIZE; + + VVecT v_vecs[NUM_ROUNDS_PER_TOKEN]; + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = CastFunctor()(*((reinterpret_cast(v_ptr) + i * WARP_SIZE))); + } + + if (token_idx >= context_len) { + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + v_vecs[i] = zero_value; + } + } + + logit = CastFunctor()(logits[token_idx - start_token_idx]); + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + accs[i] = TernaryOpFunctor()(logit, v_vecs[i], accs[i]); + } + } + } + + // must insert a sync since both logits and out_shared_mem occupy the same buffer space + __syncthreads(); + + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + block_sum(out_shared_mem, accs[i]); + } +} + +// We only support head size of { 64, 128, 256 } +// models like Phi-2, whose head size is 80, is not supported right now +template +__global__ void flash_decoding_attention_kernel_v1( + scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] + const float* __restrict__ alibi_slopes, // [num_heads] + const int max_seq_len, + const int num_kv_heads, + const float scale, + const int max_num_blocks_per_seq, + const int q_stride, // num_heads * head_size + const int kv_block_stride, + const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int lane = thread_idx % WARP_SIZE; + const int warp_idx = thread_idx / WARP_SIZE; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int x = sizeof(float4) / sizeof(scalar_t); + constexpr int Q_SHARED_SIZE = HEAD_SIZE / x; + // here thread_group does not determine the number of threads responsible for a key + // but only the VEC_SIZE of each thread + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x); + constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE; + constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN; + constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN; + constexpr int NUM_THREADS_PER_X = x / VEC_SIZE; + constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE); + constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE; + + using KVecT = typename VecTypeTrait::Type; + using VVecT = typename VecTypeTrait::Type; + using KQuantVecT = typename VecTypeTrait::Type; + using VQuantVecT = typename VecTypeTrait::Type; + using LVecT = typename VecTypeTrait::Type; + using FloatVecT = typename FloatVecTypeTrait::Type; + + const int context_len = context_lens[seq_idx]; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const int thread_group_offset = lane % NUM_THREADS_PER_X; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + __shared__ float4 q_shared[Q_SHARED_SIZE]; + __shared__ float red_shared_mem[2 * NUM_WARPS]; + extern __shared__ char shared_mem[]; + int* block_table_shared = reinterpret_cast(shared_mem); + float* logits = reinterpret_cast(shared_mem + shared_memory_offset); + float* out_shared_mem = reinterpret_cast(shared_mem + shared_memory_offset); + float qk_max = -FLT_MAX; + float exp_sum = 0.f; + + const float4* q_ptr = reinterpret_cast(q + seq_idx * q_stride + head_idx * HEAD_SIZE); + scalar_t* q_shared_ptr = reinterpret_cast(q_shared); + KVecT q_vecs[NUM_VECS_PER_THREAD]; + + // 1. load query and block_table from global memory to shared memory + data_load(q_ptr, q_shared, q_shared_ptr, q_vecs, block_table, block_table_shared, lane, max_num_blocks_per_seq); + + // 2. compute the dot product of query and key cache + qk_gemv(k_cache, q_vecs, logits, block_table_shared, alibi_slope, context_len, qk_max, scale, kv_head_idx, warp_idx, lane, thread_group_offset, 0, num_context_blocks, 0, kv_block_stride, kv_head_stride); + + // 3. compute the softmax + softmax(red_shared_mem, logits, qk_max, exp_sum, context_len); + + FloatVecT accs[NUM_ROUNDS_PER_TOKEN]; + + // 4. compute the dot product of softmax tensor and value cache + sv_gemv(v_cache, block_table_shared, out_shared_mem, logits, accs, lane, warp_idx, kv_head_idx, 0, num_context_blocks, context_len, 0, kv_block_stride, kv_head_stride); + + // 5. write back to global memory + scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE; + LVecT out_reg; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + if (thread_idx < NUM_THREADS_PER_TOKEN) { + out_reg = CastFunctor()(accs[i]); + (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; + } + } +} + +#define LAUNCH_FLASH_DECODING_ATTENTION_V1(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_attention_kernel_v1), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + flash_decoding_attention_kernel_v1 \ + <<>>( \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + context_lens.data_ptr(), \ + block_tables.data_ptr(), \ + alibi_slopes_ptr, \ + max_context_len, \ + num_kv_heads, \ + scale, \ + max_num_blocks_per_seq, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void flash_decoding_attention_v1_launcher( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int max_context_len, + float scale, + const c10::optional& alibi_slopes) { + int num_tokens = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int q_stride = query.stride(0); + + int max_num_blocks_per_seq = block_tables.size(1); + + int num_kv_heads = key_cache.size(1); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T)); + const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE; + const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float); + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + dim3 grid(num_heads, num_tokens, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. + case 64: + LAUNCH_FLASH_DECODING_ATTENTION_V1(64); + break; + case 128: + LAUNCH_FLASH_DECODING_ATTENTION_V1(128); + break; + case 256: + LAUNCH_FLASH_DECODING_ATTENTION_V1(256); + break; + default: + AT_ERROR("head size must be 64, 128, 256"); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \ + flash_decoding_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + context_lens, \ + block_tables, \ + max_context_len, \ + scale, \ + alibi_slopes); + + +template +__global__ void flash_decoding_attention_kernel_v2( + scalar_t* __restrict__ out, // [num_tokens, num_heads, max_num_partitions, head_size] + float* __restrict__ exp_sums, // [num_tokens, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_tokens, num_heads, max_num_partitions] + const scalar_t* __restrict__ q, // [num_tokens, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq] + const float* __restrict__ alibi_slopes, // [num_heads] + const int max_seq_len, + const int num_kv_heads, + const float scale, + const int max_num_blocks_per_seq, + const int q_stride, // num_heads * head_size + const int tmp_stride, // num_heads * max_num_partitions + const int kv_block_stride, + const int kv_head_stride) { + const int partition_idx = blockIdx.z; + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int thread_idx = threadIdx.x; + const int lane = thread_idx % WARP_SIZE; + const int warp_idx = thread_idx / WARP_SIZE; + const int max_num_partitions = gridDim.z; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int x = sizeof(float4) / sizeof(scalar_t); + constexpr int Q_SHARED_SIZE = HEAD_SIZE / x; + // here thread_group does not determine the number of threads responsible for a key + // but only the VEC_SIZE of each thread + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((HEAD_SIZE / THREAD_GROUP_SIZE)), x); + constexpr int NUM_VECS_PER_TOKEN = HEAD_SIZE / VEC_SIZE; + constexpr int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + constexpr int NUM_ROUNDS_PER_TOKEN = NUM_VECS_PER_TOKEN / NUM_THREADS_PER_TOKEN; + constexpr int WARP_STRIDE = WARP_SIZE * NUM_ROUNDS_PER_TOKEN; + constexpr int NUM_THREADS_PER_X = x / VEC_SIZE; + constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE); + constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE; + constexpr int NUM_BLOCKS_PER_PARTITION = PARTITION_SIZE / BLOCK_SIZE; + + using KVecT = typename VecTypeTrait::Type; + using VVecT = typename VecTypeTrait::Type; + using KQuantVecT = typename VecTypeTrait::Type; + using VQuantVecT = typename VecTypeTrait::Type; + using LVecT = typename VecTypeTrait::Type; + using FloatVecT = typename FloatVecTypeTrait::Type; + + const int context_len = context_lens[seq_idx]; + + if (partition_idx * PARTITION_SIZE >= context_len) { + return; + } + + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const int thread_group_offset = lane % NUM_THREADS_PER_X; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = partition_idx * NUM_BLOCKS_PER_PARTITION; + const int end_block_idx = MIN(start_block_idx + NUM_BLOCKS_PER_PARTITION, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int shared_memory_offset = DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + __shared__ float4 q_shared[Q_SHARED_SIZE]; + __shared__ float red_shared_mem[2 * NUM_WARPS]; + extern __shared__ char shared_mem[]; + int* block_table_shared = reinterpret_cast(shared_mem); + float* logits = reinterpret_cast(shared_mem + shared_memory_offset); + float* out_shared_mem = reinterpret_cast(shared_mem + shared_memory_offset); + float qk_max = -FLT_MAX; + float exp_sum = 0.f; + + const float4* q_ptr = reinterpret_cast(q + seq_idx * q_stride + head_idx * HEAD_SIZE); + scalar_t* q_shared_ptr = reinterpret_cast(q_shared); + KVecT q_vecs[NUM_VECS_PER_THREAD]; + + // 1. load query and block_table from global memory to shared memory + data_load(q_ptr, q_shared, q_shared_ptr, q_vecs, block_table, block_table_shared, lane, max_num_blocks_per_seq); + + // 2. compute the dot product of query and key cache + qk_gemv(k_cache, q_vecs, logits, block_table_shared, alibi_slope, context_len, qk_max, scale, kv_head_idx, warp_idx, lane, thread_group_offset, start_block_idx, end_block_idx, start_token_idx, kv_block_stride, kv_head_stride); + + // 3. compute the softmax + softmax(red_shared_mem, logits, qk_max, exp_sum, num_tokens); + + if (thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * tmp_stride + + head_idx * max_num_partitions + + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * tmp_stride + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + *exp_sums_ptr = exp_sum; + } + + FloatVecT accs[NUM_ROUNDS_PER_TOKEN]; + + // 4. compute the dot product of softmax tensor and value cache + sv_gemv(v_cache, block_table_shared, out_shared_mem, logits, accs, lane, warp_idx, kv_head_idx, start_block_idx, end_block_idx, context_len, start_token_idx, kv_block_stride, kv_head_stride); + + // 5. write back to global memory + scalar_t* out_ptr = out + seq_idx * q_stride * max_num_partitions + + head_idx * HEAD_SIZE * max_num_partitions + + partition_idx * HEAD_SIZE; + LVecT out_reg; + #pragma unroll + for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) { + if (thread_idx < NUM_THREADS_PER_TOKEN) { + out_reg = CastFunctor()(accs[i]); + (reinterpret_cast(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg; + } + } +} + +template +__global__ void flash_decoding_reduce_kernel( + scalar_t* __restrict__ out, // [num_tokens, num_heads, head_size] + float* __restrict__ exp_sums, // [num_tokens, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_tokens, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_tokens] + const int out_stride, + const int tmp_stride, + const int max_num_partitions) { + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + + extern __shared__ char shared_mem[]; + __shared__ float red_smem[2 * NUM_WARPS]; + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * tmp_stride + + head_idx * max_num_partitions; + + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float tmp_max_logit = max_logits_ptr[i]; + shared_max_logits[i] = tmp_max_logit; + max_logit = fmaxf(max_logit, tmp_max_logit); + } + + __syncthreads(); + + max_logit = block_max(red_smem, max_logit); + + float* shared_exp_sums = reinterpret_cast(shared_mem + num_partitions * sizeof(float)); + const float* exp_sums_ptr = exp_sums + seq_idx * tmp_stride + + head_idx * max_num_partitions; + + float global_exp_sum = 0.f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float tmp_max_logit = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(tmp_max_logit - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + + __syncthreads(); + + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.f, global_exp_sum + 1e-6f); + + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * out_stride * max_num_partitions + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * out_stride + head_idx * HEAD_SIZE; + + #pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.f; + for (int j = 0; j < num_partitions; j++) { + acc += CastFunctor()(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + out_ptr[i] = CastFunctor()(acc); + } +} + + +#define LAUNCH_FLASH_DECODING_ATTENTION_V2(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_attention_kernel_v2), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + flash_decoding_attention_kernel_v2 \ + <<>>( \ + reinterpret_cast(tmp_out.data_ptr()), \ + reinterpret_cast(exp_sums.data_ptr()), \ + reinterpret_cast(max_logits.data_ptr()), \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + reinterpret_cast(context_lens.data_ptr()), \ + reinterpret_cast(block_tables.data_ptr()), \ + alibi_slopes_ptr, \ + max_context_len, \ + num_kv_heads, \ + scale, \ + max_num_blocks_per_seq, \ + q_stride, \ + tmp_stride, \ + kv_block_stride, \ + kv_head_stride); \ + cudaFuncSetAttribute( \ + ((void*)flash_decoding_reduce_kernel), \ + cudaFuncAttributeMaxDynamicSharedMemorySize, reduce_shared_mem_size); \ + flash_decoding_reduce_kernel \ + <<>>( \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(exp_sums.data_ptr()), \ + reinterpret_cast(max_logits.data_ptr()), \ + reinterpret_cast(tmp_out.data_ptr()), \ + reinterpret_cast(context_lens.data_ptr()), \ + q_stride, \ + tmp_stride, \ + max_num_partitions); + + +template< + typename T, + typename CACHE_T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void flash_decoding_attention_v2_launcher( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int max_context_len, + float scale, + const c10::optional& alibi_slopes) { + int num_tokens = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int q_stride = query.stride(0); + int tmp_stride = exp_sums.stride(0); + + int max_num_blocks_per_seq = block_tables.size(1); + + int num_kv_heads = key_cache.size(1); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + const int VEC_SIZE = MIN(ROUND_DOWN_HIGHEST_POWER_OF_TWO((head_size / THREAD_GROUP_SIZE)), sizeof(float4) / sizeof(T)); + const int NUM_VECS_PER_TOKEN = head_size / VEC_SIZE; + const int NUM_THREADS_PER_TOKEN = MIN(NUM_VECS_PER_TOKEN, WARP_SIZE); + + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * NUM_THREADS_PER_TOKEN * VEC_SIZE * sizeof(float); + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4); + + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + dim3 grid(num_heads, num_tokens, max_num_partitions); + dim3 block(NUM_THREADS); + + dim3 reduce_grid(num_heads, num_tokens); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. + case 64: + LAUNCH_FLASH_DECODING_ATTENTION_V2(64); + break; + case 128: + LAUNCH_FLASH_DECODING_ATTENTION_V2(128); + break; + case 256: + LAUNCH_FLASH_DECODING_ATTENTION_V2(256); + break; + default: + AT_ERROR("head size must be 64, 128, 256"); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \ + flash_decoding_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + context_lens, \ + block_tables, \ + max_context_len, \ + scale, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_LAUNCHER_BLOCK_SIZE(Version, T, CACHE_T) \ + switch (block_size) { \ + case 8: \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 8); \ + break; \ + case 16: \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 16); \ + break; \ + case 32: \ + CALL_##Version##_LAUNCHER(T, CACHE_T, 32); \ + break; \ + default: \ + AT_ERROR("block size must be 8, 16, 32"); \ + break; \ + } + +#define CALL_LAUNCHER_DTYPE(Version) \ + if(key_cache.scalar_type() == at::ScalarType::Byte) \ + { \ + switch (query.scalar_type()) { \ + case at::ScalarType::Float: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, float, uint8_t); \ + break; \ + case at::ScalarType::Half: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, half, uint8_t); \ + break; \ + case at::ScalarType::BFloat16: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, __nv_bfloat16, uint8_t); \ + break; \ + } \ + } \ + else \ + { \ + switch (query.scalar_type()) { \ + case at::ScalarType::Float: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, float, float); \ + break; \ + case at::ScalarType::Half: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, half, half); \ + break; \ + case at::ScalarType::BFloat16: \ + CALL_LAUNCHER_BLOCK_SIZE(Version, __nv_bfloat16, __nv_bfloat16); \ + break; \ + } \ + } + +void flash_decoding_attention( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int block_size, + int max_context_len, + torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] + const c10::optional& alibi_slopes, + float scale) { + + int num_tokens = query.size(0); + int num_heads = query.size(1); + + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + // TODO(luoxiang): Need to be tuned + bool use_v1 = max_context_len <= 8192 && (max_num_partitions == 1 || num_tokens * num_heads > 512); + + if (use_v1) { + CALL_LAUNCHER_DTYPE(V1); + } else { + CALL_LAUNCHER_DTYPE(V2); + } +} + + +#undef LAUNCH_FLASH_DECODING_ATTENTION_V1 +#undef CALL_LAUNCHER +#undef CALL_LAUNCHER_BLOCK_SIZE +#undef CALL_LAUNCHER_DTYPE diff --git a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu new file mode 100644 index 000000000000..6dc9495ef7d9 --- /dev/null +++ b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -0,0 +1,526 @@ +// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision +#include +#include + +#include "utils/vec_copy.h" +#include "common/micros.h" +#include "common/mp_type_traits.h" +#include "funcs/cast_functor.h" +#include "funcs/binary_functor.h" + +using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; +using colossalAI::funcs::BinaryOpFunctor; +using colossalAI::funcs::BinaryOpType; + +template +__device__ void apply_emb_rotary_compute( + T* __restrict__ src, const MT* __restrict__ cos_ptr, + const MT* __restrict__ sin_ptr, const int64_t stride, + const int token_id, const int shard_block_size, const int half_head_dim, + const int head_num, const int head_dim) { + BinaryOpFunctor mul; + BinaryOpFunctor sub; + BinaryOpFunctor add; + CastFunctor t2mt; + CastFunctor mt2t; + + T x[VecSize]; + T y[VecSize]; + T out_x[VecSize]; + T out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * stride + (i / half_head_dim) * head_dim + head_offset; + + copy(src + addr_offset, x); + copy(src + addr_offset + half_head_dim, y); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = mt2t(sub(mul(t2mt(x[j]), cos_ptr[j * 32 + shard_offset]), + mul(t2mt(y[j]), sin_ptr[j * 32 + shard_offset]))); + out_y[j] = mt2t(add(mul(t2mt(y[j]), cos_ptr[j * 32 + shard_offset]), + mul(t2mt(x[j]), sin_ptr[j * 32 + shard_offset]))); + } + + copy(out_x, src + addr_offset); + copy(out_y, src + addr_offset + half_head_dim); + } +} + +template +__device__ void apply_kv_memcopy( + T* __restrict__ src, CacheT* __restrict__ cache, + const int64_t stride, const int token_id, const int block_id, + const int hidden_size, const int block_size, const int block_offset, + const int head_dim, const int half_head_dim) { + for (int i = threadIdx.x * VecSize; i < hidden_size / 2; + i += blockDim.x * VecSize) { + const int head_id = i / half_head_dim; + const int head_offset = i % half_head_dim; + const int64_t src_id = token_id * stride + head_id * head_dim + head_offset; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy(src + src_id, cache + target_id); + copy(src + src_id + half_head_dim, cache + target_id + half_head_dim); + } +} + +template +__device__ void cos_sin_memory_access( + const T* __restrict__ cos, const T* __restrict__ sin, + MT* cos_ptr, MT* sin_ptr, const int token_id, + const int shard_block_size, const int cos_stride, const int sin_stride, + const int half_head_dim) { + for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { + // We assume that the value of head_dim is less than 128*128. + const int shard_offset = (i % shard_block_size) / VecSize; + const int shard_head = + (i / shard_block_size) * shard_block_size + i % VecSize * 32; + cos_ptr[shard_head + shard_offset] = CastFunctor()(cos[token_id * cos_stride + i]); + sin_ptr[shard_head + shard_offset] = CastFunctor()(sin[token_id * sin_stride + i]); + } +} + +template +__device__ void apply_k_rotary_emb_compute( + T* __restrict__ key, T* __restrict__ value, + CacheT* __restrict__ key_cache, CacheT* __restrict__ value_cache, + const MT* __restrict__ cos_ptr, const MT* __restrict__ sin_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, const int64_t key_stride, + const int64_t value_stride, const int token_id, + const int block_table_stride, const int head_num, const int head_dim, + const int kv_head_num, const int block_size, const int x, const int half_head_dim, + const int shard_block_size) { + + BinaryOpFunctor mul; + BinaryOpFunctor sub; + BinaryOpFunctor add; + const int seq_len = sequence_lengths[token_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = + block_tables[token_id * block_table_stride + seq_len / block_size]; + + if (block_id < 0) { + return; + } + + T x0[VecSize]; + T x1[VecSize]; + T out_x[VecSize]; + T out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int half_head_offset = i % half_head_dim; + const int x_id = half_head_offset / x; + const int x_offset = half_head_offset % x; + const int shard_offset = + (half_head_offset / shard_block_size) * shard_block_size + + (half_head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * key_stride + (i / half_head_dim) * head_dim + half_head_offset; + const int64_t target_id = block_id * kv_head_num * head_dim * block_size + + (i / half_head_dim) * block_size * head_dim + + x_id * block_size * x + + block_offset * x + + x_offset; + + copy(key + addr_offset, x0); + copy(key + addr_offset + half_head_dim, x1); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = CastFunctor()(sub(mul(CastFunctor()(x0[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(x1[j]), sin_ptr[j * 32 + shard_offset]))); + out_y[j] = CastFunctor()(add(mul(CastFunctor()(x1[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(x0[j]), sin_ptr[j * 32 + shard_offset]))); + } + + copy(out_x, key_cache + target_id); + copy(out_y, key_cache + target_id + half_head_dim * block_size); + } + + // apply value memcopy + apply_kv_memcopy( + value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim, + block_size, block_offset, head_dim, half_head_dim); +} + +template +__global__ void rotary_embedding_and_cache_copy_kernel( + T* __restrict__ query, + T* __restrict__ key, + T* __restrict__ value, + const T* __restrict__ cos, + const T* __restrict__ sin, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int64_t query_stride, + const int64_t key_stride, + const int64_t value_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int block_table_stride, + const int head_num, + const int head_dim, + const int kv_head_num, + const int block_size, + const int x +) { + + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + MT *cos_ptr = reinterpret_cast(shard_ptr); + MT *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key and copy kv + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size); +} + +template +__global__ void rotary_embedding_kernel( + T* __restrict__ query, + T* __restrict__ key, + const T* __restrict__ cos, + const T* __restrict__ sin, + const int64_t query_stride, + const int64_t key_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int head_num, + const int head_dim, + const int kv_head_num +) { + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + MT *cos_ptr = (MT*)shard_ptr; + MT *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); +} + +#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \ + rotary_embedding_and_cache_copy_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(cos.data_ptr()), \ + reinterpret_cast(sin.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + sequence_lengths.data_ptr(), \ + block_tables.data_ptr(), \ + query_stride, \ + key_stride, \ + value_stride, \ + shard_element_num / 2, \ + cos_stride, \ + sin_stride, \ + block_table_stride, \ + head_num, \ + head_dim, \ + kv_head_num, \ + block_size, \ + x); \ + + +template +void apply_rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int64_t query_stride = query.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + int block_table_stride = block_tables.stride(0); + + using MT = typename colossalAI::common::ScalarTypeTrait::Type; + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size; + const int shared_memory_size = shard_element_num * sizeof(MT); + + switch (vec_size) { + case 1: + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(1); + break; + case 2: + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(2); + break; + case 4: + ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(4); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +void apply_rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + + int query_stride = query.stride(0); + int key_stride = key.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + + using MT = typename colossalAI::common::ScalarTypeTrait::Type; + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 2: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 4: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision) +{ +#define _(T, CacheT, HIGH_PRECISION) \ + apply_rotary_embedding_and_cache_copy( \ + query, \ + key, \ + value, \ + cos, \ + sin, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + block_tables); + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + if(high_precision) { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t, true) + break; + case at::ScalarType::Half: + _(half, uint8_t, true) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t, true) + break; + } + } + else { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t, false) + break; + case at::ScalarType::Half: + _(half, uint8_t, false) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t, false) + break; + } + } + } + else + { + if(high_precision) { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float, true) + break; + case at::ScalarType::Half: + _(half, half, true) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16, true) + break; + } + } + else { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float, false) + break; + case at::ScalarType::Half: + _(half, half, false) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16, false) + break; + } + } + } +#undef _ +} + +void rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin, // [total_tokens, head_dim] + bool high_precision +){ + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, + query.scalar_type(), + "rotary_embedding", + apply_rotary_embedding( + query, + key, + cos, + sin + );) +} diff --git a/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu b/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu new file mode 100644 index 000000000000..d5fda83ebb56 --- /dev/null +++ b/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu @@ -0,0 +1,218 @@ +#include +#include + +#include "utils/vec_copy.h" +#include "common/micros.h" + +using colossalAI::cuda::utils::copy; +using colossalAI::cuda::utils::get_vec_size; + + +template +__device__ void apply_cos_and_sin_memcopy( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int head_dim, + const int dest_offset_id, + const int src_offset_id + ) { + + int begin_id = threadIdx.x * VecSize; + + for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){ + copy(cos_cache_ptr + src_offset_id + begin_id, cos + dest_offset_id + begin_id); + copy(sin_cache_ptr + src_offset_id + begin_id, sin + dest_offset_id + begin_id); + } + + if (!Aligned) { + for (; begin_id < head_dim; ++begin_id ) { + cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id]; + sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id]; + } + } +} + +template +__global__ void apply_get_context_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cumsum_lengths, + const int batch_size, + const int head_dim +) { + int token_id = blockIdx.x; + if ( token_id >= sequence_lengths[blockIdx.y] ) { + return ; + } + + int src_offset_id = token_id * head_dim; + int dest_offset_id = src_offset_id; + + if (blockIdx.y > 0) { + dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim; + } + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); + +} + +template +__global__ void apply_get_decode_cos_and_sin_kernel( + scalar_t* __restrict__ cos, + scalar_t* __restrict__ sin, + const scalar_t* __restrict__ cos_cache_ptr, + const scalar_t* __restrict__ sin_cache_ptr, + const int* __restrict__ sequence_lengths, + const int batch_size, + const int head_dim +) { + int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim; + int dest_offset_id = blockIdx.y * head_dim; + + apply_cos_and_sin_memcopy( + cos, + sin, + cos_cache_ptr, + sin_cache_ptr, + sequence_lengths, + head_dim, + dest_offset_id, + src_offset_id + ); +} + +template +void apply_get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + int token_num = cos.size(0); + int head_dim = cos.size(1); + int batch_size = sequence_lengths.size(0); + + at::Tensor cumsum_lengths; + + int vec_size = get_vec_size(cos); + + bool aligned = true; + if (head_dim % vec_size != 0) { + aligned = false; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int block_size_y; + int block_size_x; + + if (is_prompts) { + block_size_y = batch_size; + block_size_x = max_seq_len_in_batch; + // TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on. + cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32); + } + else{ + block_size_y = batch_size; + block_size_x = 1; + } + + int thread_nums = (head_dim + vec_size - 1) / vec_size; + + dim3 grid(block_size_x, block_size_y); + dim3 block(std::min(thread_nums, 512)); + +#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \ + do { \ + if (is_prompts){ \ + apply_get_context_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + cumsum_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + else { \ + apply_get_decode_cos_and_sin_kernel<<>>( \ + cos.data_ptr(), \ + sin.data_ptr(), \ + cos_cache.data_ptr(), \ + sin_cache.data_ptr(), \ + sequence_lengths.data_ptr(), \ + batch_size, \ + head_dim \ + ); \ + } \ + } while(0) + +#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ + do { \ + switch (vec_size) { \ + case 1: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \ + break; \ + case 2: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \ + break; \ + case 4: \ + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \ + break; \ + default: \ + AT_ERROR("Unsupported vectorized size ", vec_size); \ + break; \ + } \ + } while(0) + + if (aligned) { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true); + } + else { + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void get_cos_and_sin( + at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, + bool is_prompts +) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + cos.scalar_type(), + "get_cos_and_sin", + apply_get_cos_and_sin( + cos_cache, + sin_cache, + cos, + sin, + sequence_lengths, + max_seq_len_in_batch, + is_prompts + );) +} diff --git a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu b/extensions/csrc/kernel/cuda/layer_norm_kernel.cu similarity index 96% rename from extensions/csrc/cuda/layer_norm_cuda_kernel.cu rename to extensions/csrc/kernel/cuda/layer_norm_kernel.cu index 72b84d6ca40f..cd569f741a51 100644 --- a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu +++ b/extensions/csrc/kernel/cuda/layer_norm_kernel.cu @@ -9,7 +9,7 @@ #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/DeviceUtils.cuh" -#include "type_shim.h" +#include "common/micros.h" template __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { @@ -606,11 +606,11 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, using namespace at; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", - HostApplyLayerNorm(output->DATA_PTR(), - mean->DATA_PTR(), invvar->DATA_PTR(), - input->DATA_PTR(), n1, n2, epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL, - beta != NULL ? beta->DATA_PTR() : NULL);) + HostApplyLayerNorm(output->data_ptr(), + mean->data_ptr(), invvar->data_ptr(), + input->data_ptr(), n1, n2, epsilon, + gamma != NULL ? gamma->data_ptr() : NULL, + beta != NULL ? beta->data_ptr() : NULL);) } template @@ -633,14 +633,14 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, {part_size, n2}, input->options().dtype(at::ScalarType::Float)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<>>( - dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), - part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR()); + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr()); const dim3 threads3(32, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(U); cuComputeGradGammaBeta<<>>( - part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), part_size, n1, n2, grad_gamma, grad_beta); } @@ -651,7 +651,7 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, const dim3 threads1(32, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; cuComputeGradInput<<>>( - dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), gamma, + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input); } @@ -671,13 +671,13 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel", HostLayerNormGradient( - dout->DATA_PTR(), mean->DATA_PTR(), - invvar->DATA_PTR(), input, n1, n2, + dout->data_ptr(), mean->data_ptr(), + invvar->data_ptr(), input, n1, n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - gamma != NULL ? beta->DATA_PTR() : NULL, epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL);) + gamma != NULL ? gamma->data_ptr() : NULL, + gamma != NULL ? beta->data_ptr() : NULL, epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL, + gamma != NULL ? grad_beta->data_ptr() : NULL);) } diff --git a/extensions/csrc/cuda/moe_cuda_kernel.cu b/extensions/csrc/kernel/cuda/moe_kernel.cu similarity index 93% rename from extensions/csrc/cuda/moe_cuda_kernel.cu rename to extensions/csrc/kernel/cuda/moe_kernel.cu index 66c1e6bd260e..ff74800869d8 100644 --- a/extensions/csrc/cuda/moe_cuda_kernel.cu +++ b/extensions/csrc/kernel/cuda/moe_kernel.cu @@ -4,7 +4,10 @@ #include -#include "block_reduce.h" +#include "funcs/reduce_function.h" + +using colossalAI::funcs::block_reduce; +using colossalAI::funcs::ReduceType; template __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { @@ -157,8 +160,7 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, BlockStore(ts_store).Store(src_row + idx, grad); } - - blockReduce(&thread_sum); + block_reduce(&thread_sum); if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); } @@ -230,7 +232,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, BlockStore(ts_store).Store(src_row2 + idx, sgrad2); } - blockReduce(thread_sum); + block_reduce(thread_sum); if (threadIdx.x == 0) *weight_grad1 = static_cast(thread_sum[0]); @@ -537,7 +539,7 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { // API FUNCTIONS -------------------------------- -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ +#define DISPATCH_FLOAT_AND_HALF_MOE(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Float: { \ using scalar_t = float; \ @@ -563,13 +565,13 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); auto k = mask.size(0); - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_MOE( batch_tokens.scalar_type(), "moe dispatch forward", moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + batch_tokens.data_ptr(), res.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); return res; } @@ -583,13 +585,13 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); auto k = mask.size(0); - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_MOE( expert_grad.scalar_type(), "moe dispatch backward", moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + res.data_ptr(), expert_grad.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); return res; } @@ -606,13 +608,13 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); auto k = mask.size(0); - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_MOE( expert_tokens.scalar_type(), "moe combine forward", moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + expert_tokens.data_ptr(), res.data_ptr(), + logits.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, h)); return res; @@ -633,14 +635,14 @@ std::vector moe_combine_cuda_backward( {s, e}, torch::dtype(logits.dtype()).device(logits.device())); auto k = mask.size(0); - DISPATCH_FLOAT_AND_HALF( + DISPATCH_FLOAT_AND_HALF_MOE( tokens_grad.scalar_type(), "moe combine backward", moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + tokens_grad.data_ptr(), egrad.data_ptr(), + expert_tokens.data_ptr(), logits.data_ptr(), + wgrad.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, h)); return {egrad, wgrad}; @@ -653,7 +655,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { const int s = mask.size(0), e = mask.size(1); auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); + cumsum_launch(mask.data_ptr(), res.data_ptr(), s, e); return res; } diff --git a/extensions/csrc/cuda/multi_tensor_adam.cu b/extensions/csrc/kernel/cuda/multi_tensor_adam_kernel.cu similarity index 99% rename from extensions/csrc/cuda/multi_tensor_adam.cu rename to extensions/csrc/kernel/cuda/multi_tensor_adam_kernel.cu index 9cc3ae1eac10..e0c2f0b4c819 100644 --- a/extensions/csrc/cuda/multi_tensor_adam.cu +++ b/extensions/csrc/kernel/cuda/multi_tensor_adam_kernel.cu @@ -15,7 +15,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_apply.cuh b/extensions/csrc/kernel/cuda/multi_tensor_apply.cuh similarity index 98% rename from extensions/csrc/cuda/multi_tensor_apply.cuh rename to extensions/csrc/kernel/cuda/multi_tensor_apply.cuh index ec55dd320b40..8c98687ce02d 100644 --- a/extensions/csrc/cuda/multi_tensor_apply.cuh +++ b/extensions/csrc/kernel/cuda/multi_tensor_apply.cuh @@ -12,7 +12,7 @@ #include #include -#include "compat.h" +#include "common/micros.h" // #include @@ -104,7 +104,7 @@ void multi_tensor_apply( if (tensors_full || blocks_full || last_chunk) { // using accscalar_t = acc_type; multi_tensor_apply_kernel<<>>( - chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + chunk_size, noop_flag.data_ptr(), tl, callable, args...); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/kernel/cuda/multi_tensor_l2norm_kernel.cu similarity index 81% rename from extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu rename to extensions/csrc/kernel/cuda/multi_tensor_l2norm_kernel.cu index 85f935152f8a..3596aa3d575c 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/kernel/cuda/multi_tensor_l2norm_kernel.cu @@ -11,11 +11,99 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + template __device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; @@ -289,8 +377,8 @@ std::tuple multi_tensor_l2norm_cuda( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - L2NormFunctor(), output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, max_chunks_per_tensor);) AT_CUDA_CHECK(cudaGetLastError()); @@ -303,10 +391,10 @@ std::tuple multi_tensor_l2norm_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); auto stream = at::cuda::getCurrentCUDAStream(); cleanup<<>>( - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - ret.DATA_PTR(), - per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, per_tensor, + output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, + ret.data_ptr(), + per_tensor ? ret_per_tensor.data_ptr() : nullptr, per_tensor, max_chunks_per_tensor); return std::tuple(ret, ret_per_tensor); @@ -349,15 +437,15 @@ void multi_tensor_norm_out_cuda( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - MaxNormFunctor(), output.DATA_PTR(), - output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + MaxNormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) } else { DISPATCH_FLOAT_AND_HALF( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - L2NormFunctor(), output.DATA_PTR(), - output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + L2NormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) } AT_CUDA_CHECK(cudaGetLastError()); @@ -374,8 +462,8 @@ void multi_tensor_norm_out_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); auto stream = at::cuda::getCurrentCUDAStream(); cleanup_v2<<>>( - output.DATA_PTR(), output_per_tensor.DATA_PTR(), - ret.DATA_PTR(), out.DATA_PTR(), true, max_chunks_per_tensor, + output.data_ptr(), output_per_tensor.data_ptr(), + ret.data_ptr(), out.data_ptr(), true, max_chunks_per_tensor, norm_type, alpha, beta); return; diff --git a/extensions/csrc/cuda/multi_tensor_lamb.cu b/extensions/csrc/kernel/cuda/multi_tensor_lamb_kernel.cu similarity index 98% rename from extensions/csrc/cuda/multi_tensor_lamb.cu rename to extensions/csrc/kernel/cuda/multi_tensor_lamb_kernel.cu index 63771cf40bcb..05b3d1199937 100644 --- a/extensions/csrc/cuda/multi_tensor_lamb.cu +++ b/extensions/csrc/kernel/cuda/multi_tensor_lamb_kernel.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 @@ -333,7 +333,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, - global_grad_norm.DATA_PTR(), max_grad_norm);) + global_grad_norm.data_ptr(), max_grad_norm);) // Compute update norms auto update_norm_tuple = @@ -346,8 +346,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor(), - std::get<1>(param_norm_tuple).DATA_PTR(), - std::get<1>(update_norm_tuple).DATA_PTR(), + std::get<1>(param_norm_tuple).data_ptr(), + std::get<1>(update_norm_tuple).data_ptr(), lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); diff --git a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu b/extensions/csrc/kernel/cuda/multi_tensor_scale_kernel.cu similarity index 99% rename from extensions/csrc/cuda/multi_tensor_scale_kernel.cu rename to extensions/csrc/kernel/cuda/multi_tensor_scale_kernel.cu index 2f58a0f16dce..a84c93c3b1cd 100644 --- a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu +++ b/extensions/csrc/kernel/cuda/multi_tensor_scale_kernel.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu b/extensions/csrc/kernel/cuda/multi_tensor_sgd_kernel.cu similarity index 99% rename from extensions/csrc/cuda/multi_tensor_sgd_kernel.cu rename to extensions/csrc/kernel/cuda/multi_tensor_sgd_kernel.cu index 7f48dbd5d497..d48bb7053df4 100644 --- a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu +++ b/extensions/csrc/kernel/cuda/multi_tensor_sgd_kernel.cu @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "common/micros.h" #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 diff --git a/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu new file mode 100644 index 000000000000..ca359df8d6dc --- /dev/null +++ b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu @@ -0,0 +1,348 @@ +/*This code from FasterTransformer: + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu + * with minor changes. */ + +#include +#include +#include + + +#include "common/micros.h" +#include "funcs/cast_functor.h" +#include "funcs/binary_functor.h" +#include "funcs/reduce_function.h" +#include "common/vec_type_traits.h" + +using colossalAI::funcs::block_reduce; +using colossalAI::funcs::ReduceType; +using colossalAI::funcs::CastFunctor; +using colossalAI::funcs::BinaryOpFunctor; +using colossalAI::funcs::BinaryOpType; +using colossalAI::common::VecTypeTrait; + +#define RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \ + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \ + input.element_size(), \ + input.scalar_type(), \ + "rms_layernorm_kernel", \ + rms_layernorm_kernel<<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size);) \ + +#define FUSED_ADD_RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \ + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \ + input.element_size(), \ + input.scalar_type(), \ + "fused_add_rms_layernorm_kernel", \ + fused_add_rms_layernorm_kernel<<>>( \ + input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size);) \ + +// optimized for half and bf16 +template +__global__ void rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename VecTypeTrait::Type; + BinaryOpFunctor mul_scalar2t; + __shared__ float s_variance; + + /* + * since the open-sourced LLM's hidden dimensions mainly range from + * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported + * hidden dimension limit to 8192, and each thread's capacity + * for caching input tensors to 8 (8192 = 8 * 1024) which + * will cause problems for extremely large models, such as + * Megatron-Turing NLG 530B with hidden dimensions up to 20480 + */ + scalar2_t x_local[4]; + + scalar2_t* out_ptr = (scalar2_t*)out; + const scalar2_t* input_ptr = (scalar2_t*)input; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = CastFunctor()(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); + } +} + +template +__global__ void general_rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = (float) input[id]; + variance += x_local[cnt] * x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +// optimized for half and bf16 +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename VecTypeTrait::Type; + BinaryOpFunctor add_scalar2t; + BinaryOpFunctor mul_scalar2t; + + __shared__ float s_variance; + scalar2_t x_local[4]; + + scalar2_t* input_ptr = (scalar2_t*)input; + scalar2_t* residual_ptr = (scalar2_t*)residual; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]); + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + residual_ptr[id] = x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = CastFunctor()(s_variance); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); + } +} + +template +__global__ void general_fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = (float) input[id]; + x_local[cnt] += (float) residual[id]; + variance += x_local[cnt] * x_local[cnt]; + residual[id] = (scalar_t) x_local[cnt]; + } + block_reduce(&variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + + +#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ + if (DATA_SIZE == 2) { \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } else { \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + general_##__VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } \ + + +void rms_layernorm( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + RMSNORM_LAUNCHER(8, hidden_size / 8); + } else { + RMSNORM_LAUNCHER(4, hidden_size / 8); + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + RMSNORM_LAUNCHER(1, block); + break; + case 2: + RMSNORM_LAUNCHER(2, block); + break; + case 3: + RMSNORM_LAUNCHER(3, block); + break; + case 4: + RMSNORM_LAUNCHER(4, block); + break; + case 5: + RMSNORM_LAUNCHER(5, block); + break; + case 8: + RMSNORM_LAUNCHER(8, block); + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8"); + } + } +} + +void fused_add_rms_layernorm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + FUSED_ADD_RMSNORM_LAUNCHER(8, hidden_size / 8); + } else { + FUSED_ADD_RMSNORM_LAUNCHER(4, hidden_size / 8); + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + FUSED_ADD_RMSNORM_LAUNCHER(1, block); + break; + case 2: + FUSED_ADD_RMSNORM_LAUNCHER(2, block); + break; + case 3: + FUSED_ADD_RMSNORM_LAUNCHER(3, block); + break; + case 4: + FUSED_ADD_RMSNORM_LAUNCHER(4, block); + break; + case 5: + FUSED_ADD_RMSNORM_LAUNCHER(5, block); + break; + case 8: + FUSED_ADD_RMSNORM_LAUNCHER(8, block); + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8"); + } + } +} + +#undef DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT diff --git a/extensions/csrc/cuda/scaled_masked_softmax.h b/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu similarity index 80% rename from extensions/csrc/cuda/scaled_masked_softmax.h rename to extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu index d3e6f04e6093..00455897ebb3 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax.h +++ b/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu @@ -1,99 +1,30 @@ /*This code from NVIDIA Megatron: * with minor changes. */ -#pragma once +#include +#include +#include +#include +#include +#include +#include #include #include -#include -#include - #include #include -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} +#include "common/micros.h" +#include "utils/vec_copy.h" +#include "funcs/reduce_function.h" +#include "funcs/unary_functor.h" -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} +using colossalAI::funcs::UnaryOpFunctor; +using colossalAI::funcs::UnaryOpType; +using colossalAI::funcs::warp_reduce; +using colossalAI::funcs::ReduceType; +using colossalAI::cuda::utils::copy; -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} /* * Extended softmax (from native aten pytorch) with following additional @@ -156,8 +87,8 @@ __global__ void scaled_masked_softmax_warp_forward( if (element_index < batch_element_count) { int itr_idx = i * element_count + it * WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); + copy(src + itr_idx, temp_data); + copy(mask + itr_idx, temp_mask); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -187,7 +118,7 @@ __global__ void scaled_masked_softmax_warp_forward( (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } - warp_reduce(max_value); + warp_reduce(max_value); acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll @@ -198,7 +129,7 @@ __global__ void scaled_masked_softmax_warp_forward( sum[i] += elements[i][it]; } } - warp_reduce(sum); + warp_reduce(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; @@ -213,8 +144,8 @@ __global__ void scaled_masked_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy( + out, dst + i * element_count + it * WARP_SIZE); } else { break; } @@ -269,10 +200,10 @@ __global__ void scaled_masked_softmax_warp_backward( for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count + it * WARP_SIZE); + copy( + grad + i * element_count + it * WARP_SIZE, temp_grad); + copy( + output + i * element_count + it * WARP_SIZE, temp_output); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -296,7 +227,7 @@ __global__ void scaled_masked_softmax_warp_backward( sum[i] += grad_reg[i][it]; } } - warp_reduce(sum); + warp_reduce(sum); // store result #pragma unroll @@ -314,17 +245,17 @@ __global__ void scaled_masked_softmax_warp_backward( (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); + copy( + out, gradInput + i * element_count + it * WARP_SIZE); } } } } -} // end of anonymous namespace + int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) { - int log2_elements = log2_ceil(key_seq_len); + int log2_elements = UnaryOpFunctor()(key_seq_len); const int next_power_of_two = 1 << log2_elements; int warp_size = @@ -349,7 +280,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, if (key_seq_len == 0) { return; } else { - int log2_elements = log2_ceil(key_seq_len); + int log2_elements = UnaryOpFunctor()(key_seq_len); const int next_power_of_two = 1 << log2_elements; int batch_count = batches * attn_heads * query_seq_len; @@ -449,7 +380,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, if (key_seq_len == 0) { return; } else { - int log2_elements = log2_ceil(key_seq_len); + int log2_elements = UnaryOpFunctor()(key_seq_len); const int next_power_of_two = 1 << log2_elements; int batch_count = batches * attn_heads * query_seq_len; @@ -536,3 +467,67 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, } } } + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = torch::empty( + {batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, + // seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); + + // backward pass is completely in-place + return output_grads; +} diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu similarity index 83% rename from extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h rename to extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu index 54c8e9133a1b..42d14b423749 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ b/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu @@ -1,124 +1,30 @@ /*This code from NVIDIA Megatron: * with minor changes. */ -#pragma once - +#include +#include +#include +#include +#include +#include +#include #include #include -#include #include - #include #include -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} +#include "common/micros.h" +#include "utils/vec_copy.h" +#include "funcs/reduce_function.h" +#include "funcs/unary_functor.h" -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} +using colossalAI::funcs::UnaryOpFunctor; +using colossalAI::funcs::UnaryOpType; +using colossalAI::funcs::warp_reduce; +using colossalAI::funcs::ReduceType; +using colossalAI::cuda::utils::copy; +using colossalAI::cuda::utils::copy_zero; /* * Extended softmax (from native aten pytorch) with following additional @@ -169,8 +75,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_data, src + i * element_count * stride + it * WARP_SIZE); + copy( + src + i * element_count * stride + it * WARP_SIZE, temp_data); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -200,7 +106,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } - warp_reduce(max_value); + warp_reduce(max_value); acc_t sum[WARP_BATCH]{0.0f}; #pragma unroll @@ -213,7 +119,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( } } } - warp_reduce(sum); + warp_reduce(sum); + // store result output_t out[ELEMENTS_PER_LDG_STG]; @@ -233,10 +140,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( out[element] = 0; } } - copy_vector( - dst + i * element_count * stride + it * WARP_SIZE, out); + copy( + out, dst + i * element_count * stride + it * WARP_SIZE); } else if (element_index < element_count) { - copy_zero_vector( + copy_zero( dst + i * element_count * stride + it * WARP_SIZE); } else { break; @@ -292,10 +199,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count * stride + it * WARP_SIZE); + copy( + grad + i * element_count * stride + it * WARP_SIZE, temp_grad); + copy( + output + i * element_count * stride + it * WARP_SIZE, temp_output); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -323,7 +230,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( sum[i] += grad_reg[i][it]; } } - warp_reduce(sum); + warp_reduce(sum); // store result #pragma unroll @@ -341,15 +248,13 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector( - gradInput + i * element_count * stride + it * WARP_SIZE, out); + copy( + out, gradInput + i * element_count * stride + it * WARP_SIZE); } } } } -} // end of anonymous namespace - template void dispatch_scaled_upper_triang_masked_softmax_forward( output_t *dst, const input_t *src, const input_t scale, @@ -358,7 +263,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( if (softmax_elements == 0) { return; } else { - int log2_elements = log2_ceil(softmax_elements); + int log2_elements = UnaryOpFunctor()(softmax_elements); const int next_power_of_two = 1 << log2_elements; int seq_len = softmax_elements; int batch_count = attn_batches * seq_len; @@ -483,7 +388,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( if (softmax_elements == 0) { return; } else { - int log2_elements = log2_ceil(softmax_elements); + int log2_elements = UnaryOpFunctor()(softmax_elements); const int next_power_of_two = 1 << log2_elements; int seq_len = softmax_elements; int batch_count = attn_batches * seq_len; @@ -598,3 +503,61 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( } } } + + + + +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), scale_factor, seq_len, + seq_len, attn_batches);); + return softmax_results; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, + // seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, seq_len, seq_len, attn_batches);); + + // backward pass is completely in-place + return output_grads; +} diff --git a/extensions/csrc/kernel/cuda/utils/gpu_launch_config.h b/extensions/csrc/kernel/cuda/utils/gpu_launch_config.h new file mode 100644 index 000000000000..b953c6587a64 --- /dev/null +++ b/extensions/csrc/kernel/cuda/utils/gpu_launch_config.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include "nvgpu_dev_info.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +struct GPULaunchConfig { + dim3 block{1, 1, 1}; + dim3 grid{1, 1, 1}; +}; + +static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info, + int64_t numel, int64_t vec_size) { + const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock(); + const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0]; + const int64_t kMinimumSize = 64; + const int64_t kMaximumSize = 512; + int64_t active_threads = (numel + vec_size - 1) / vec_size; + int64_t sm_num = dev_info.GetMultiProcessorCount(); + + // Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally + int64_t expected_threads_per_block = kMaximumSize; + + auto RoundUpToPowerOfTwo = [](int64_t x) { + bool is_power_of_two = false; + int64_t ret = 1; + int64_t y = x; + while (y > 0) { + is_power_of_two = ((ret ^ x) == 0); + y = (x >> 1); + ret = (ret << 1); + if (y > 0) is_power_of_two = false; + } + if (is_power_of_two) return x; + return ret; + }; + + if ((active_threads / (sm_num << 1)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 1)); + } else if ((active_threads / (sm_num << 2)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 2)); + } + + expected_threads_per_block = + std::max(expected_threads_per_block, kMinimumSize); + int64_t expect_block_per_grid = + ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + + if (expect_block_per_grid > max_blocks_per_grid) { + expect_block_per_grid = max_blocks_per_grid; + expected_threads_per_block = + (active_threads + expect_block_per_grid - 1) / expect_block_per_grid; + if (expected_threads_per_block > max_threads_per_block) + throw std::invalid_argument( + "Threads required for current input exceed for current GPU!"); + expected_threads_per_block = + RoundUpToPowerOfTwo(expected_threads_per_block); + expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + } + + GPULaunchConfig config; + config.block.x = expected_threads_per_block; + config.grid.x = expect_block_per_grid; + return config; +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/kernel/cuda/utils/micros.h b/extensions/csrc/kernel/cuda/utils/micros.h new file mode 100644 index 000000000000..aaa2fc1ef1b9 --- /dev/null +++ b/extensions/csrc/kernel/cuda/utils/micros.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(status)); \ + } \ + } + +#define HOST __host__ +#define DEVICE __device__ +#define HOSTDEVICE __host__ __device__ diff --git a/extensions/csrc/kernel/cuda/utils/nvgpu_dev_info.h b/extensions/csrc/kernel/cuda/utils/nvgpu_dev_info.h new file mode 100644 index 000000000000..f4c017e754c3 --- /dev/null +++ b/extensions/csrc/kernel/cuda/utils/nvgpu_dev_info.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "micros.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +class NVGPUDevInfo { + public: + explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { + CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num)); + } + + std::array GetMaxGridDims() const { + std::array ret; + ret[0] = prop_.maxGridSize[0]; + ret[1] = prop_.maxGridSize[1]; + ret[2] = prop_.maxGridSize[2]; + return ret; + } + + std::array GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_.maxThreadsDim[0]; + ret[1] = prop_.maxThreadsDim[1]; + ret[2] = prop_.maxThreadsDim[2]; + return ret; + } + + std::array GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; + return ret; + } + + int GetMultiProcessorCount() const { return prop_.multiProcessorCount; } + + int GetMaxThreadsPerMultiProcessor() const { + return prop_.maxThreadsPerMultiProcessor; + } + + int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; } + + private: + int device_num_; + cudaDeviceProp prop_; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/kernel/cuda/utils/vec_copy.h b/extensions/csrc/kernel/cuda/utils/vec_copy.h new file mode 100644 index 000000000000..465703a743a8 --- /dev/null +++ b/extensions/csrc/kernel/cuda/utils/vec_copy.h @@ -0,0 +1,64 @@ + +#pragma once + +#include "common/vec_type_traits.h" +#include "funcs/cast_functor.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +template +__device__ __inline__ void copy_zero(T *dst) { + using VT = typename common::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = funcs::CastFunctor()(0.0f); +} + +template +__device__ __inline__ void copy(const SrcT *src, DstT *dst) { + using SrcVT = typename common::VecTypeTrait::Type; + using DstVT = typename common::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = funcs::CastFunctor()( + *(reinterpret_cast(src))); +} + +template +__device__ __inline__ void copy(const T *src, T *dst) { + using VT = typename common::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +template <> +__device__ __inline__ void copy(const float *src, float *dst) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst + 4)) = + *(reinterpret_cast(src + 4)); +} + +template +int get_vec_size(const torch::Tensor &tensor) { + uint64_t address = reinterpret_cast(tensor.data_ptr()); + const int max_aligned_size = 128; + const int dtype_size = sizeof(T) * 8; + + const int vec_size = max_aligned_size / sizeof(T) / 8; + + // Note(LiuYang): Performance of situation of which + // vec_size equals to 8 need to be profiled in the future + // if (address % (dtype_size * 8) == 0) { + // return std::min(8, vec_size); + // } + if (address % (dtype_size * 4) == 0) { + return std::min(4, vec_size); + } else if (address % (dtype_size * 2) == 0) { + return std::min(2, vec_size); + } else { + return 1; + } +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/cpu_adam.cpp b/extensions/csrc/kernel/x86/cpu_adam.cpp similarity index 100% rename from extensions/csrc/cuda/cpu_adam.cpp rename to extensions/csrc/kernel/x86/cpu_adam.cpp diff --git a/extensions/csrc/cuda/cpu_adam.h b/extensions/csrc/kernel/x86/cpu_adam.h similarity index 100% rename from extensions/csrc/cuda/cpu_adam.h rename to extensions/csrc/kernel/x86/cpu_adam.h diff --git a/extensions/csrc/scaled_softmax.py b/extensions/csrc/scaled_softmax.py deleted file mode 100644 index 7c220d60dd19..000000000000 --- a/extensions/csrc/scaled_softmax.py +++ /dev/null @@ -1,190 +0,0 @@ -# This code from NVIDIA Megatron: -# with minor changes. - -import enum - -import torch -import torch.nn as nn - -from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader - -try: - from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax -except ImportError: - scaled_masked_softmax = None - scaled_upper_triang_masked_softmax = None - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - global scaled_upper_triang_masked_softmax - if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) - - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - - # build and load kernel if not pre-built - global scaled_masked_softmax - if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() - - softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None, None - - -class FusedScaleMaskSoftmax(nn.Module): - """ - Fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: Flag to indicate if input in fp16 data format. - input_in_bf16: Flag to indicate if input in bf16 data format. - attn_mask_type: Attention mask type (pad or causal) - scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion - mask_func: Mask function to be applied. - softmax_in_fp32: If True, softmax in performed at fp32 precision. - scale: Scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super(FusedScaleMaskSoftmax, self).__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 2048: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type.value > 1: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - b, np, sq, sk = input.size() - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type.value > 1: - assert sq == sk, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, sq, sk) - input = input.view(-1, sq, sk) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - return probs.view(b, np, sq, sk) - else: - # input is 4D tensor (b, np, sq, sk) - return ScaledMaskedSoftmax.apply(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - def get_batch_per_block(self, sq, sk, b, np): - # build and load kernel if not pre-built - global scaled_masked_softmax - if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() - - return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index f1e0095b29b6..b722057c9e8b 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -21,6 +21,7 @@ def nvcc_flags(self) -> List[str]: """ This function should return a list of nvcc compilation flags for extensions. """ + return ["-DCOLOSSAL_WITH_CUDA"] def is_available(self) -> bool: # cuda extension can only be built if cuda is available @@ -53,6 +54,12 @@ def get_cuda_home_include(self): cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + return super().include_dirs() + [self.get_cuda_home_include()] + def build_jit(self) -> None: from torch.utils.cpp_extension import CUDA_HOME, load diff --git a/extensions/pybind/__init__.py b/extensions/pybind/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/extensions/cpu_adam/__init__.py b/extensions/pybind/cpu_adam/__init__.py similarity index 100% rename from extensions/cpu_adam/__init__.py rename to extensions/pybind/cpu_adam/__init__.py diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/pybind/cpu_adam/cpu_adam_arm.py similarity index 80% rename from extensions/cpu_adam/cpu_adam_arm.py rename to extensions/pybind/cpu_adam/cpu_adam_arm.py index 61c4f3ed0697..9595eda69263 100644 --- a/extensions/cpu_adam/cpu_adam_arm.py +++ b/extensions/pybind/cpu_adam/cpu_adam_arm.py @@ -1,6 +1,7 @@ import platform +from typing import List -from ..cpp_extension import _CppExtension +from ...cpp_extension import _CppExtension class CpuAdamArmExtension(_CppExtension): @@ -20,12 +21,12 @@ def assert_compatible(self) -> None: # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + self.csrc_abs_path("kernel/arm/cpu_adam_arm.cpp"), ] return ret - def include_dirs(self): - return [] + def include_dirs(self) -> List[str]: + return super().include_dirs() def cxx_flags(self): extra_cxx_flags = [ diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/pybind/cpu_adam/cpu_adam_x86.py similarity index 83% rename from extensions/cpu_adam/cpu_adam_x86.py rename to extensions/pybind/cpu_adam/cpu_adam_x86.py index 9bbc8d85126d..525f3abe1a01 100644 --- a/extensions/cpu_adam/cpu_adam_x86.py +++ b/extensions/pybind/cpu_adam/cpu_adam_x86.py @@ -1,7 +1,7 @@ import platform -from ..cuda_extension import _CudaExtension -from ..utils import append_nvcc_threads +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads class CpuAdamX86Extension(_CudaExtension): @@ -21,13 +21,10 @@ def assert_compatible(self) -> None: # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cuda/cpu_adam.cpp"), + self.csrc_abs_path("kernel/x86/cpu_adam.cpp"), ] return ret - def include_dirs(self): - return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] - def cxx_flags(self): extra_cxx_flags = [ "-std=c++14", @@ -50,5 +47,5 @@ def nvcc_flags(self): "-U__CUDA_NO_HALF2_OPERATORS__", "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags() return append_nvcc_threads(ret) diff --git a/extensions/flash_attention/__init__.py b/extensions/pybind/flash_attention/__init__.py similarity index 100% rename from extensions/flash_attention/__init__.py rename to extensions/pybind/flash_attention/__init__.py diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py similarity index 98% rename from extensions/flash_attention/flash_attention_dao_cuda.py rename to extensions/pybind/flash_attention/flash_attention_dao_cuda.py index a2f2a52f1af4..a108377a8dcf 100644 --- a/extensions/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/pybind/flash_attention/flash_attention_dao_cuda.py @@ -1,4 +1,4 @@ -from ..base_extension import _Extension +from ...base_extension import _Extension class FlashAttentionDaoCudaExtension(_Extension): diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/pybind/flash_attention/flash_attention_npu.py similarity index 97% rename from extensions/flash_attention/flash_attention_npu.py rename to extensions/pybind/flash_attention/flash_attention_npu.py index 0e01cefa1112..8a30972b6fba 100644 --- a/extensions/flash_attention/flash_attention_npu.py +++ b/extensions/pybind/flash_attention/flash_attention_npu.py @@ -1,4 +1,4 @@ -from ..base_extension import _Extension +from ...base_extension import _Extension class FlashAttentionNpuExtension(_Extension): diff --git a/extensions/flash_attention/flash_attention_sdpa_cuda.py b/extensions/pybind/flash_attention/flash_attention_sdpa_cuda.py similarity index 97% rename from extensions/flash_attention/flash_attention_sdpa_cuda.py rename to extensions/pybind/flash_attention/flash_attention_sdpa_cuda.py index d3323a6aae27..2f920db61006 100644 --- a/extensions/flash_attention/flash_attention_sdpa_cuda.py +++ b/extensions/pybind/flash_attention/flash_attention_sdpa_cuda.py @@ -1,4 +1,4 @@ -from ..base_extension import _Extension +from ...base_extension import _Extension class FlashAttentionSdpaCudaExtension(_Extension): diff --git a/extensions/pybind/inference/__init__.py b/extensions/pybind/inference/__init__.py new file mode 100644 index 000000000000..c5ea424fa25d --- /dev/null +++ b/extensions/pybind/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference_ops_cuda import InferenceOpsCudaExtension + +__all__ = ["InferenceOpsCudaExtension"] diff --git a/extensions/pybind/inference/inference.cpp b/extensions/pybind/inference/inference.cpp new file mode 100644 index 000000000000..dc7be2349d79 --- /dev/null +++ b/extensions/pybind/inference/inference.cpp @@ -0,0 +1,111 @@ +#include + +void decode_kv_cache_memcpy( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_size] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables); // [batch_size, max_seq_len] + +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch); + +void rotary_embedding( + torch::Tensor& query, // [total_tokens, head_num, head_dim] + torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] + torch::Tensor& cos, // [total_tokens, head_dim] + torch::Tensor& sin, // [total_tokens, head_dim] + bool high_precision); + +void rotary_embedding_and_cache_copy( + torch::Tensor& query, // [num_tokens, head_num, head_dim] + torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] + torch::Tensor& value, // [num_tokens, num_heads, head_dim] + torch::Tensor& cos, // [num_tokens, head_dim] + torch::Tensor& sin, // [num_tokens, head_dim] + torch::Tensor& + key_cache, // [num_blocks, head_num, head_dim/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision); + +torch::Tensor silu_and_mul(const torch::Tensor& ins); + +void rms_layernorm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim] + at::Tensor& sin_cache, // [max_rotary_position, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + int max_seq_len_in_batch, bool is_prompts); + +void flash_decoding_attention( + torch::Tensor& out, // [num_tokens, num_heads, head_size] + torch::Tensor& query, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& context_lens, // [num_tokens] + torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq] + int block_size, int max_context_len, + torch::Tensor& + tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size] + torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions] + const c10::optional& alibi_slopes, float scale); + +void convert_fp8(torch::Tensor& input, torch::Tensor& output); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the decode stage."); + + m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the context stage."); + + m.def( + "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, + "Performing Rotary Embedding-related calculations and KVCache Memcopy."); + + m.def("rotary_embedding", &rotary_embedding, + "Performing Rotary Embedding-related calculations."); + + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); + + m.def("rms_layernorm", &rms_layernorm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, + "In-place fused Add and RMS Normalization."); + + m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache."); + + m.def("flash_decoding_attention", &flash_decoding_attention, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + + m.def("convert_fp8", &convert_fp8, + "Convert input to fp8 output or convert fp8 input to output."); +} diff --git a/extensions/pybind/inference/inference_ops_cuda.py b/extensions/pybind/inference/inference_ops_cuda.py new file mode 100644 index 000000000000..463a0704d0b7 --- /dev/null +++ b/extensions/pybind/inference/inference_ops_cuda.py @@ -0,0 +1,32 @@ +from ...cuda_extension import _CudaExtension +from ...utils import get_cuda_cc_flag + + +class InferenceOpsCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="inference_ops_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "kernel/cuda/decode_kv_cache_memcpy_kernel.cu", + "kernel/cuda/context_kv_cache_memcpy_kernel.cu", + "kernel/cuda/fused_rotary_emb_and_cache_kernel.cu", + "kernel/cuda/activation_kernel.cu", + "kernel/cuda/rms_layernorm_kernel.cu", + "kernel/cuda/get_cos_and_sin_kernel.cu", + "kernel/cuda/flash_decoding_attention_kernel.cu", + "kernel/cuda/convert_fp8_kernel.cu", + ] + ] + [self.pybind_abs_path("inference/inference.cpp")] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags() diff --git a/extensions/layernorm/__init__.py b/extensions/pybind/layernorm/__init__.py similarity index 100% rename from extensions/layernorm/__init__.py rename to extensions/pybind/layernorm/__init__.py diff --git a/extensions/csrc/cuda/layer_norm_cuda.cpp b/extensions/pybind/layernorm/layer_norm.cpp similarity index 99% rename from extensions/csrc/cuda/layer_norm_cuda.cpp rename to extensions/pybind/layernorm/layer_norm.cpp index 15a07bb0c7ac..77c4e38c8150 100644 --- a/extensions/csrc/cuda/layer_norm_cuda.cpp +++ b/extensions/pybind/layernorm/layer_norm.cpp @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "common/micros.h" namespace { diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/pybind/layernorm/layernorm_cuda.py similarity index 57% rename from extensions/layernorm/layernorm_cuda.py rename to extensions/pybind/layernorm/layernorm_cuda.py index db5f2fce1368..951563e7eec1 100644 --- a/extensions/layernorm/layernorm_cuda.py +++ b/extensions/pybind/layernorm/layernorm_cuda.py @@ -1,5 +1,5 @@ -from ..cuda_extension import _CudaExtension -from ..utils import append_nvcc_threads, get_cuda_cc_flag +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag class LayerNormCudaExtension(_CudaExtension): @@ -7,11 +7,13 @@ def __init__(self): super().__init__(name="layernorm_cuda") def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/layer_norm_kernel.cu"]] + [ + self.pybind_abs_path("layernorm/layer_norm.cpp") + ] return ret def include_dirs(self): - ret = [self.get_cuda_home_include()] + ret = [self.get_cuda_home_include()] + [self.csrc_abs_path("")] return ret def cxx_flags(self): @@ -20,5 +22,5 @@ def cxx_flags(self): def nvcc_flags(self): extra_cuda_flags = ["-maxrregcount=50"] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + super().nvcc_flags() return append_nvcc_threads(ret) diff --git a/extensions/moe/__init__.py b/extensions/pybind/moe/__init__.py similarity index 100% rename from extensions/moe/__init__.py rename to extensions/pybind/moe/__init__.py diff --git a/extensions/csrc/cuda/moe_cuda.cpp b/extensions/pybind/moe/moe.cpp similarity index 100% rename from extensions/csrc/cuda/moe_cuda.cpp rename to extensions/pybind/moe/moe.cpp diff --git a/extensions/moe/moe_cuda.py b/extensions/pybind/moe/moe_cuda.py similarity index 58% rename from extensions/moe/moe_cuda.py rename to extensions/pybind/moe/moe_cuda.py index 52883e97fc3a..898ffe21c19b 100644 --- a/extensions/moe/moe_cuda.py +++ b/extensions/pybind/moe/moe_cuda.py @@ -1,17 +1,15 @@ -from ..cuda_extension import _CudaExtension -from ..utils import append_nvcc_threads, get_cuda_cc_flag +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag class MoeCudaExtension(_CudaExtension): def __init__(self): super().__init__(name="moe_cuda") - def include_dirs(self): - ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] - return ret - def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/moe_kernel.cu"]] + [ + self.pybind_abs_path("moe/moe.cpp") + ] return ret def cxx_flags(self): @@ -25,5 +23,5 @@ def nvcc_flags(self): "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags() return append_nvcc_threads(ret) diff --git a/extensions/optimizer/__init__.py b/extensions/pybind/optimizer/__init__.py similarity index 100% rename from extensions/optimizer/__init__.py rename to extensions/pybind/optimizer/__init__.py diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/pybind/optimizer/fused_optimizer_cuda.py similarity index 50% rename from extensions/optimizer/fused_optimizer_cuda.py rename to extensions/pybind/optimizer/fused_optimizer_cuda.py index e065cf34a17d..13f3281fbfb0 100644 --- a/extensions/optimizer/fused_optimizer_cuda.py +++ b/extensions/pybind/optimizer/fused_optimizer_cuda.py @@ -1,5 +1,5 @@ -from ..cuda_extension import _CudaExtension -from ..utils import get_cuda_cc_flag +from ...cuda_extension import _CudaExtension +from ...utils import get_cuda_cc_flag class FusedOptimizerCudaExtension(_CudaExtension): @@ -10,18 +10,13 @@ def sources_files(self): ret = [ self.csrc_abs_path(fname) for fname in [ - "cuda/colossal_C_frontend.cpp", - "cuda/multi_tensor_sgd_kernel.cu", - "cuda/multi_tensor_scale_kernel.cu", - "cuda/multi_tensor_adam.cu", - "cuda/multi_tensor_l2norm_kernel.cu", - "cuda/multi_tensor_lamb.cu", + "kernel/cuda/multi_tensor_sgd_kernel.cu", + "kernel/cuda/multi_tensor_scale_kernel.cu", + "kernel/cuda/multi_tensor_adam_kernel.cu", + "kernel/cuda/multi_tensor_l2norm_kernel.cu", + "kernel/cuda/multi_tensor_lamb_kernel.cu", ] - ] - return ret - - def include_dirs(self): - ret = [self.get_cuda_home_include()] + ] + [self.pybind_abs_path("optimizer/optimizer.cpp")] return ret def cxx_flags(self): @@ -31,4 +26,4 @@ def cxx_flags(self): def nvcc_flags(self): extra_cuda_flags = ["-lineinfo"] extra_cuda_flags.extend(get_cuda_cc_flag()) - return ["-O3", "--use_fast_math"] + extra_cuda_flags + return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags() diff --git a/extensions/csrc/cuda/colossal_C_frontend.cpp b/extensions/pybind/optimizer/optimizer.cpp similarity index 100% rename from extensions/csrc/cuda/colossal_C_frontend.cpp rename to extensions/pybind/optimizer/optimizer.cpp diff --git a/extensions/softmax/__init__.py b/extensions/pybind/softmax/__init__.py similarity index 100% rename from extensions/softmax/__init__.py rename to extensions/pybind/softmax/__init__.py diff --git a/extensions/csrc/cuda/scaled_masked_softmax.cpp b/extensions/pybind/softmax/scaled_masked_softmax.cpp similarity index 71% rename from extensions/csrc/cuda/scaled_masked_softmax.cpp rename to extensions/pybind/softmax/scaled_masked_softmax.cpp index 8c2982b0cff9..427035d4e88b 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax.cpp +++ b/extensions/pybind/softmax/scaled_masked_softmax.cpp @@ -6,10 +6,6 @@ #include -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); @@ -17,8 +13,8 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads); torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { @@ -46,25 +42,13 @@ torch::Tensor bwd(torch::Tensor const& output_grads, return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, - attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + m.def("forward", &fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + m.def("backward", &bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); - m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax:: - get_batch_per_block, + m.def("get_batch_per_block", &get_batch_per_block, "Return Batch per block size."); } diff --git a/extensions/softmax/scaled_masked_softmax_cuda.py b/extensions/pybind/softmax/scaled_masked_softmax_cuda.py similarity index 66% rename from extensions/softmax/scaled_masked_softmax_cuda.py rename to extensions/pybind/softmax/scaled_masked_softmax_cuda.py index 5b4208dba895..049a8c7b593b 100644 --- a/extensions/softmax/scaled_masked_softmax_cuda.py +++ b/extensions/pybind/softmax/scaled_masked_softmax_cuda.py @@ -1,5 +1,5 @@ -from ..cuda_extension import _CudaExtension -from ..utils import append_nvcc_threads +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): @@ -7,15 +7,11 @@ def __init__(self): super().__init__(name="scaled_masked_softmax_cuda") def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/scaled_masked_softmax_kernel.cu"]] + [ + self.pybind_abs_path("softmax/scaled_masked_softmax.cpp") ] return ret - def include_dirs(self): - return [self.get_cuda_home_include()] - def cxx_flags(self): return ["-O3"] + self.version_dependent_macros @@ -28,5 +24,5 @@ def nvcc_flags(self): "-U__CUDA_NO_HALF2_OPERATORS__", "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags() return append_nvcc_threads(ret) diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp b/extensions/pybind/softmax/scaled_upper_triang_masked_softmax.cpp similarity index 79% rename from extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp rename to extensions/pybind/softmax/scaled_upper_triang_masked_softmax.cpp index cbbc3706497a..bbd65712374d 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp +++ b/extensions/pybind/softmax/scaled_upper_triang_masked_softmax.cpp @@ -6,10 +6,6 @@ #include -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); torch::Tensor bwd_cuda(torch::Tensor const& output_grads, @@ -40,15 +36,9 @@ torch::Tensor bwd(torch::Tensor const& output_grads, return bwd_cuda(output_grads, softmax_results, scale_factor); } -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + m.def("forward", &fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + m.def("backward", &bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); } diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py similarity index 65% rename from extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py rename to extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py index d4f27a9218ff..a179c2ac5450 100644 --- a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py +++ b/extensions/pybind/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -1,22 +1,18 @@ -from ..cuda_extension import _CudaExtension -from ..utils import append_nvcc_threads, get_cuda_cc_flag +from ...cuda_extension import _CudaExtension +from ...utils import append_nvcc_threads, get_cuda_cc_flag class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): def __init__(self): super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") - def include_dirs(self): - return [self.get_cuda_home_include()] - def sources_files(self): ret = [ self.csrc_abs_path(fname) for fname in [ - "cuda/scaled_upper_triang_masked_softmax.cpp", - "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + "kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu", ] - ] + ] + [self.pybind_abs_path("softmax/scaled_upper_triang_masked_softmax.cpp")] return ret def cxx_flags(self): @@ -30,5 +26,5 @@ def nvcc_flags(self): "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags() return append_nvcc_threads(ret) diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt deleted file mode 100644 index f85f9d88e629..000000000000 --- a/requirements/requirements-infer.txt +++ /dev/null @@ -1,4 +0,0 @@ -transformers==4.34.0 -auto-gptq==0.5.0 -git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 -git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 58c7f780fbb0..e4affc7f5396 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,3 @@ datasets pydantic ray peft>=0.7.1 -#auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f1815ee18028..741975942b4c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -19,4 +19,5 @@ protobuf transformers==4.36.2 peft>=0.7.1 bitsandbytes>=0.39.0 +rpyc==6.0.0 galore_torch diff --git a/setup.py b/setup.py index c16709ad1c1c..b105c03b717c 100644 --- a/setup.py +++ b/setup.py @@ -111,7 +111,6 @@ def get_version() -> str: "tests", "scripts", "requirements", - "extensions", "*.egg-info", ), ), diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py deleted file mode 100644 index ded70fa43c30..000000000000 --- a/tests/test_gptq/test_gptq_linear.py +++ /dev/null @@ -1,144 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -try: - from auto_gptq.modeling._utils import autogptq_post_init - from auto_gptq.utils.import_utils import dynamically_import_QuantLinear - from exllama_kernels import prepare_buffers, set_tuning_params - - from colossalai.inference.quant.gptq import CaiQuantLinear - - HAS_AUTO_GPTQ = True -except: - HAS_AUTO_GPTQ = False - print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ") - -import warnings - -HAS_GPTQ_CUDA = False -try: - from colossalai.kernel.op_builder.gptq import GPTQBuilder - - gptq_cuda = GPTQBuilder().load() - HAS_GPTQ_CUDA = True -except ImportError: - warnings.warn("CUDA gptq is not installed") - HAS_GPTQ_CUDA = False - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - -max_inner_outer_dim = 1 -max_input_len = 1 -max_dq_buffer_size = 1 -gptq_temp_dq_buffer = None -gptq_temp_state_buffer = None - - -def init_buffer(cai_linear, use_act_order=False): - global max_dq_buffer_size - global max_input_len - global max_dq_buffer_size - global max_inner_outer_dim - global gptq_temp_dq_buffer - global gptq_temp_state_buffer - - max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8) - - if use_act_order: - max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures) - - if use_act_order: - max_input_len = 4096 - # The temp_state buffer is required to reorder X in the act-order case. - # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - gptq_temp_state_buffer = torch.zeros( - (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() - ) - gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) - - gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, - reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq", -) -def test_gptq_linear(): - infeature = 1024 - outfeature = 1024 - group_size = 128 - wbits = 4 - - inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device()) - - device = torch.device("cuda:0") - - linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits) - - linear = linear_class( - bits=4, - group_size=group_size, - infeatures=infeature, - outfeatures=outfeature, - bias=False, - ) - - torch.manual_seed(42) - - linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) - linear.scales = linear.scales + 0.002 - - linear = linear.to(device) - - cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True) - cai_linear.qweight.data.copy_(linear.qweight) - cai_linear.scales = cai_linear.scales + 0.002 - cai_linear = cai_linear.to(device) - - linear = autogptq_post_init(linear, use_act_order=False) - - max_inner_outer_dim = max(infeature, outfeature) - max_dq_buffer_size = linear.infeatures * linear.outfeatures - max_input_len = 2048 - buffers = { - "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), - "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device), - } - - prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) - - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - with torch.no_grad(): - gptq_out = linear(inps) - batch_gptq_out = linear(batch_inps) - torch.cuda.synchronize() - cai_out = cai_linear(inps) - torch.cuda.synchronize() - - batch_cai_out = cai_linear(batch_inps) - torch.cuda.synchronize() - - assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) - assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) - - -if __name__ == "__main__": - test_gptq_linear() diff --git a/tests/test_infer/__init__.py b/tests/test_infer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py old mode 100644 new mode 100755 diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py new file mode 100644 index 000000000000..ac532b1b199d --- /dev/null +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -0,0 +1,80 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + + +@dataclass +class MockSequence: + request_id: int + + +class MockEngine: + def __init__(self): + self.step_calls = 0 + self.add_request_calls = 0 + self.abort_request_calls = 0 + self.request_id = None + + async def async_step(self): + self.step_calls += 1 + return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False) + + def add_single_request(self, **kwargs): + del kwargs + self.add_request_calls += 1 + + def generate(self, request_id): + self.request_id = request_id + + def stop_generating(self): + self.request_id = None + + def add_request(self, **kwargs): + del kwargs # Unused + self.add_request_calls += 1 + + def abort_request(self, request_id): + del request_id # Unused + self.abort_request_calls += 1 + + +class MockAsyncInferenceEngine(AsyncInferenceEngine): + def _init_engine(self, *args, **kwargs): + return MockEngine() + + +@pytest.mark.asyncio +async def test_new_requests_event(): + engine = MockAsyncInferenceEngine() + engine.start_background_loop() + await asyncio.sleep(0.01) + assert engine.engine.step_calls == 0 + + await engine.add_request(1, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 1 + assert engine.engine.step_calls == 1 + + await engine.add_request(2, "", None) + engine.engine.generate(2) + await asyncio.sleep(0) + assert engine.engine.add_request_calls == 2 + assert engine.engine.step_calls == 2 + await asyncio.sleep(0) + assert engine.engine.step_calls == 3 + engine.engine.stop_generating() + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + + await engine.add_request(3, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 diff --git a/tests/test_infer/test_async_engine/test_request_tracer.py b/tests/test_infer/test_async_engine/test_request_tracer.py new file mode 100644 index 000000000000..14bcb96281b3 --- /dev/null +++ b/tests/test_infer/test_async_engine/test_request_tracer.py @@ -0,0 +1,68 @@ +import pytest + +from colossalai.inference.core.async_engine import Tracer +from colossalai.inference.struct import Sequence + + +class SampleEvent: + def __init__(self): + self.flag = False + + def set(self): + self.flag = True + + def clear(self): + self.flag = False + + +def test_request_tracer(): + tracker = Tracer() + tracker.new_requests_event = SampleEvent() + stream_1 = tracker.add_request(1) + assert tracker.new_requests_event.flag + new = tracker.get_new_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 1 + assert new[0]["request_id"] == 1 + assert not stream_1.finished + + stream_2 = tracker.add_request(2) + stream_3 = tracker.add_request(3) + assert tracker.new_requests_event.flag + new = tracker.get_new_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 2 + assert new[0]["request_id"] == 2 + assert new[1]["request_id"] == 3 + assert not stream_2.finished + assert not stream_3.finished + + # request_ids must be unique + with pytest.raises(KeyError): + tracker.add_request(1) + assert not tracker.new_requests_event.flag + + tracker.abort_request(1) + new = tracker.get_new_requests() + assert not new + + stream_4 = tracker.add_request(4) + tracker.abort_request(4) + assert tracker.new_requests_event.flag + new = tracker.get_new_requests() + assert not new + assert stream_4.finished + + stream_5 = tracker.add_request(5) + assert tracker.new_requests_event.flag + tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) + new = tracker.get_new_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 1 + assert new[0]["request_id"] == 5 + assert stream_2.finished + assert not stream_5.finished + + +if __name__ == "__main__": + test_request_tracer() diff --git a/tests/test_infer/test_batch_bucket.py b/tests/test_infer/test_batch_bucket.py new file mode 100644 index 000000000000..f7fd1d4a4986 --- /dev/null +++ b/tests/test_infer/test_batch_bucket.py @@ -0,0 +1,144 @@ +import torch +from transformers.models.llama import LlamaConfig + +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig +from colossalai.inference.kv_cache import KVCacheManager +from colossalai.inference.struct import Sequence +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize + +logger = get_dist_logger(__name__) + + +@parameterize( + "test_config", + [ + { + "hidden_size": 128, + "num_attention_heads": 4, + "num_layers": 2, + "block_size": 4, + "max_batch_size": 4, + "max_input_len": 32, + "max_output_len": 8, + "dtype": torch.float16, + "tp_size": 1, + } + ], +) +def test_bucket(test_config): + hidden_size = test_config.pop("hidden_size") + num_heads = test_config.pop("num_attention_heads") + num_layers = test_config.pop("num_layers") + model_config = LlamaConfig( + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + ) + inference_config = InferenceConfig(**test_config) + + # Just for testing usage. Don't create multiple cache_manager on the same device. + cache_manager = KVCacheManager(inference_config, model_config) + cache_manager_copy = KVCacheManager(inference_config, model_config) + + seq_lens = [19, 20, 27] + seq1 = Sequence( + request_id=0, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[0])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + seq2 = Sequence( + request_id=1, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[1])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + seq3 = Sequence( + request_id=2, + prompt="", # Dummy for testing usage + input_token_id=list(range(seq_lens[2])), + block_size=4, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=10, + ) + + block_size = test_config["block_size"] + max_batch_size = test_config["max_batch_size"] + max_length = test_config["max_input_len"] + test_config["max_output_len"] + assert max_batch_size >= 2, "max_batch_size should be greater than 1" + + bb = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + bb_copy = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + block_tables = bb.add_seqs([seq1, seq2]) + logger.debug(f"bb information: {bb}") + assert block_tables.shape == (2, cache_manager.max_blocks_per_sequence) + assert torch.all(block_tables < 0), "Initialized block_tables should be negative values" + + cache_manager.allocate_context_from_block_tables(block_tables, bb.seq_lengths[: bb.current_batch_size]) + bb_copy.add_seqs( + [seq1, seq2], alloc_block_tables_fn=cache_manager_copy.allocate_context_from_block_tables + ) # This is just for testing usage. Don't add the same sequence to different buckets. + + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + assert torch.equal(bb.block_tables, bb_copy.block_tables) + + bb.append_batch_tokens(torch.tensor([99, 99])) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + bb.append_batch_tokens(torch.tensor([99, 99])) + + cache_manager.allocate_tokens_from_block_tables(bb.block_tables, bb.seq_lengths, bsz=bb.current_batch_size) + assert bb.seq_lengths.tolist() == [seq1.sentence_len, seq2.sentence_len] + [0] * ( + max_batch_size - bb.current_batch_size + ) + + bb.pop_seq_update_batch(0, free_block_table_fn=cache_manager.free_block_table) + assert bb.seq_lengths.tolist() == [bb.seqs_li[0].sentence_len] + [0] * (max_batch_size - bb.current_batch_size) + assert bb.is_compact + + bb2 = BatchBucket( + num_heads, cache_manager.get_head_size(), max_batch_size, max_length, block_size, kv_max_split_num=2 + ) + block_tables = bb2.add_seqs([seq3]) + cache_manager.allocate_context_from_block_tables(block_tables, bb2.seq_lengths[: bb2.current_batch_size]) + unmerged_ids = bb.merge(bb2) + assert not unmerged_ids + assert bb.is_compact + assert bb2.is_compact + assert bb.current_batch_size == 2 + assert bb2.current_batch_size == 0 + + bb.clear(cache_manager.free_block_tables) + assert bb.current_batch_size == 0 + assert bb.is_compact + assert bb.seq_lengths.tolist() == [0] * max_batch_size + assert torch.all(bb.block_tables < 0) + + +if __name__ == "__main__": + test_bucket() diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py new file mode 100755 index 000000000000..d6f54212949e --- /dev/null +++ b/tests/test_infer/test_config_and_struct.py @@ -0,0 +1,46 @@ +import pytest + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_config_and_inference(): + config = InferenceConfig() + assert config.max_batch_size == 8 + sequence = Sequence( + request_id=1, + prompt="abc", + input_token_id=[1, 2, 3], + block_size=16, + sample_params=None, + eos_token_id=2, + pad_token_id=2, + max_output_len=256, + ) + + sequence.mark_running() + assert sequence.status == RequestStatus.RUNNING + sequence.recycle() + assert sequence.status == RequestStatus.RECYCLED + + assert sequence.sentence_len == 3 + assert sequence.input_len == 3 + assert sequence.output_len == 0 + assert sequence.check_finish() == False + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_config_and_inference() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_config_and_inference(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_config_and_inference() diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py new file mode 100644 index 000000000000..a88798619b79 --- /dev/null +++ b/tests/test_infer/test_continuous_batching.py @@ -0,0 +1,103 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def generate_inputs(num_sequences, min_length, max_length): + sequences = [] + for _ in range(num_sequences): + length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item() + # generating randomly lengthed sequences + sequence = torch.randint(10, 30000, size=(length,)) + sequences.append(sequence) + return sequences + + +@parameterize( + "test_config", + [ + { + "max_batch_size": 8, + "max_output_len": 512, + "max_input_len": 64, + "do_sample": False, + } + ], +) +def check_inference_engine(test_config, use_engine=False, prompt_template=None): + setup_seed(20) + max_batch_size = test_config["max_batch_size"] + max_input_len = test_config["max_input_len"] + max_output_len = test_config["max_output_len"] + do_sample = test_config["do_sample"] + top_p = 0.5 + top_k = 50 + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() + model = model.eval() + + inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len) + + if use_engine: + inference_config = InferenceConfig( + max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == max_output_len + inference_engine.add_request(prompts_token_ids=inputs_token_ids) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=max_output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + assert len(outputs) == 10 * max_batch_size + + +@parameterize("prompt_template", [None, "llama"]) +def check_continuous_batching(prompt_template): + check_inference_engine(use_engine=True, prompt_template=prompt_template) + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_continuous_batching() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_continuous_batching(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_continuous_batching() diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py new file mode 100644 index 000000000000..2be188571d9c --- /dev/null +++ b/tests/test_infer/test_cuda_graph.py @@ -0,0 +1,96 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(use_cuda_graph=False, batch_size=32): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) + ) + .cuda() + .half() + ) + model = model.eval() + + prompts_token_ids = [] + for i in range(batch_size): + prompts_token_ids.append( + np.random.randint(low=0, high=100, size=random.randint(1, max(1024 // batch_size, 32))).tolist() + ) + + input_len = 1024 + output_len = 128 + do_sample = False + top_p = 0.5 + top_k = 50 + + if use_cuda_graph: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_kernel=False, + use_cuda_graph=True, + block_size=16, + ) + else: + inference_config = InferenceConfig( + max_batch_size=batch_size, + max_input_len=input_len, + max_output_len=output_len, + use_cuda_kernel=False, + use_cuda_graph=False, + block_size=16, + ) + + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == output_len + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(prompts_token_ids=prompts_token_ids, generation_config=generation_config) + + return outputs + + +def check_output_consistency(batch_size): + cuda_graph_output = check_inference_engine(use_cuda_graph=True, batch_size=batch_size) + naive_model_output = check_inference_engine(use_cuda_graph=False, batch_size=batch_size) + + for s1, s2 in zip(cuda_graph_output, naive_model_output): + assert s1 == s2, f"\nCUDA Graph Output: {s1}\nOrigin Output: {s2}" + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_output_consistency(32) + check_output_consistency(64) + check_output_consistency(128) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_cuda_graph_infer(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_cuda_graph_infer() diff --git a/tests/test_infer/test_drafter.py b/tests/test_infer/test_drafter.py new file mode 100644 index 000000000000..3c5dda1578a2 --- /dev/null +++ b/tests/test_infer/test_drafter.py @@ -0,0 +1,74 @@ +import pytest +import torch +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM + +from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM +from colossalai.inference.spec.drafter import Drafter +from colossalai.utils import get_current_device + +NUM_LAYERS = 1 +MAX_LEN = 100 +SPEC_NUM = 5 + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + + +@pytest.mark.parametrize("spec_num", [SPEC_NUM]) +def test_drafter(tokenizer, spec_num: int): + torch.manual_seed(123) + + device = get_current_device() + toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS) + toy_config.pad_token_id = tokenizer.eos_token_id + drafter_model = LlamaForCausalLM(toy_config) + drafter_model = drafter_model.eval().cuda() + + drafter = Drafter(drafter_model, tokenizer, device=device) + + input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device) + out = drafter.speculate(input_ids, spec_num) + past_kv_length = input_ids.size(1) + spec_num - 1 + + assert out.speculated_length == spec_num + assert out.next_tokens.shape == (spec_num,) + assert out.logits.shape == (spec_num, len(tokenizer)) + assert out.past_key_values[0][0].size(2) == past_kv_length + + reject_num = max(0, spec_num - 1) + trimmed_past_key_values = drafter.trim_kv_cache(out.past_key_values, reject_num) + assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num + + +def test_spec_dec(tokenizer): + spec_num = SPEC_NUM + device = get_current_device() + tokenizer.pad_token = tokenizer.eos_token + + # Dummy config for Glide Model + glide_config = GlideLlamaConfig( + intermediate_size=8192, + large_hidden_size=4096, + large_num_attention_heads=32, + num_hidden_layers=NUM_LAYERS, + ) + drafter_model = GlideLlamaForCausalLM(glide_config) + + assert hasattr(drafter_model, "model") + assert hasattr(drafter_model.model, "layers") + for _, layer in enumerate(drafter_model.model.layers): + assert hasattr(layer, "cross_attn") + + # Init the Drafter by providing the sharded drafter model + drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16) + + input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device) + out = drafter.speculate(input_ids, spec_num, past_key_values=None) + + +if __name__ == "__main__": + dummy_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + test_drafter(dummy_tokenizer, spec_num=SPEC_NUM) + test_spec_dec(dummy_tokenizer) diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py deleted file mode 100644 index ef2aac1d1aa7..000000000000 --- a/tests/test_infer/test_hybrid_bloom.py +++ /dev/null @@ -1,121 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -import transformers -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - model = transformers.BloomForCausalLM( - transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4) - ) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py deleted file mode 100644 index e80b3477f736..000000000000 --- a/tests/test_infer/test_hybrid_chatglm2.py +++ /dev/null @@ -1,129 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - chatglm_config = ChatGLMConfig( - num_layers=2, - vocab_size=20000, - use_cache=True, - multi_query_attention=True, - multi_query_group_num=2, - num_attention_heads=8, - hidden_size=1024, - ) - model = ChatGLMForConditionalGeneration(chatglm_config) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py deleted file mode 100644 index a997948178e0..000000000000 --- a/tests/test_infer/test_hybrid_llama.py +++ /dev/null @@ -1,126 +0,0 @@ -import importlib.util - -import pytest -import torch -import torch.distributed as dist -import transformers -from packaging import version - -import colossalai -from colossalai.inference import InferenceEngine -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - -import importlib.util - -HAS_LIGHTLLM_KERNEL = True - -if importlib.util.find_spec("lightllm") is None: - HAS_LIGHTLLM_KERNEL = False - - -def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -inputs = data_gen() -for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - -def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( - vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 - ) - ) - - engine = InferenceEngine( - tp_size=tp_size, - pp_size=pp_size, - model=model, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - ) - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [2]) -@parameterize("max_output_len", [4]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [2]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -@parameterize("tp_size", [1]) -@parameterize("pp_size", [1]) -@parameterize("max_output_len", [2]) -@parameterize("micro_batch_size", [1]) -@clear_cache_before_run() -def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): - pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) - torch.cuda.empty_cache() - - -def check_tp_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_pipeline_inference_test() - - -def check_tp_or_pp_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_tp_inference_test() - run_pipeline_inference_test() - - -def check_single_inference(rank, world_size, port): - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_single_inference_test - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_pipeline_inference(): - spawn(check_tp_pp_inference, nprocs=4) - spawn(check_tp_or_pp_inference, nprocs=2) - spawn(check_single_inference, nprocs=1) - - -if __name__ == "__main__": - test_pipeline_inference() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py new file mode 100644 index 000000000000..919a10077d24 --- /dev/null +++ b/tests/test_infer/test_inference_engine.py @@ -0,0 +1,208 @@ +import random + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.multiprocessing import Manager +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM +from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, + hidden_size=512, + intermediate_size=1536, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=16, + ) + ).cuda() + model = model.eval() + inputs = [ + "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", + "介绍一下武汉,", + ] + + output_len = 38 + do_sample = do_sample + top_p = 0.5 + top_k = 50 + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + dtype="fp32", + use_cuda_kernel=True, + tp_size=dist.get_world_size(), + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + dtype="fp32", + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return outputs + + +def run_engine(world_size, **kwargs): + manager = Manager() + result_list = manager.list([-1] * world_size) # Create a shared list + + spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs) + return result_list[0] + + +def check_spec_dec(num_layers, max_length): + torch.manual_seed(123) + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + # Dummy configs for testing + toy_config = LlamaConfig(num_hidden_layers=num_layers) + toy_config.pad_token_id = tokenizer.eos_token_id + drafter_model = LlamaForCausalLM(toy_config) + drafter_model = drafter_model.eval().cuda() + large_config = LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=8, + num_key_value_heads=32, + max_position_embeddings=2048, + ) + large_config.pad_token_id = tokenizer.eos_token_id + main_model = LlamaForCausalLM(large_config) + + inference_config = InferenceConfig( + dtype="fp16", + micro_batch_size=1, + max_batch_size=1, + max_input_len=128, + max_output_len=128, + prefill_ratio=1.2, + block_size=16, + ) + engine = InferenceEngine(main_model, tokenizer, inference_config) + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + + dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda") + generation_config = GenerationConfig( + pad_token_id=tokenizer.eos_token_id, + max_length=max_length, + eos_token_id=tokenizer.eos_token_id, + ) + out, out_token_ids = engine.generate( + prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True + ) + engine.disable_spec_dec() + engine.clear_spec_dec() + + assert not engine.use_spec_dec + assert engine.drafter is None and engine.drafter_model is None + + max_new_tokens = max_length - dummy_inputs.size(1) + assert len(out) == 1 + assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens + + # test GLIDE model + glide_config = GlideLlamaConfig( + intermediate_size=8192, + large_hidden_size=4096, + large_num_attention_heads=32, + num_hidden_layers=num_layers, + ) + glide_model = GlideLlamaForCausalLM(glide_config) + engine.enable_spec_dec(glide_model, use_glide_drafter=True) + + out, out_token_ids = engine.generate( + prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True + ) + engine.clear_spec_dec() + + assert len(out) == 1 + assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_new_tokens + + +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +@pytest.mark.largedist +@parameterize("prompt_template", [None, "llama"]) +@parameterize("do_sample", [False]) +@rerun_if_address_is_in_use() +def test_tp_engine(prompt_template, do_sample): + kwargs1 = { + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": NoPaddingLlamaModelInferPolicy(), + } + + kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +@pytest.mark.largedist +@parameterize("num_layers", [1]) +@parameterize("max_length", [64]) +@rerun_if_address_is_in_use() +def test_spec_dec(num_layers, max_length): + spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length) + + +if __name__ == "__main__": + test_tp_engine() + test_spec_dec() diff --git a/tests/test_infer/test_kernels/__init__.py b/tests/test_infer/test_kernels/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_infer/test_kernels/cuda/__init__.py b/tests/test_infer/test_kernels/cuda/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_infer/test_kernels/cuda/test_convert_fp8.py b/tests/test_infer/test_kernels/cuda/test_convert_fp8.py new file mode 100644 index 000000000000..bfcffa713d8d --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_convert_fp8.py @@ -0,0 +1,57 @@ +import random + +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + +DTYPES = [torch.half, torch.bfloat16, torch.float] +NUM_TOKENS = [42] # Arbitrary values for testing +NUM_LAYERS = [1] # Arbitrary values for testing +NUM_HEADS = [8] # Arbitrary values for testing +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +BLOCK_SIZES = [8, 16, 32] + + +@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!") +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256]) +@pytest.mark.parametrize("block_size", [8, 16, 32]) +@pytest.mark.parametrize("num_blocks", [1024, 10000]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) +@pytest.mark.parametrize("seed", [0]) +@torch.inference_mode() +def test_fp8_conversion( + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + device = get_current_device() + + low = -224.0 + high = 224.0 + shape = (num_blocks, num_heads, head_size, block_size) + cache = torch.empty(shape, dtype=dtype, device=device) + cache.uniform_(low, high) + + cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) + inference_ops.convert_fp8(cache, cache_fp8) + + converted_cache = torch.empty_like(cache) + inference_ops.convert_fp8(cache_fp8, converted_cache) + + assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) + + +if __name__ == "__main__": + test_fp8_conversion(8, 64, 8, 1024, torch.half, 0) diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py new file mode 100644 index 000000000000..38913b8a94f9 --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -0,0 +1,334 @@ +from itertools import product + +import numpy as np +import pytest +import torch + +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask + +inference_ops = InferenceOpsLoader().load() + +from tests.test_infer.test_kernels.triton.kernel_utils import ( + convert_kv_unpad_to_padded, + create_attention_mask, + generate_caches_and_block_tables_v3, + generate_caches_and_block_tables_vllm, + torch_attn_ref, +) + +q_len = 1 +PARTITION_SIZE = 512 + + +def prepare_data( + BATCH_SIZE: int, + HEAD_SIZE: int, + NUM_ATTN_HEADS: int, + NUM_KV_HEADS: int, + MAX_SEQ_LEN: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + # returns + # q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE] + # k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE] + kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device) + num_tokens = torch.sum(kv_lengths).item() + + q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths + + +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) +@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) +@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32, 256, 512]) +@pytest.mark.parametrize("HEAD_SIZE", [64, 128]) +@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) +@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) +def test_flash_decoding_attention( + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + try: + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) + else: + alibi_slopes = None + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + + if use_alibi_slopes: + alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) + torch_padding_mask = torch_padding_mask + alibi_mask + + if len(torch_padding_mask.size()) == 4: + torch_padding_mask = torch_padding_mask[:, :, -1:, :] + else: + torch_padding_mask = torch_padding_mask[:, -1:, :] + + mid_output = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device + ) + exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device) + max_logits = torch.empty( + size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device + ) + + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + high_precision_q = q.to(torch.float32) + high_precision_k_torch = k_torch.to(torch.float32) + high_precision_v_torch = v_torch.to(torch.float32) + out_ref = torch_attn_ref( + high_precision_q, + high_precision_k_torch, + high_precision_v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + out_ref = torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ) + + except torch.cuda.OutOfMemoryError: + pytest.skip("Required GPU memory is larger than capacity.") + + inference_ops.flash_decoding_attention( + output, + q.squeeze(2), + k_cache, + v_cache, + kv_seq_lengths, + block_tables, + BLOCK_SIZE, + max_seq_len_across_batch, + mid_output, + exp_sums, + max_logits, + alibi_slopes, + sm_scale, + ) + + # The alibi may introduce relatively large errors + if use_alibi_slopes: + rtol = 1e0 + + try: + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) + + except AssertionError: + if MAX_NUM_BLOCKS_PER_SEQ >= 256: + pytest.skip("Long sequence length introduce precision error.") + else: + raise + + +try: + from vllm._C import ops as vllm_ops # noqa + + HAS_VLLM = True +except ImportError: + HAS_VLLM = False + print("The subsequent test requires vllm. Please refer to https://github.com/vllm-project/vllm") + + +@pytest.mark.skipif(not HAS_VLLM, reason="requires vllm") +@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32]) +@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32]) +@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32]) +@pytest.mark.parametrize("HEAD_SIZE", [64, 128]) +@pytest.mark.parametrize("NUM_ATTN_HEADS", [16]) +@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) +def test_vllm_flash_decoding_attention( + BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM + assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads." + MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ + device = get_current_device() + + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device + ) + + k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device + ) + + block_tables = block_tables.to(device=device) + max_seq_len_across_batch = kv_seq_lengths.max().item() + output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device) + sm_scale = 1.0 / (HEAD_SIZE**0.5) + kv_scale = 1.0 + + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch) + torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device) + + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device) + alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device) + torch_padding_mask = torch_padding_mask + alibi_mask + + if len(torch_padding_mask.size()) == 4: + torch_padding_mask = torch_padding_mask[:, :, -1:, :] + else: + torch_padding_mask = torch_padding_mask[:, -1:, :] + else: + alibi_slopes = None + + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + high_precision_q = q.to(torch.float32) + high_precision_k_torch = k_torch.to(torch.float32) + high_precision_v_torch = v_torch.to(torch.float32) + out_ref = torch_attn_ref( + high_precision_q, + high_precision_k_torch, + high_precision_v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + out_ref = torch_attn_ref( + q, + k_torch, + v_torch, + torch_padding_mask, + BATCH_SIZE, + q_len, + max_seq_len_across_batch, + NUM_ATTN_HEADS, + NUM_KV_HEADS, + HEAD_SIZE, + ) + + vllm_ops.paged_attention_v1( + output, + q.squeeze(2), + k_cache, + v_cache, + NUM_KV_HEADS, + sm_scale, + block_tables, + kv_seq_lengths, + BLOCK_SIZE, + max_seq_len_across_batch, + alibi_slopes, + "auto", + kv_scale, + ) + + # The alibi may introduce relatively large errors + if use_alibi_slopes: + rtol = 1e0 + + numpy_allclose(out_ref, output, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + BATCH_SIZE = [1, 4, 7, 32] + BLOCK_SIZE = [8, 16, 32] + MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32] + HEAD_SIZE = [64, 128] + NUM_ATTN_HEADS = [16] + KV_GROUP_NUM = [1, 2, 16] + DTYPE = [torch.float16, torch.float32] + test_combinations = list( + product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE) + ) + for ( + batch_size, + block_size, + max_num_blocks_per_seq, + head_size, + num_attn_heads, + kv_group_num, + dtype, + ) in test_combinations: + test_flash_decoding_attention( + batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True + ) diff --git a/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py new file mode 100644 index 000000000000..b6ba1a01bd54 --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin + +inference_ops = InferenceOpsLoader().load() + + +def numpy_equal(x, y): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_equal(x_numpy, y_numpy) + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [64]) +@pytest.mark.parametrize("HEAD_DIM", [64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_get_cos_and_sin(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): + MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN + cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda").to(torch.int32) + + max_seq_len_in_batch = lengths.max() + + # prefill + cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + + cos = torch.zeros_like(cos_ref) + sin = torch.zeros_like(sin_ref) + + inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, True) + + numpy_equal(cos, cos_ref) + numpy_equal(sin, sin_ref) + + # decoding + ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + + cos = torch.zeros_like(ncos_ref) + sin = torch.zeros_like(nsin_ref) + + inference_ops.get_cos_and_sin(cos_cache, sin_cache, cos, sin, lengths, max_seq_len_in_batch, False) + numpy_equal(cos, ncos_ref) + numpy_equal(sin, nsin_ref) + + +if __name__ == "__main__": + test_get_cos_and_sin(16, 4096, 256, torch.float16) diff --git a/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py new file mode 100644 index 000000000000..d90f64690152 --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py @@ -0,0 +1,157 @@ +import pytest +import torch +import torch.nn.functional as F + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + generate_caches_and_block_tables_v3, + mock_alloc_single_token, +) + +inference_ops = InferenceOpsLoader().load() + +HEAD_DIM = 72 + + +def prepare_data( + bsz, + num_kv_heads, + block_size, + max_num_blocks_per_seq, + context_lengths, + device="cuda", + dtype=torch.float16, +): + num_tokens = torch.sum(context_lengths).item() + + max_seq_len_in_batch = context_lengths.max() + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + kv_size = (num_tokens, num_kv_heads, HEAD_DIM) + key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3( + key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + + block_tables = block_tables.to(device=device) + k_cache = torch.zeros_like(k_cache_ref) + v_cache = torch.zeros_like(v_cache_ref) + + return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref + + +def run_decode_copy_kv_to_caches( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + n = 1 + + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float32 + device = get_current_device() + + assert max_seq_len > n, "max_seq_len must be greater than n" + + past_kv_seq_lengths = ( + torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device) + ) + + key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data( + bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype + ) + + new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device) + new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device) + + # mock allocating blocks for the new k/v and update block tables + for _ in range(n): + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + past_kv_seq_lengths += 1 + + inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables) + + past_kv_seq_len = past_kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :] + k_source = new_k.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :] + k_target = k_target.reshape(v_target.shape) + v_source = new_v.squeeze() + + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + +def run_context_copy_kv_to_cache( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + max_seq_len = max_num_blocks_per_seq * block_size + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + + ( + key, + value, + k_cache, + v_cache, + cu_seqlens, + block_tables, + max_seq_len_in_batch, + k_cache_ref, + v_cache_ref, + ) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype) + + inference_ops.context_kv_cache_memcpy( + key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch + ) + + assert torch.equal(k_cache, k_cache_ref) + assert torch.equal(v_cache, v_cache_ref) + + +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_kv_cache_memcopy( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + + +if __name__ == "__main__": + test_kv_cache_memcopy(4, 32, 8, 16, True) diff --git a/tests/test_infer/test_kernels/cuda/test_rms_layernorm.py b/tests/test_infer/test_kernels/cuda/test_rms_layernorm.py new file mode 100644 index 000000000000..0b677fff89e9 --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_rms_layernorm.py @@ -0,0 +1,51 @@ +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("M", [2, 4, 8, 16]) +@pytest.mark.parametrize("N", [64, 128, 512, 5120]) +def test_rms_layernorm(M: int, N: int): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + device = get_current_device() + + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device=device) + residual = torch.rand(x_shape, dtype=dtype, device=device) + residual_copy = residual.clone() + rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + x_copy = x.clone() + + y_cuda = torch.empty_like(x) + inference_ops.rms_layernorm(y_cuda, x, weight, eps) + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + + inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) + y_cuda = x + + x = x_copy + residual_copy + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_rms_layernorm(16, 5120) diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py new file mode 100644 index 000000000000..8237384c03fd --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from colossalai.kernel.kernel_loader import InferenceOpsLoader + +inference_ops = InferenceOpsLoader().load() + +from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3 +from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb + + +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("SEQ_LEN", [64]) +@pytest.mark.parametrize("H", [32]) +@pytest.mark.parametrize("K_H", [16, 32]) +@pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): + torch.manual_seed(10) + TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN + # our crafted op equals to Transformers + x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + emb = LlamaRotaryEmbedding(D) + cos, sin = emb(x0, TOTAL_TOKENS) + cos_2 = cos[:, : D // 2] + sin_2 = sin[:, : D // 2] + position_ids = torch.arange(TOTAL_TOKENS) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) + embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + assert torch.allclose(embd_x0, embd_stimulated_x) + + # create data + block_size = 32 + max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size + q_shape = (TOTAL_TOKENS, H, D) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (TOTAL_TOKENS, K_H, D) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (TOTAL_TOKENS, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x) + v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda") + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v3( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size + ) + new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda") + new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") + new_v = torch.randn_like(new_k) + + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + + new_q_copy = new_q.clone() + new_k_copy = new_k.clone() + + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + new_q_fp16 = new_q.clone() + new_k_fp16 = new_k.clone() + + high_precision_cos = cos[:BATCH_SIZE].to(torch.float32) + high_precision_sin = sin[:BATCH_SIZE].to(torch.float32) + high_precision_q = new_q.to(torch.float32) + high_precision_k = new_k.to(torch.float32) + q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16) + k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + + inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True + ) + + inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True) + + past_kv_seq_len = kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze() + k_source = new_k_copy.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() + k_target = k_target.reshape(v_target.shape) + v_source = new_v.squeeze() + + numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol) + numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol) + + numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol) + + assert k_target.shape == k_source.shape + numpy_allclose(k_target, k_source, rtol=rtol, atol=atol) + + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + if dtype == torch.float16: + # After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test. + rtol = 1e-3 + atol = 1e-1 + inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False) + numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + test_rotary_emb(16, 64, 32, 16, 128, torch.float16) diff --git a/tests/test_infer/test_kernels/cuda/test_silu_and_mul.py b/tests/test_infer/test_kernels/cuda/test_silu_and_mul.py new file mode 100644 index 000000000000..ced2db7ca048 --- /dev/null +++ b/tests/test_infer/test_kernels/cuda/test_silu_and_mul.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("SHAPE_X", [2]) +@pytest.mark.parametrize("SHAPE_Y", [64]) +@pytest.mark.parametrize("SHAPE_Z", [11008]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): + torch.manual_seed(5) + device = get_current_device() + ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device) + origin_input = ref_input.clone() + + act_out = torch.nn.functional.silu(ref_input[0], inplace=True) + ref_out = act_out * ref_input[1] + + origin_out = inference_ops.silu_and_mul(origin_input) + + if dtype == torch.float32: + assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5) + else: + assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + test_silu_and_mul(2, 64, 11008, torch.float32) + test_silu_and_mul(2, 64, 11008, torch.float16) diff --git a/tests/test_infer/test_kernels/triton/__init__.py b/tests/test_infer/test_kernels/triton/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_infer/test_kernels/triton/kernel_utils.py b/tests/test_infer/test_kernels/triton/kernel_utils.py new file mode 100644 index 000000000000..916691228e7c --- /dev/null +++ b/tests/test_infer/test_kernels/triton/kernel_utils.py @@ -0,0 +1,348 @@ +from typing import Tuple + +import torch +from torch.nn import functional as F + + +# This function is adapted from src/transformers/models/llama/modeling_llama.py +# in huggingface transformers repository +# https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/llama/modeling_llama.py#L273 +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (bsz, num_key_value_heads, seq_len, head_dim) to (bsz, num_attention_heads, seq_len, head_dim) + """ + if n_rep == 1: + return hidden_states + bsz, num_key_value_heads, seq_len, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim) + return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) + + +def create_attention_mask(kv_lengths: torch.Tensor, bsz: int, q_len: int, kv_len: int, device="cuda"): + assert q_len <= kv_len + + causal_mask = torch.full((q_len, q_len), fill_value=float("-inf"), device=device).triu(diagonal=1) + + padding_mask = torch.zeros((bsz, 1, q_len, kv_len), dtype=torch.float32, device=device) + for i in range(bsz): + cur_seq_len = kv_lengths[i].item() + assert cur_seq_len <= kv_len + padding_mask[i, :, :, : kv_len - cur_seq_len] = float("-inf") + + padding_mask[:, :, -q_len:, -q_len:] += causal_mask + + return padding_mask + + +# Attention calculation adapted from HuggingFace transformers repository +# src/transformers/models/llama/modeling_llama.py +# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350 +def torch_attn_ref( + q: torch.Tensor, # [bsz, num_heads, q_len, head_dim] + k: torch.Tensor, # [bsz, num_heads, kv_len, head_dim] + v: torch.Tensor, # [bsz, num_heads, kv_len, head_dim] + attention_mask: torch.Tensor, # [bsz, 1, q_len, kv_len] + bsz: int, + q_len: int, + kv_len: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, +) -> torch.Tensor: + assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim + + # repeat kv for GQA and MQA + # k/v won't change if kv_group_num is 1 + assert num_heads % num_kv_heads == 0, "Number of heads is not multiple of kv heads" + kv_group_num = num_heads // num_kv_heads + k = repeat_kv(k, kv_group_num) + v = repeat_kv(v, kv_group_num) + + qk = torch.matmul(q, k.transpose(2, 3)) + attn_scores = qk / (head_dim**0.5) + + assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores" + if attention_mask is not None: + attn_scores = attn_scores + attention_mask + + attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) + out = torch.matmul(attn_weights, v) + if out.size() != (bsz, num_heads, q_len, head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)}, but is" f" {out.size()}" + ) + out = out.transpose(1, 2).contiguous() + out = out.view(-1, out.size(-2), out.size(-1)) + # out [bsz * q_len, num_heads, head_dim] + return out + + +def mock_alloc_block_table_and_kvcache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + k_cache[block_id, :, :, :allocated_locs] = k_block + v_cache[block_id, :, :, :allocated_locs] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + +def mock_alloc_block_table_and_kvcache_v2( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + k_block = k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + k_cache[block_id, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :allocated_locs, :] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + +def mock_alloc_block_table_and_kvcache_v3( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + + _, num_kv_heads, head_dim = k.shape + + x = 16 // torch.tensor([], dtype=k.dtype).element_size() + + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + # [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x] + k_block = ( + k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :] + .reshape(allocated_locs, num_kv_heads, head_dim // x, x) + .permute(1, 2, 0, 3) + ) + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2) + k_cache[block_id, :, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :allocated_locs, :] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + +def mock_alloc_block_table_and_kvcache_vllm( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + context_lengths: torch.Tensor, + num_seqs: int, + max_num_blocks_per_seq: int, + block_size: int, +) -> torch.Tensor: + """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" + block_id = 0 + block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) + num_tokens_processed = 0 + + _, num_kv_heads, head_dim = k.shape + + x = 16 // torch.tensor([], dtype=k.dtype).element_size() + + for i, seq_len in enumerate(context_lengths.tolist()): + right_bound = (seq_len + block_size - 1) // block_size # open bound + block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32) + # Manually fill kv caches by copying from k and v + for i in range(right_bound): + if i == right_bound - 1: + allocated_locs = seq_len % block_size or block_size + else: + allocated_locs = block_size + # [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x] + k_block = ( + k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :] + .reshape(allocated_locs, num_kv_heads, head_dim // x, x) + .permute(1, 2, 0, 3) + ) + # [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size] + v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0) + k_cache[block_id, :, :, :allocated_locs, :] = k_block + v_cache[block_id, :, :, :allocated_locs] = v_block + + num_tokens_processed += allocated_locs + block_id += 1 + + return block_tables + + +def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None: + # Allocate 1 token on the block table for each seqs in block tables. + # It won't change provided context_lengths. + # Consider max_block_id as the last physical block allocated + # NOTE It assumes all the blocks preceding this block have been allocated + max_block_id = torch.max(block_tables).item() + # the indices on each block table representing the cache block to be allocated one more token + alloc_local_block_indices = context_lengths // block_size + # offsets of the token to be allocated on the target block (for each seq) + alloc_block_offsets = context_lengths % block_size + + require_new_block = alloc_block_offsets == 0 + new_block_ids = torch.arange( + max_block_id + 1, + max_block_id + 1 + require_new_block.sum(), + dtype=block_tables.dtype, + device=block_tables.device, + ) + + if new_block_ids.numel(): + new_block_alloc_local_indices = alloc_local_block_indices[require_new_block] + block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids + + +def generate_caches_and_block_tables( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + +def generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_v2( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + +def generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + + x = 16 // torch.tensor([], dtype=dtype).element_size() + + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_v3( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + +def generate_caches_and_block_tables_vllm( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + + x = 16 // torch.tensor([], dtype=dtype).element_size() + + k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x) + v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache_vllm( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + +def convert_kv_unpad_to_padded( + k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int +) -> torch.Tensor: + # Rebuild (batched) k/v with padding to be used by torch attention + # input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + # returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device) + prev_len_sum = 0 + for i, seq_len in enumerate(kv_seq_lengths.tolist()): + # left-side padding + k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len] + prev_len_sum += seq_len + k_torch = k_torch.transpose(1, 2) + return k_torch diff --git a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py new file mode 100644 index 000000000000..9d76858ed07f --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -0,0 +1,179 @@ +import pytest +import torch +from packaging import version + +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.kernel.triton import context_attention_unpadded +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + torch_attn_ref, +) + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +HEAD_DIM = 32 + + +def _fill_with_neg_inf(t): + return t.float().fill_(float("-inf")).type_as(t) + + +# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py +def generate_alibi_mask(slopes, num_heads, max_seq_len, device): + token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1 + token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1) + diag = torch.diag(token_position[0]) + token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position + alibi = alibi.view(num_heads, 1, max_seq_len) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + return alibi_mask + + +def torch_attn_unpad( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + context_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + slopes: torch.Tensor = None, +): + # Process sequence one by one and concatenate them together. + # q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim] + assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor" + + _, num_heads, head_dim = q.shape + out_torch = [] + start_idx = 0 + for seq_i in range(len(context_lengths)): + end_idx = start_idx + context_lengths[seq_i].item() + seq_len = end_idx - start_idx + mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device) + mask[mask == 0.0] = float("-inf") + + if slopes is not None: + alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device) + mask = mask + alibi_mask + + torch_attn_ref_out = torch_attn_ref( + q[start_idx:end_idx].unsqueeze(0).transpose(1, 2), + k[start_idx:end_idx].unsqueeze(0).transpose(1, 2), + v[start_idx:end_idx].unsqueeze(0).transpose(1, 2), + mask, + 1, # set bsz as 1 as we're processing sequence one by one + seq_len, + seq_len, + num_heads, + num_kv_heads, + head_dim, + ) + out_torch.append(torch_attn_ref_out.squeeze(0)) + start_idx = end_idx + + return torch.cat(out_torch, dim=0) + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("bsz", [7, 32]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16]) +@pytest.mark.parametrize("num_attn_heads", [16]) +@pytest.mark.parametrize("kv_group_num", [1, 4]) +@pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) +def test_context_attention( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_attn_heads: int, + kv_group_num: int, + same_context_len: bool, + use_alibi_slopes: bool, + use_new_kcache_layout: bool, +): + if use_new_kcache_layout and use_alibi_slopes: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + # And tests for the alibi kernel using new kcache layout will be added then. + return + + torch.manual_seed(123) + # It's necessary to clear cache here. + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + max_seq_len = max_num_blocks_per_seq * block_size + dtype = torch.float16 + device = get_current_device() + alibi_slopes = None + + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(num_attn_heads, device) + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(context_lengths).item() + + qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) + qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) + q_unpad = q_unpad.contiguous() + + if use_new_kcache_layout: + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + else: + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + + block_tables = block_tables.to(device=device) + k_cache_triton = torch.zeros_like(k_cache_ref) + v_cache_triton = torch.zeros_like(v_cache_ref) + + _, num_heads, head_dim = q_unpad.shape + + out_triton = context_attention_unpadded( + q_unpad, + k_unpad, + v_unpad, + k_cache_triton, + v_cache_triton, + context_lengths, + block_tables, + block_size, + alibi_slopes=alibi_slopes, + use_new_kcache_layout=use_new_kcache_layout, + ) + + out_triton = out_triton.view(-1, num_heads, head_dim) + out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes) + + assert out_torch.shape == out_triton.shape + assert torch.allclose(out_torch, out_triton, atol=1e-3) + assert torch.equal(k_cache_ref, k_cache_triton) + assert torch.equal(v_cache_ref, v_cache_triton) + + +if __name__ == "__main__": + test_context_attention(4, 32, 8, 16, 1, True, True, True) diff --git a/tests/test_infer/test_kernels/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py new file mode 100644 index 000000000000..e487129c19e7 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -0,0 +1,197 @@ +import numpy as np +import pytest +import torch +from packaging import version + +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.kernel.triton import flash_decoding_attention +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + convert_kv_unpad_to_padded, + create_attention_mask, + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + torch_attn_ref, +) +from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +HEAD_DIM = 128 + + +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + +def prepare_data( + bsz: int, + num_attn_heads: int, + num_kv_heads: int, + head_dim: int, + same_context_len: bool, + q_len: int, + max_kv_seq_len: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + # returns + # q [bsz, num_attn_heads, q_len, head_dim] + # k_unpad/v_unpad [num_tokens, num_kv_heads, head_dim] + kv_lengths = ( + torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_kv_seq_len, size=(bsz,), dtype=torch.int32, device=device) + ) + num_tokens = torch.sum(kv_lengths).item() + + q_size = (bsz, q_len, num_attn_heads, head_dim) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * num_kv_heads, head_dim) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("bsz", [7, 16]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 16]) +@pytest.mark.parametrize("num_attn_heads", [16]) +@pytest.mark.parametrize("kv_group_num", [1, 4]) +@pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("q_len", [1, 5]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) +def test_flash_decoding( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_attn_heads: int, + kv_group_num: int, + same_context_len: bool, + q_len: int, + use_alibi_slopes: bool, + use_new_kcache_layout: bool, +): + if use_new_kcache_layout and use_alibi_slopes: + # TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one, + # the code (alibi kernel) will be refactored later to avoid code duplication, when + # the whole triton flow with new k cache layout has been supported and tested. + # And tests for the alibi kernel using new kcache layout will be added then. + pytest.skip("Alibi kernel does not support new kcache layout yet.") + + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(num_attn_heads, device) + # Currently, alibi flash decoding does not support q_len>1. + q_len = 1 + else: + alibi_slopes = None + + q, k_unpad, v_unpad, kv_lengths = prepare_data( + bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device + ) + # The maximum sequence length in the batch (if context lengths randomly generated) + max_kv_len_in_b = kv_lengths.max().item() + + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b) + attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) + + if use_alibi_slopes: + alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device) + attention_mask = attention_mask + alibi_mask + + if q_len == 1: + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + + out_torch = torch_attn_ref( + q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM + ) + + if use_new_kcache_layout: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + else: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + # The maximum block length splitted on kv should be the kv cache block size + kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size + output = torch.empty((bsz * q_len, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) + mid_output = torch.empty( + size=(bsz * q_len, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty( + size=(bsz * q_len, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device + ) + sm_scale = 1.0 / (HEAD_DIM**0.5) + # Here we use different methods to hide the q_len dimension, + # refer to attention forward function in modeling. + if q_len > 1: + q = q.transpose(1, 2).contiguous() # [bsz, q_len, num_heads, head_dim] + q = q.view(-1, q.size(-2), q.size(-1)) # [bsz * q_len, num_heads, head_dim] + else: + q = q.squeeze(2) + assert q.shape == (bsz * q_len, num_attn_heads, HEAD_DIM) + + out_triton = flash_decoding_attention( + q, + k_cache, + v_cache, + kv_lengths, + block_tables, + block_size, + max_kv_len_in_b, + output, + mid_output, + mid_output_lse, + alibi_slopes=alibi_slopes, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + q_len=q_len, + use_new_kcache_layout=use_new_kcache_layout, + ) # [bsz * q_len, num_heads, head_dim] + + assert out_torch.shape == out_triton.shape + + rtol = 1e-4 + # After the shape becomes larger, some data elements are too small, leading to excessively large relative errors. + if bsz >= 16 and use_alibi_slopes: + rtol = 100 + + numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol) + + +if __name__ == "__main__": + test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True) diff --git a/tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py b/tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py new file mode 100644 index 000000000000..787e48986185 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py @@ -0,0 +1,50 @@ +from copy import deepcopy + +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding +from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding +from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +def test_fused_rotary_emb(): + num_tokens = 20 + num_kv_heads = 32 + head_dim = 64 + dtype = torch.float32 + q_shape = (num_tokens, num_kv_heads, head_dim) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + q_copy = deepcopy(q) + + k_shape = (num_tokens, num_kv_heads, head_dim) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + k_copy = deepcopy(k) + + cos_shape = (1024, head_dim) + lengths = torch.tensor([3, 4, 6, 7], device="cuda") + cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + cos, sin = get_xine_cache(lengths, cos_cache[:, : head_dim // 2], sin_cache[:, : head_dim // 2]) + + rotary_embedding(q, k, cos, sin) + fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths) + torch.allclose(q, q_copy) + torch.allclose(k, k_copy) + + +if __name__ == "__main__": + test_fused_rotary_emb() diff --git a/tests/test_infer/test_kernels/triton/test_kvcache_copy.py b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py new file mode 100644 index 000000000000..4aa34ae30649 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py @@ -0,0 +1,168 @@ +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache +from colossalai.utils import get_current_device +from tests.test_infer.test_kernels.triton.kernel_utils import ( + generate_caches_and_block_tables_v2, + generate_caches_and_block_tables_v3, + mock_alloc_single_token, +) + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +HEAD_DIM = 32 + + +def prepare_data( + bsz, + num_kv_heads, + head_dim, + block_size, + max_num_blocks_per_seq, + same_context_len, + max_seq_len, + n=1, + device="cuda", + dtype=torch.float16, + use_new_kcache_layout=False, +): + assert max_seq_len > n, "max_seq_len must be greater than n" + + past_kv_seq_lengths = ( + torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device) + ) + num_tokens = torch.sum(past_kv_seq_lengths).item() + + kv_size = (num_tokens, 2 * num_kv_heads, head_dim) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) + + if use_new_kcache_layout: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3( + k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device + ) + else: + k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( + k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device + ) + block_tables = block_tables.to(device=device) + + new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) + new_v = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device) + # mock allocating blocks for the new k/v and update block tables + for _ in range(n): + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) + past_kv_seq_lengths += 1 + + return new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("bsz", [7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [16]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("n_tokens", [1, 5]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) +def test_copy_kv_to_caches( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, + n_tokens: int, + use_new_kcache_layout: bool, +): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + max_seq_len = block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data( + bsz, + num_kv_heads, + HEAD_DIM, + block_size, + max_num_blocks_per_seq, + same_context_len, + max_seq_len, + n_tokens, + device=device, + dtype=dtype, + use_new_kcache_layout=use_new_kcache_layout, + ) + k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1)) + v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1)) + k_cache_copy = k_cache.detach().clone() + past_kv_seq_lengths = kv_seq_lengths - n_tokens + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_lengths // block_size] + offsets_in_block = past_kv_seq_lengths % block_size + + # Copy k (or v) to k (or v) cache + copy_k_to_blocked_cache( + new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout + ) + # Reshape target k from k cache to compare if matching with original tensor + # Mainly to handle cases of n_tokens > 1 + k_target = [] + for i in range(bsz): + block_table = block_tables[i] + curr_kv_len = past_kv_seq_lengths[i].item() + offset = offsets_in_block[i].item() + tokens_left = n_tokens + while tokens_left > 0: + tokens_to_fill = min(block_size - offset, tokens_left) + curr_block_id = block_table[curr_kv_len // block_size] + if use_new_kcache_layout: + k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :]) + else: + k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :]) + curr_kv_len += tokens_to_fill + tokens_left -= tokens_to_fill + offset = 0 + if use_new_kcache_layout: + k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous() + k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM) + else: + k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim] + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + + if n_tokens == 1: + # Copy k and v to k/v caches + k_cache = k_cache_copy + copy_kv_to_blocked_cache( + new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout + ) + + if use_new_kcache_layout: + k_target = k_cache[target_block_ids, :, :, offsets_in_block, :] + k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM) + else: + k_target = k_cache[target_block_ids, :, offsets_in_block, :] + assert k_target.shape == k_source.shape + assert torch.equal(k_target, k_source) + v_target = v_cache[target_block_ids, :, offsets_in_block, :] + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + +if __name__ == "__main__": + test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1) diff --git a/tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py b/tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py new file mode 100644 index 000000000000..20b7ff519541 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py @@ -0,0 +1,55 @@ +import pytest +import torch +from packaging import version +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from colossalai.kernel.triton import rms_layernorm +from colossalai.testing.utils import parameterize + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@parameterize("M", [2, 4, 8, 16]) +@parameterize("N", [64, 128]) +def test_layer_norm(M, N): + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device="cuda") + residual = torch.rand(x_shape, dtype=dtype, device="cuda") + residual_copy = residual.clone() + rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + x_copy = x.clone() + + y_triton, _ = rms_layernorm(x, weight, eps=eps) + y_llama = rms_norm.forward(x).to(dtype) + + assert y_triton.shape == y_llama.shape + assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) + + y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual) + + x = x_copy + residual_copy + + y_llama = rms_norm.forward(x).to(dtype) + + assert y_triton.shape == y_llama.shape + assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3) + assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_layer_norm() diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py new file mode 100644 index 000000000000..570093693447 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -0,0 +1,100 @@ +import pytest +import torch +from packaging import version +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from colossalai.kernel.triton import decoding_fused_rotary_embedding +from tests.test_infer.test_kernels.triton.kernel_utils import ( + mock_alloc_block_table_and_kvcache_v2, + mock_alloc_block_table_and_kvcache_v3, +) + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("SEQ_LEN", [64]) +@pytest.mark.parametrize("H", [32]) +@pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("use_new_kcache_layout", [True, False]) +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): + TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN + # our crafted op equals to Transformers + x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) + x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) + emb = LlamaRotaryEmbedding(D) + cos, sin = emb(x0, TOTAL_TOKENS) + cos_2 = cos[:, :32] + sin_2 = sin[:, :32] + position_ids = torch.arange(TOTAL_TOKENS) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) + embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + assert torch.allclose(embd_x0, embd_stimulated_x) + + # create data + block_size = 32 + max_num_blocks_per_seq = 4 + q_shape = (TOTAL_TOKENS, H, D) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (TOTAL_TOKENS, H, D) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + + cos_shape = (TOTAL_TOKENS, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda") + + if use_new_kcache_layout: + x = 16 // torch.tensor([], dtype=dtype).element_size() + kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x) + k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v3( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + else: + k_cache = torch.zeros_like(v_cache) + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + + decoding_fused_rotary_embedding( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout + ) + assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True) diff --git a/tests/test_infer/test_kernels/triton/test_xine_copy.py b/tests/test_infer/test_kernels/triton/test_xine_copy.py new file mode 100644 index 000000000000..d8ce78617260 --- /dev/null +++ b/tests/test_infer/test_kernels/triton/test_xine_copy.py @@ -0,0 +1,66 @@ +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton import get_xine_cache + +try: + import triton # noqa + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@torch.no_grad() +def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): + """ + Get cos and sin for the cache, and return nopad format. + Args: + lengths: shape(num_seqs,), stores lenghth of each sequence. + cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model. + sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model. + is_prompts: bool, mark if in prefill mode. + dtype: The data type of this inference process. + """ + + if is_prompts: + index_arrays = [torch.arange(length) for length in lengths] + else: + index_arrays = [(length - 1).view(-1) for length in lengths] + indices = torch.cat(index_arrays, dim=-1) + cos_output = cos_cache[indices].to(dtype=dtype) + sin_output = sin_cache[indices].to(dtype=dtype) + + return (cos_output, sin_output) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("MAX_SEQ_LEN", [64]) +@pytest.mark.parametrize("HEAD_DIM", [64]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): + MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN + cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda") + # prefill + cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) + assert torch.allclose(cos, cos_ref) + assert torch.allclose(sin, sin_ref) + # decoding + ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False) + assert torch.allclose(cos, ncos_ref) + assert torch.allclose(sin, nsin_ref) + + +if __name__ == "__main__": + test_get_xine_cache(4, 64, 256, torch.float32) diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py old mode 100644 new mode 100755 index e8765317291a..bca9a1a84f08 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,66 +1,179 @@ -import os +import random import pytest import torch -from packaging import version +from transformers.models.llama import LlamaConfig -from colossalai.inference.kv_cache import MemoryManager +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.kv_cache import CacheBlock, KVCacheManager from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -BATCH_SIZE = 4 -INPUT_LEN = 16 -OUTPUT_LEN = 8 -LAYER_NUM = 4 -HEAD_NUM = 32 -HEAD_DIM = 128 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") +@parameterize( + "test_config", + [ + { + "elem_size": 2, + "block_size": 4, + } + ], +) +def test_logical_blocks(test_config): + block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"]) + assert block.is_empty() + assert block.available_space == test_config["block_size"] + assert not block.has_ref() + block.add_ref() + assert block.ref_count == 1 + assert block.has_ref() + block.remove_ref() + assert block.ref_count == 0 + block.allocate(1) + assert block.allocated_size == 1 + block.allocate(test_config["block_size"] - 1) + assert block.available_space < 1 -def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) + +@parameterize( + "test_config", + [ + { + "hidden_size": 512, + "num_attention_heads": 16, + "num_layers": 2, + "block_size": 8, + "max_batch_size": 10, + "max_input_len": 32, + "max_output_len": 32, + "dtype": torch.float32, + "beam_width": 1, + "tp_size": 1, + }, + { + "hidden_size": 128, + "num_attention_heads": 4, + "num_layers": 3, + "block_size": 4, + "max_batch_size": 4, + "max_input_len": 64, + "max_output_len": 32, + "dtype": torch.float16, + "beam_width": 3, + "tp_size": 1, + }, + ], +) +def check_cache_manager(test_config): disable_existing_loggers() - size = batch_size * (input_len + output_len) - kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) - key_buffers = kvcache_manager.key_buffer - value_buffers = kvcache_manager.value_buffer - assert len(key_buffers) == len(value_buffers) == layer_num - assert key_buffers[0].shape == value_buffers[0].shape - # required size exceeds the maximum allocated size - invalid_locs = kvcache_manager.alloc_contiguous(size + 1) - assert invalid_locs is None - # for prefill stage, allocation via alloc and alloc_contiguous should be the same - total_token_prefill = batch_size * input_len - prefill_locs = kvcache_manager.alloc(total_token_prefill) - kvcache_manager.free_all() - prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] - assert torch.equal(prefill_locs, prefill_locs_contiguous) - assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill - kvcache_manager.alloc_contiguous(batch_size) - assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False) - - -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") + assert test_config["max_batch_size"] > 1 + + hidden_size = test_config.pop("hidden_size") + num_layers = test_config.pop("num_layers") + num_attention_heads = test_config.pop("num_attention_heads") + head_size = hidden_size // num_attention_heads + block_size = test_config["block_size"] + max_batch_size = test_config["max_batch_size"] + max_input_length = test_config["max_input_len"] + max_output_length = test_config["max_output_len"] + + inference_config = InferenceConfig(**test_config) + model_config = LlamaConfig( + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_attention_heads, + ) + cache_manager = KVCacheManager(inference_config, model_config) + + num_blocks = cache_manager.total_num_blocks + assert num_blocks > 0 + assert len(cache_manager._cache_blocks) == num_blocks + key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers + assert len(key_caches) == num_layers + expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size) + assert key_caches[0].shape == expected_kv_shape + k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) + expected_kv_block_shape = expected_kv_shape[1:] + assert k_cache_block0.shape == expected_kv_block_shape + assert v_cache_block0.shape == expected_kv_block_shape + + max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence() + block_tables = torch.tensor( + [[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32 + ) + context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)] + cnt_blocks_used = 0 + # Mock Prefill + for req_i in range(max_batch_size): + cur_seq_len = context_lengths[req_i] + cur_block_table = block_tables[req_i] + cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len) + last_allocated_idx = (cur_seq_len - 1) // block_size + assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0) + cnt_blocks_used += torch.sum(cur_block_table >= 0).item() + assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used + + # Mock Decoding + for req_i in range(max_batch_size): + context_length = context_lengths[req_i] + cur_output_length = random.randint(1, max_output_length) + cur_block_table = block_tables[req_i] + for _ in range(cur_output_length): + cache_manager.allocate_token_from_block_table(cur_block_table, context_length) + context_length += 1 + context_length -= 1 + last_allocated_idx = context_length // block_size + space_allocated_on_last_block = context_length % block_size + 1 + assert space_allocated_on_last_block > 0 + block_id = cur_block_table[last_allocated_idx] + block: CacheBlock = cache_manager._cache_blocks[block_id] + assert block.allocated_size == space_allocated_on_last_block + + # Randomly select a request and clear its cache + req_i = random.randint(0, max_batch_size - 1) + context_length = context_lengths[req_i] + blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item() + prev_available_blocks = cache_manager.num_available_blocks + cache_manager.free_block_table(block_tables[req_i]) + assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks + + k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0) + k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0) + elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size() + expected_stride = block_size * num_attention_heads * head_size * elem_size + assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride + cache_manager.clear_all() + assert cache_manager.num_available_blocks == num_blocks + + for cache_block in cache_manager._cache_blocks: + assert cache_block.available_space == block_size + + # Mock batch operations (Prefill/Decoding updates) + context_lengths = torch.tensor([max_input_length, max_input_length - 1]) + block_tables = torch.tensor( + [[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32 + ) + cache_manager.allocate_context_from_block_tables(block_tables, context_lengths) + cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths) + cache_manager.free_block_tables(block_tables) + for cache_block in cache_manager._cache_blocks: + assert cache_block.available_space == block_size + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_cache_manager() + + @pytest.mark.dist @rerun_if_address_is_in_use() -def test_cache_manager_dist(): - spawn( - create_cache_manager, - 4, - batch_size=BATCH_SIZE, - input_len=INPUT_LEN, - output_len=OUTPUT_LEN, - layer_num=LAYER_NUM, - head_num=HEAD_NUM, - head_dim=HEAD_DIM, - ) +def test_cache_manager(): + spawn(run_dist, 1) if __name__ == "__main__": - test_cache_manager_dist() + test_logical_blocks() + test_cache_manager() diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py new file mode 100644 index 000000000000..79ed6675db5f --- /dev/null +++ b/tests/test_infer/test_models/test_attention.py @@ -0,0 +1,145 @@ +import pytest +import torch +from transformers.cache_utils import DynamicCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + +from colossalai.inference.modeling.layers.attention import PagedAttention, convert_kvcache, copy_to_cache + + +@pytest.mark.skip(reason="This test is not used in the current version.") +def test_copy_to_cache(): + key = torch.ones((2, 11, 3, 3)) + key[0, 9, :, :] = 0 + key[1, -2:, :, :] = 0 + cache = torch.zeros(8, 3, 8, 3) + block_tables = torch.tensor([[0, 1], [2, 3]]) + lengths = torch.tensor([9, 8]) + cache = copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="prefill") + assert cache[1, 0, 0, 0] == 1 + assert cache[3, 0, 0, 0] == 0 + + decoding_key = torch.ones((2, 1, 3, 3)) + cache = copy_to_cache(decoding_key, cache=cache, lengths=lengths + 1, block_tables=block_tables, type="decoding") + assert cache[1, 0, 0, 1] == 1 + assert cache[3, 0, 0, 0] == 1 + + +@pytest.mark.skip(reason="This test is not used in the current version.") +def test_convert_kvcache(): + cache = torch.ones(8, 3, 8, 3) + key = torch.ones(2, 1, 3, 3) + 1 + lengths = torch.tensor([10, 9]) + block_tables = torch.tensor([[0, 1], [2, 3]]) + copy_to_cache(key, cache=cache, lengths=lengths, block_tables=block_tables, type="decoding") + converted_cache = convert_kvcache(cache=cache, lengths=lengths, block_tables=block_tables) + assert converted_cache.shape == (2, 10, 3, 3) + + +@pytest.mark.skip(reason="This test is not used in the current version.") +def test_context_attention(): + """ + test config: head_num = 4, head_size = 4 + """ + attn = PagedAttention() + q = k = v = torch.randn(8, 4, 4) + k_cache = torch.empty(8, 4, 8, 4) + v_cache = torch.empty(8, 4, 8, 4) + context_lengths = torch.tensor( + [ + 8, + ] + ) + block_tables = torch.tensor([[0, 1]]) + attn.nopad_context_forward(q, k, v, k_cache, v_cache, context_lengths, block_tables) + # test padded q/k/v + pad_q = pad_k = pad_v = q.unsqueeze(0) + attn.pad_context_forward(pad_q, pad_k, pad_v, k_cache, v_cache, context_lengths, block_tables) + + config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=16) + transformer_attn = LlamaAttention(config) + transformer_attn.training = False + + # test accuracy with LlamaAttention + hidden_states = torch.randn(1, 8, 16) + proj_q = transformer_attn.q_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + proj_k = transformer_attn.k_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + proj_v = transformer_attn.v_proj(hidden_states).view(1, 8, 4, 4).transpose(1, 2) + + position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device) + position_ids = position_ids.unsqueeze(0) + cos, sin = transformer_attn.rotary_emb(proj_v, 8) + proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids) + + pad_attn_output = attn.pad_context_forward( + proj_q.transpose(1, 2), + proj_k.transpose(1, 2), + proj_v.transpose(1, 2), + k_cache, + v_cache, + context_lengths, + block_tables, + ) + pad_attn_output = transformer_attn.o_proj(pad_attn_output) + attn_mask = AttentionMaskConverter._make_causal_mask( + hidden_states.shape[:2], q.dtype, q.device, past_key_values_length=0 + ) + attn_mask += PagedAttention.generate_padding_mask(context_lengths, 8) + attn_output, _, _ = transformer_attn.forward(hidden_states, attention_mask=attn_mask) + assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-3) + + +@pytest.mark.skip(reason="This test is not used in the current version.") +def test_decoding_attention(): + # test the pipeline of decoding attention + attn = PagedAttention() + q = k = v = torch.randn(2, 1, 4, 8) + k_cache = torch.empty(8, 4, 8, 8) + v_cache = torch.empty(8, 4, 8, 8) + past_kv = torch.randn(2, 8, 4, 8) + context_lenghths = torch.tensor([8, 8]) + lengths = context_lenghths + 1 + block_tables = torch.tensor([[0, 1], [2, 3]]) + copy_to_cache(past_kv, k_cache, lengths=context_lenghths, block_tables=block_tables) + copy_to_cache(past_kv, v_cache, lengths=context_lenghths, block_tables=block_tables) + attn.pad_decoding_forward(q, k, v, k_cache, v_cache, lengths=lengths, block_tables=block_tables) + + # test decoding accuracy, past_kv is reused + config = LlamaConfig(num_attention_heads=4, num_key_value_heads=None, hidden_size=32) + transformer_attn = LlamaAttention(config) + transformer_attn.layer_idx = 0 + transformer_attn.training = False + hidden_states = torch.randn(2, 1, 32) + proj_q = transformer_attn.q_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + proj_k = transformer_attn.k_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + proj_v = transformer_attn.v_proj(hidden_states).view(2, 1, 4, 8).transpose(1, 2) + + cos, sin = transformer_attn.rotary_emb(proj_v, 16) + position_ids = lengths - 1 + position_ids = position_ids.unsqueeze(1) # NOTE: this may be wrong + proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids, unsqueeze_dim=2) + + llama_past_kv = DynamicCache() + llama_past_kv.update(key_states=past_kv.transpose(1, 2), value_states=past_kv.transpose(1, 2), layer_idx=0) + + # past_key_value shape in Llama: bsz, num_heads, seq_len, head_dim + pad_attn_output = attn.pad_decoding_forward( + proj_q.transpose(1, 2), proj_k.transpose(1, 2), proj_v.transpose(1, 2), k_cache, v_cache, lengths, block_tables + ) + attn_mask = AttentionMaskConverter._make_causal_mask(q.shape[:2], q.dtype, q.device, past_key_values_length=8) + attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, 9).unsqueeze(1).unsqueeze(2) + + pad_attn_output = transformer_attn.o_proj(pad_attn_output) + position_ids = context_lenghths.unsqueeze(1) + attn_output, _, _ = transformer_attn.forward( + hidden_states, past_key_value=llama_past_kv, position_ids=position_ids, attention_mask=attn_mask + ) + assert torch.allclose(pad_attn_output, attn_output, atol=1e-3, rtol=1e-2) + + +if __name__ == "__main__": + test_copy_to_cache() + test_convert_kvcache() + test_context_attention() + test_decoding_attention() diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py new file mode 100644 index 000000000000..736fab5ff1a3 --- /dev/null +++ b/tests/test_infer/test_models/test_baichuan.py @@ -0,0 +1,138 @@ +import os +import random + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.multiprocessing import Manager +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +import colossalai +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base" + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda() + model = model.eval() + + inputs = [ + "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", + ] + + output_len = 38 + + if do_sample: + top_p = 0.5 + top_k = 50 + else: + top_p = None + top_k = None + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + use_cuda_kernel=use_cuda_kernel, + tp_size=dist.get_world_size(), + ) + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + return outputs + + +def run_engine(world_size, **kwargs): + manager = Manager() + result_list = manager.list([-1] * world_size) # Create a shared list + + spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs) + return result_list[0] + + +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +# NOTE(caidi) If do_sample is set to True or use_cuda_kernel is set to False, the inference result will be different from that of the transformer. +@parameterize("prompt_template", [None, "baichuan"]) +@parameterize("do_sample", [False]) +@parameterize("use_cuda_kernel", [True]) +def check_tp_engine(prompt_template, do_sample, use_cuda_kernel): + kwargs1 = { + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": NoPaddingBaichuanModelInferPolicy(), + "use_cuda_kernel": use_cuda_kernel, + } + + kwargs2 = { + "use_engine": False, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": None, + "use_cuda_kernel": use_cuda_kernel, + } + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +@pytest.mark.skipif( + not os.path.exists(BAICHUAN_MODEL_NAME_OR_PATH), + reason="There is no local model address included, please replace this address with a valid one.", +) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_inference_engine(): + check_tp_engine() + + +if __name__ == "__main__": + test_inference_engine() diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py new file mode 100644 index 000000000000..912fdbf112c1 --- /dev/null +++ b/tests/test_infer/test_request_handler.py @@ -0,0 +1,105 @@ +import pytest +from transformers.models.llama import LlamaConfig + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.request_handler import RequestHandler, RunningList +from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_running_list(): + """ + Test the RunningList Structure. + """ + running_list = RunningList(prefill_ratio=1.2) + seq1 = Sequence( + request_id=1, + prompt="abc", + input_token_id=[1, 2, 3], + block_size=16, + eos_token_id=0, + pad_token_id=0, + sample_params=None, + ) + seq2 = Sequence( + request_id=2, + prompt="abc", + input_token_id=[1, 2, 3], + block_size=16, + eos_token_id=0, + pad_token_id=0, + sample_params=None, + ) + running_list.append(seq1) + running_list.append(seq2) + assert running_list.ready_for_prefill() + assert len(running_list.decoding) == 0 + assert len(running_list.prefill) > 0 and running_list.prefill[0] == seq1 + + seq = running_list.find_seq(seq1.request_id) + assert seq == seq1 + + running_list.mark_prefill_running() + for seq in running_list.prefill: + assert seq.status == RequestStatus.RUNNING + + running_list.move_prefill_to_decoding([seq1.request_id, seq2.request_id]) + assert len(running_list.prefill) == 0 + assert len(running_list.decoding) > 0 and running_list.decoding[0] == seq1 + + running_list.remove(seq1) + running_list.remove(seq2) + assert running_list.is_empty() + + +def check_request_handler(): + """ + Test main function of RequestHandler + """ + inference_config = InferenceConfig( + max_input_len=10, + max_output_len=10, + block_size=8, + ) + model_config = LlamaConfig( + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + ) + request_handler = RequestHandler(inference_config, model_config) + seq1 = Sequence( + request_id=1, + prompt="abc", + input_token_id=[1, 2, 3, 4, 5], + block_size=16, + eos_token_id=0, + pad_token_id=0, + sample_params=None, + ) + request_handler.add_sequence(seq1) + # the priority should be 1 + assert request_handler.waiting_list[1][0] == seq1 + assert request_handler._has_waiting() + + request_handler.abort_sequence(seq1.request_id) + assert not request_handler._has_waiting() + seq1.status = RequestStatus.WAITING + request_handler.add_sequence(seq1) + request_handler.schedule() + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_running_list() + check_request_handler() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_running_list_and_request_handler(): + spawn(run_dist, 1) + + +if __name__ == "__main__": + test_running_list_and_request_handler() diff --git a/tests/test_infer/test_rpc_engine.py b/tests/test_infer/test_rpc_engine.py new file mode 100644 index 000000000000..12479b49ce50 --- /dev/null +++ b/tests/test_infer/test_rpc_engine.py @@ -0,0 +1,105 @@ +import random + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig +from colossalai.inference.core.rpc_engine import RPCInferenceEngine +from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy +from colossalai.testing import parameterize, rerun_if_address_is_in_use + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = "meta-llama/Llama-2-7b-hf" # remote mode path + inputs = [ + "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", + "介绍一下武汉,", + ] + + output_len = 38 + top_p = 0.5 + top_k = 50 + + if use_engine: + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + dtype="fp32", + use_cuda_kernel=True, + tp_size=tp_size, + ) + inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts=inputs) + assert inference_engine.request_handler._has_waiting() + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) + outputs = inference_engine.generate(generation_config=generation_config) + else: + if prompt_template: + # apply prompt template + inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] + model = AutoModelForCausalLM.from_pretrained(model).cuda() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] + inputs = inputs.cuda() + generation_config = GenerationConfig( + do_sample=do_sample, + dtype="fp32", + top_p=top_p, + top_k=top_k, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=output_len, + ) + outputs = model.generate(inputs, generation_config=generation_config) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return outputs + + +def run_engine(tp_size, **kwargs): + return check_inference_engine(tp_size=tp_size, **kwargs) + + +@pytest.mark.largedist +@parameterize("prompt_template", [None, "llama"]) +@parameterize("do_sample", [False]) +@rerun_if_address_is_in_use() +def test_tp_engine(prompt_template, do_sample): + if torch.multiprocessing.get_start_method(allow_none=True) is None: + torch.multiprocessing.set_start_method("spawn") + kwargs1 = { + "use_engine": True, + "prompt_template": prompt_template, + "do_sample": do_sample, + "policy": NoPaddingLlamaModelInferPolicy(), + } + + kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} + + colossal_tp_1_output = run_engine(1, **kwargs1) + colossal_tp_2_output = run_engine(2, **kwargs1) + transformer_tp_1_output = run_engine(1, **kwargs2) + + for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): + assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" + assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess + test_tp_engine()