From 15bf3d4944c701d1db611e307ebdacf8a966c319 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 19 May 2023 12:01:57 +0200 Subject: [PATCH 01/27] Fused all commits for saner rebase.. Big refacto. Working ? Working bitsandbytes. Weights to its own file. Remove dead file. Bloom. TMP. Finally finished bloom (grr old logic) SantaCoder. Remove dead code. Neox. Black + ruff. T5 Support. Galactica + OPT. Small fixes. Fix auto download. Remove custom transformers. Missing remove instruction. Some work on the dockerfile. Version issues. Black + ruff after rebase. Adding custom_kernels Bad rebase. Fixing dummy gather + fix Dockerfile Better fake gather. Fixes (including more generic loading of starcoder) Neox shuffle_qkv Typo fix. cleanups. Fixing starcoder/santacoder Fix santacoder Fixing neox. Using the saved rotary embeddings instead of the created ones. --- Dockerfile | 15 +- server/Makefile | 5 +- server/Makefile-transformers | 13 - .../fused_bloom_attention_cuda.cu | 250 ++++ server/custom_kernels/setup.py | 15 + server/pyproject.toml | 3 +- server/requirements.txt | 5 +- server/tests/models/test_bloom.py | 9 +- .../text_generation_server/models/__init__.py | 130 +- server/text_generation_server/models/bloom.py | 76 +- .../models/custom_modeling/bloom_modeling.py | 913 +++++++++++++ .../custom_modeling/flash_llama_modeling.py | 210 ++- .../custom_modeling/flash_neox_modeling.py | 266 ++-- .../flash_santacoder_modeling.py | 302 +++-- .../models/custom_modeling/opt_modeling.py | 837 ++++++++++++ .../models/custom_modeling/t5_modeling.py | 1177 +++++++++++++++++ .../models/flash_llama.py | 262 +--- .../models/flash_neox.py | 120 +- .../models/flash_santacoder.py | 367 +---- .../models/galactica.py | 204 +-- server/text_generation_server/models/opt.py | 194 +-- server/text_generation_server/models/t5.py | 199 +-- .../text_generation_server/utils/__init__.py | 2 + server/text_generation_server/utils/dist.py | 51 +- server/text_generation_server/utils/layers.py | 320 +++-- .../text_generation_server/utils/weights.py | 78 ++ 26 files changed, 3987 insertions(+), 2036 deletions(-) delete mode 100644 server/Makefile-transformers create mode 100644 server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu create mode 100644 server/custom_kernels/setup.py create mode 100644 server/text_generation_server/models/custom_modeling/bloom_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/opt_modeling.py create mode 100644 server/text_generation_server/models/custom_modeling/t5_modeling.py create mode 100644 server/text_generation_server/utils/weights.py diff --git a/Dockerfile b/Dockerfile index 483270a8826..ae53b748de4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -98,14 +98,15 @@ COPY server/Makefile-flash-att Makefile RUN make build-flash-attention # Build Transformers CUDA kernels -FROM kernel-builder as transformers-builder +FROM kernel-builder as custom-kernels-builder WORKDIR /usr/src -COPY server/Makefile-transformers Makefile +COPY server/custom_kernels/ . # Build specific version of transformers -RUN BUILD_EXTENSIONS="True" make build-transformers +RUN pip install ninja +RUN python setup.py build # Text Generation Inference base image FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base @@ -136,11 +137,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy build artifacts from transformers builder -COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers -COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers +COPY --from=custom-kernels-builder /usr/src/custom_kernels /usr/src/custom_kernels +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels # Install transformers dependencies -RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir +RUN pip install einops --no-cache-dir # Install server COPY proto proto @@ -170,4 +171,4 @@ ENTRYPOINT ["./entrypoint.sh"] FROM base ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] \ No newline at end of file +CMD ["--json-output"] diff --git a/server/Makefile b/server/Makefile index 6eb56c7582b..17020c97413 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,3 @@ -include Makefile-transformers include Makefile-flash-att unit-tests: @@ -17,7 +16,7 @@ install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir -install: gen-server install-torch install-transformers +install: gen-server install-torch pip install pip --upgrade pip install -r requirements.txt pip install -e ".[bnb, accelerate]" @@ -26,4 +25,4 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements.txt -E bnb --without-hashes \ No newline at end of file + poetry export -o requirements.txt -E bnb --without-hashes diff --git a/server/Makefile-transformers b/server/Makefile-transformers deleted file mode 100644 index 64d0167222d..00000000000 --- a/server/Makefile-transformers +++ /dev/null @@ -1,13 +0,0 @@ -transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e - -transformers: - # Clone fork of transformers with custom CUDA kernels and sharding logic - pip install --upgrade setuptools - git clone https://github.com/OlivierDehaene/transformers.git - -build-transformers: transformers - cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build - -install-transformers: build-transformers - pip uninstall transformers -y || true - cd transformers && python setup.py install \ No newline at end of file diff --git a/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu new file mode 100644 index 00000000000..4be547b1b82 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor fused_qkv, + const std::optional> layer_past, + const at::Tensor alibi, + const at::Tensor attention_mask, + const std::optional head_mask, + const float beta, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + const auto batch_size = fused_qkv.size(0); + const auto q_length = fused_qkv.size(1); + const auto three_times_hidden_size = fused_qkv.size(2); + const auto head_dim = three_times_hidden_size / (3 * num_heads); + const auto batch_size_times_num_heads = batch_size * num_heads; + + // `split_heads` + const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim}); + const auto tensor_list = fused_qkv_view.split(head_dim, -1); + const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length}); + auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim}); + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 1); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + const auto kv_length = key_layer.size(2); + + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_layer); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "Bloom attention mechanism forward (CUDA)" + ); +} \ No newline at end of file diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py new file mode 100644 index 00000000000..62c720e1773 --- /dev/null +++ b/server/custom_kernels/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +setup( + name='custom_kernels', + ext_modules=[ + CUDAExtension( + name="custom_kernels.fused_bloom_attention_cuda", + sources=['custom_kernels/fused_bloom_attention_cuda.cu'], + extra_compile_args=["-arch=compute_80", "-std=c++17"], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/server/pyproject.toml b/server/pyproject.toml index d381eac4d51..f0ec25eb09c 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "0.13.3" -huggingface-hub = "0.14.0" +huggingface-hub = "^0.14.1" +transformers = "^4.29.2" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/requirements.txt b/server/requirements.txt index 50ba4e438b2..e8cee52b081 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" -huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0" -idna==3.4 ; python_version >= "3.9" and python_version < "4.0" +huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" +idna==3.4 ; python_version >= "3.9" and python_version < "4" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" @@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" +transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 338fe053826..71013cb60ca 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -6,12 +6,17 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM +from text_generation_server.utils import weight_hub_files, download_weights +from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded @pytest.fixture(scope="session") def default_bloom(): - return BLOOM("bigscience/bloom-560m") + model_id = "bigscience/bloom-560m" + revision = "main" + filenames = weight_hub_files(model_id, revision, ".safetensors") + download_weights(filenames, model_id, revision) + return BLOOMSharded(model_id) @pytest.fixture(scope="session") diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc92d03d4fe..3e181321e89 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,13 +8,12 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM -from text_generation_server.models.bloom import BLOOM, BLOOMSharded +from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPT, OPTSharded -from text_generation_server.models.galactica import Galactica, GalacticaSharded +from text_generation_server.models.opt import OPTSharded +from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded try: @@ -30,14 +29,12 @@ f"GPU with CUDA capability {major} {minor} is not supported" ) - from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_llama import ( FlashLlama, - FlashLlamaSharded, ) from text_generation_server.models.flash_santacoder import ( - FlashSantacoder, FlashSantacoderSharded, ) @@ -52,30 +49,23 @@ __all__ = [ "Model", - "BLOOM", "BLOOMSharded", "CausalLM", "FlashCausalLM", "Galactica", "GalacticaSharded", - "GPTNeoxSharded", "Seq2SeqLM", "SantaCoder", - "OPT", "OPTSharded", "T5Sharded", "get_model", ] if FLASH_ATTENTION: - __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) - __all__.append(FlashRW) __all__.append(FlashRWSharded) - __all__.append(FlashSantacoder) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) - __all__.append(FlashLlamaSharded) FLASH_ATT_ERROR_MESSAGE = ( "{} requires Flash Attention CUDA kernels to be installed.\n" @@ -102,36 +92,24 @@ def get_model( trust_remote_code: bool, ) -> Model: if "facebook/galactica" in model_id: - if sharded: - return GalacticaSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return Galactica( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + return GalacticaSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) if model_id.startswith("bigcode/"): - if sharded: - if not FLASH_ATTENTION: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") - ) + if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") + ) else: - santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls( + return SantaCoder( model_id, revision, quantize=quantize, @@ -144,20 +122,19 @@ def get_model( model_type = config_dict["model_type"] if model_type == "gpt_bigcode": - if sharded: - if not FLASH_ATTENTION: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") - ) + if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") + ) else: - santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls( + return SantaCoder( model_id, revision, quantize=quantize, @@ -165,33 +142,40 @@ def get_model( ) if model_type == "bloom": - if sharded: - return BLOOMSharded( + return BLOOMSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) + + elif model_type == "gpt_neox": + if FLASH_ATTENTION or shard: + return FlashNeoXSharded( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Neox")) else: - return BLOOM( + return CausalLM( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) - if model_type == "gpt_neox": - if sharded: - neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded - return neox_cls( + elif model_type == "llama": + if FLASH_ATTENTION: + return FlashLlama( model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM - return neox_cls( + return CausalLM( model_id, revision, quantize=quantize, @@ -217,7 +201,7 @@ def get_model( ) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRW( + return FlashRWSharded( model_id, revision, quantize=quantize, @@ -231,42 +215,12 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "llama": - if sharded: - if FLASH_ATTENTION: - return FlashLlamaSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")) - else: - llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM - return llama_cls( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - - if model_type == "opt": - if sharded: - return OPTSharded( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - else: - return OPT( - model_id, - revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) + elif model_type == "opt": + return OPTSharded( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) - if model_type == "t5": + elif model_type == "t5": if sharded: return T5Sharded( model_id, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 45d7cd4c09d..8d0ceeb4853 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -1,35 +1,30 @@ import torch import torch.distributed -from typing import List, Optional, Type +from typing import Optional, Type -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase, ) -from transformers.models.bloom.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) HAS_BITS_AND_BYTES = True try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: + pass +except Exception: HAS_BITS_AND_BYTES = False @@ -42,34 +37,12 @@ def from_pb( dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super(BloomCausalLMBatch, cls).from_pb( - pb=pb, tokenizer=tokenizer, dtype=dtype, device=device - ) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch -class BLOOM(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - super(BLOOM, self).__init__( - model_id=model_id, - revision=revision, - quantize=quantize, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch - - -class BLOOMSharded(BLOOM): +class BLOOMSharded(CausalLM): def __init__( self, model_id: str, @@ -101,25 +74,16 @@ def __init__( trust_remote_code=trust_remote_code, ) config.pad_token_id = 3 + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) + model = BloomForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -131,6 +95,7 @@ def __init__( world_size=world_size, ) +<<<<<<< HEAD @staticmethod def load_weights( model, @@ -257,6 +222,11 @@ def linear(input, weight, bias): module._parameters[param_name] = tensor if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor +======= + @property + def batch_type(self) -> Type[CausalLMBatch]: + return BloomCausalLMBatch +>>>>>>> ba30033 (Fused all commits for saner rebase..) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None @@ -269,9 +239,5 @@ def forward( use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - + logits = outputs.logits return logits, outputs.past_key_values diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py new file mode 100644 index 00000000000..554cab9f4bf --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -0,0 +1,913 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BLOOM model.""" + +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import LayerNorm +from torch.nn import functional as F + +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from transformers import BloomConfig, PreTrainedModel + +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, + FastLinear +) + +CUSTOM_KERNELS_ENABLED = False +if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": + try: + from custom_kernels import fused_bloom_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + +_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m" +_CONFIG_FOR_DOC = "BloomConfig" + +BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bigscience/bigscience-small-testing", + "bigscience/bloom-560m", + "bigscience/bloom-1b1", + "bigscience/bloom-1b7", + "bigscience/bloom-3b", + "bigscience/bloom-7b1", + "bigscience/bloom", +] + + +def _make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) + return expanded_mask + + +def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + powers = torch.arange( + 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, + 1 + 2 * num_remaining_heads, + 2, + device=attention_mask.device, + dtype=torch.int32, + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi + + +# @torch.jit.script +def dropout_add( + x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool +) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + esidual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + + +# @torch.jit.script # this is shit for unknow reasons. +def _split_heads( + fused_qkv: torch.Tensor, num_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim) + query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1) + + query_layer = query_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + key_layer = key_layer.permute(0, 2, 3, 1).reshape( + batch_size * num_heads, head_dim, seq_length + ) + value_layer = value_layer.transpose(1, 2).reshape( + batch_size * num_heads, seq_length, head_dim + ) + + return query_layer, key_layer, value_layer + + +# @torch.jit.script +def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, num_heads, seq_length, head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, num_heads * head_dim) + + +class BloomAttention(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + + self.process_group = weights.process_group + + self.hidden_size = config.hidden_size + self.num_heads = config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 + + process_group = weights.process_group + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=True, + ) + self.dense = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + @staticmethod + def compute_attention( + fused_qkv: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], + alibi: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: Optional[torch.Tensor], + beta: float, + inv_norm_factor: float, + num_heads: int, + use_cache: bool, + ): + batch_size, q_length, three_times_hidden_size = fused_qkv.shape + head_dim = three_times_hidden_size // (3 * num_heads) + batch_size * num_heads + + ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that? + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = _split_heads( + fused_qkv, num_heads=num_heads, head_dim=head_dim + ) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + past_key = past_key.view(-1, *past_key.shape[-2:]) + key_layer = torch.cat((past_key, key_layer), dim=2) + past_value = past_value.view(-1, *past_value.shape[-2:]) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + ### + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=beta, + alpha=inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34` + attn_weights = attention_scores.masked_fill_( + attention_mask, torch.finfo(attention_scores.dtype).min + ) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + input_dtype + ) + + # # [batch_size, num_heads, q_length, kv_length] + # attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs, value_layer, out=query_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = _merge_heads( + context_layer, num_heads=num_heads, head_dim=head_dim + ) + + return context_layer, present, attention_probs + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value( + hidden_states + ) # [batch_size, seq_length, 3 x hidden_size] + batch_size, q_length, _ = fused_qkv.shape + + if layer_past is not None: + past_key, past_value = layer_past + layer_past = ( + past_key.view(-1, *past_key.shape[-2:]), + past_value.view(-1, *past_value.shape[-2:]), + ) + + if CUSTOM_KERNELS_ENABLED: + assert self.training is False, "Only foward pass was implemented" + assert ( + attention_mask.shape[-1] < 4096 + ), "Custom kernel support only up to 4096 tokens" + ( + context_layer, + present, + attention_probs, + ) = fused_bloom_attention_cuda.forward( + fused_qkv, + layer_past, + alibi, + attention_mask, + head_mask, + self.beta, + self.inv_norm_factor, + self.num_heads, + use_cache, + ) + else: + context_layer, present, attention_probs = self.compute_attention( + fused_qkv=fused_qkv, + layer_past=layer_past, + alibi=alibi, + attention_mask=attention_mask, + head_mask=head_mask, + beta=self.beta, + inv_norm_factor=self.inv_norm_factor, + num_heads=self.num_heads, + use_cache=use_cache, + ) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor += residual + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + +class BloomMLP(nn.Module): + def __init__(self, prefix, config: BloomConfig, weights): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) + self.gelu_impl = torch.nn.GELU(approximate="tanh") + self.hidden_dropout = config.hidden_dropout + + def forward( + self, hidden_states: torch.Tensor, residual: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[ + :, int(i * slices) : int((i + 1) * slices) + ], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + + # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + intermediate_output += residual + + return intermediate_output + + +class BloomBlock(nn.Module): + def __init__(self, layer_id: int, config: BloomConfig, weights): + super().__init__() + + prefix = f"h.{layer_id}" + self.input_layernorm = LayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.num_heads = config.n_head + self.self_attention = BloomAttention( + prefix=f"{prefix}.self_attention", config=config, weights=weights + ) + self.post_attention_layernorm = LayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) + self.hidden_dropout = config.hidden_dropout + + def forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class BloomPreTrainedModel(PreTrainedModel): + config_class = BloomConfig + base_model_prefix = "transformer" + _no_split_modules = ["BloomBlock"] + + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + @staticmethod + def _convert_to_bloom_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + +class BloomModel(BloomPreTrainedModel): + def __init__(self, config: BloomConfig, weights): + super().__init__(config) + + self.embed_dim = config.hidden_size + self.num_heads = config.n_head + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.word_embeddings = TensorParallelEmbedding( + prefix="word_embeddings", weights=weights + ) + + self.word_embeddings_layernorm = LayerNorm.load( + prefix="word_embeddings_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + # Transformer blocks + self.h = nn.ModuleList( + [ + BloomBlock(layer_id=layer_id, config=config, weights=weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Final Layer Norm + self.ln_f = LayerNorm.load( + prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon + ) + + def _prepare_attn_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, + ) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = build_alibi_tensor(attention_mask, self.num_heads) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + if hasattr(self, "tp_rank"): + assert self.num_heads % self.tp_world_size == 0 + block_size = self.num_heads // self.tp_world_size + alibi = alibi[ + :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size + ] + alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + else: + alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past) + causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0) + + alibi = alibi.to(hidden_states.dtype) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class BloomForCausalLM(BloomPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + self.transformer = BloomModel(config, weights) + + self.lm_head = TensorParallelHead.load( + config, + prefix="word_embeddings", + weights=weights, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + loss = None + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f4116937dc3..a33c6c2d024 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -30,21 +30,24 @@ import dropout_layer_norm from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, + TensorParallelHead, ) class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, prefix, weights, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + + weight = weights.get_tensor(f"{prefix}.weight") + # assert weight.shape == (hidden_size,) + self.weight = nn.Parameter(weight) self.variance_epsilon = eps def forward(self, hidden_states, residual=None): @@ -91,35 +94,32 @@ def forward(self, hidden_states, residual=None): class FlashLlamaAttention(torch.nn.Module): def __init__( self, - num_heads, - hidden_size, - process_group=None, + prefix: str, + config, + weights, ): super().__init__() - self.num_heads = num_heads - self.hidden_size = hidden_size - self.head_size = hidden_size // num_heads + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) - self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) - else: - self.num_heads = self.num_heads // process_group.size() - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - 3 * hidden_size, - bias=False, - process_group=process_group, - ) - self.o_proj = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=False, - process_group=process_group, - ) + self.num_heads = self.num_heads // weights.process_group.size() + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) def forward( self, @@ -195,8 +195,9 @@ def forward( class LlamaMLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__(self, prefix, config, weights): super().__init__() + act = config.hidden_act self.act = ( ACT2FN[act] if "gelu" not in act @@ -207,32 +208,23 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None): else "none", ) ) - - if process_group is None: - # Fuse gate and up proj - self.gate_up_proj = FastLinear( - hidden_size, 2 * intermediate_size, bias=False - ) - self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) - self.intermediate_size = intermediate_size - else: - # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear( - hidden_size, - 2 * intermediate_size, - bias=False, - process_group=process_group, - ) - self.down_proj = TensorParallelRowLinear( - intermediate_size, - hidden_size, - bias=False, - process_group=process_group, - reduce=True, - ) - self.intermediate_size = self.down_proj.in_features - - self.process_group = process_group + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) def forward(self, hidden_states): gate_up_states = self.gate_up_proj(hidden_states) @@ -241,22 +233,22 @@ def forward(self, hidden_states): class FlashLlamaLayer(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rms_norm_eps, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashLlamaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) - self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) - - self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) def forward( self, @@ -295,54 +287,48 @@ def forward( class FlashLlamaModel(torch.nn.Module): - def __init__(self, config, process_group=None): - super(FlashLlamaModel, self).__init__() + def __init__(self, config, weights): + super().__init__() self.config = config self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True if self.tp_embeddings: self.embed_tokens = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group + prefix="model.embed_tokens", weights=weights ) else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights) self.layers = nn.ModuleList( [ FlashLlamaLayer( - config.num_attention_heads, - config.hidden_act, - config.hidden_size, - config.intermediate_size, - config.rms_norm_eps, - process_group, + # config.num_attention_heads, + # config.hidden_act, + # config.hidden_size, + # config.intermediate_size, + layer_id, + config, + weights, + # config.rms_norm_eps, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = LlamaRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) self.gradient_checkpointing = False self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.embed_tokens, TensorParallelEmbedding): - self.embed_tokens.add_null_idx() - for layer in self.layers: - layer: FlashLlamaLayer - layer.self_attn.query_key_value.prepare_weights(quantize) - layer.self_attn.o_proj.prepare_weights(quantize) - layer.mlp.gate_up_proj.prepare_weights(quantize) - layer.mlp.down_proj.prepare_weights(quantize) - def forward( self, input_ids, @@ -410,29 +396,15 @@ def forward( class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 - - self.model = FlashLlamaModel(config, process_group) - - if self.model.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.model.post_load_weights(quantize) - self.lm_head.prepare_weights() + self.model = FlashLlamaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) def forward( self, @@ -457,12 +429,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.model.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b798750a744..c1273267f32 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -31,63 +31,74 @@ import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, + get_linear, ) +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias, config.quantize) + if config.use_parallel_residual: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + + weight = weight.view( + num_heads, 3, head_size, hidden_size, + ).permute(1, 0, 2, 3).reshape(-1, hidden_size) + bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) + + + linear = get_linear(weight, bias, config.quantize) + if config.use_parallel_residual: + return linear + else: + return TensorParallelColumnLinear(linear) + + class FlashNeoxAttention(torch.nn.Module): - def __init__( - self, - num_heads, - hidden_size, - rotary_pct, - rotary_emb_base, - process_group=None, - reduce=True, - ): + def __init__(self, config, prefix, weights): super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + rotary_pct = config.rotary_pct + rotary_emb_base = config.rotary_emb_base + self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.num_heads = self.num_heads // weights.process_group.size() rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")) self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) - self.dense = FastLinear(hidden_size, hidden_size) - else: - self.num_heads = self.num_heads // process_group.size() - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - 3 * hidden_size, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, hidden_size, process_group=process_group, reduce=reduce - ) - - def shuffle_qkv_dims(self): - """Swap dims to avoid an additional permute""" - self.query_key_value.weight = torch.nn.Parameter( - self.query_key_value.weight.view( - self.num_heads, 3, self.head_size, self.hidden_size - ) - .permute(1, 0, 2, 3) - .reshape(-1, self.hidden_size) + self.query_key_value = load_qkv( + config, prefix=f"{prefix}.query_key_value", weights=weights, + num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size ) - self.query_key_value.bias = torch.nn.Parameter( - self.query_key_value.bias.view(self.num_heads, 3, self.head_size) - .permute(1, 0, 2) - .reshape(-1) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=True ) - + def forward( self, hidden_states, @@ -162,10 +173,9 @@ def forward( class FlashMLP(nn.Module): - def __init__( - self, act, hidden_size, intermediate_size, process_group=None, reduce=True - ): + def __init__(self, config, prefix, weights): super().__init__() + act = config.hidden_act self.act = ( ACT2FN[act] if "gelu" not in act @@ -177,22 +187,12 @@ def __init__( ) ) - if process_group is None: - self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) - self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size) - else: - self.dense_h_to_4h = TensorParallelColumnLinear( - hidden_size, - intermediate_size, - process_group=process_group, - ) - self.dense_4h_to_h = TensorParallelRowLinear( - intermediate_size, - hidden_size, - process_group=process_group, - reduce=reduce, - ) - self.process_group = process_group + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True + ) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -202,38 +202,28 @@ def forward(self, hidden_states): class FlashNeoXLayer(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() - self.use_parallel_residual = use_parallel_residual - self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.attention = FlashNeoxAttention( - num_heads, - hidden_size, - rotary_pct, - rotary_emb_base, - process_group, - reduce=not use_parallel_residual, + + layer_norm_eps = config.layer_norm_eps + + prefix = f"gpt_neox.layers.{layer_id}" + + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=layer_norm_eps, ) - self.mlp = FlashMLP( - act, - hidden_size, - intermediate_size, - process_group, - reduce=not use_parallel_residual, + self.attention = FlashNeoxAttention( + config, prefix=f"{prefix}.attention", weights=weights ) - self.process_group = process_group + + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) + self.process_group = weights.process_group def forward( self, @@ -302,42 +292,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.embed_in = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.embed_in = TensorParallelEmbedding( + prefix="gpt_neox.embed_in", weights=weights + ) self.layers = nn.ModuleList( [ - FlashNeoXLayer( - config.num_attention_heads, - config.hidden_act, - config.hidden_size, - config.intermediate_size, - config.rotary_pct, - config.rotary_emb_base, - config.layer_norm_eps, - config.use_parallel_residual, - process_group, - ) - for _ in range(config.num_hidden_layers) + FlashNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) ] ) - self.final_layer_norm = FastLayerNorm( - config.hidden_size, eps=config.layer_norm_eps + self.final_layer_norm = FastLayerNorm.load( + prefix="gpt_neox.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, ) self.gradient_checkpointing = False @@ -345,29 +317,6 @@ def __init__(self, config, process_group=None): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.embed_in, TensorParallelEmbedding): - self.embed_in.add_null_idx() - for layer in self.layers: - layer: FlashNeoXLayer - layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.prepare_weights(quantize) - layer.attention.dense.prepare_weights(quantize) - layer.mlp.dense_h_to_4h.prepare_weights(quantize) - layer.mlp.dense_4h_to_h.prepare_weights(quantize) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs - ) - - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model - def forward( self, input_ids, @@ -435,42 +384,13 @@ def forward( class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) + self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 - - self.gpt_neox = FlashGPTNeoXModel(config, process_group) - - if self.gpt_neox.tp_embeddings: - self.embed_out = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.embed_out = FastLinear( - config.hidden_size, config.vocab_size, bias=False - ) - - def post_load_weights(self, quantize: Optional[str] = None): - self.gpt_neox.post_load_weights(quantize) - self.embed_out.prepare_weights() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.embed_out = TensorParallelHead.load( + config, prefix="embed_out", weights=weights ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model def forward( self, @@ -495,12 +415,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.embed_out(hidden_states) - - if self.gpt_neox.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index b61ec8733c3..21b3f039aed 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,44 +3,139 @@ from torch import nn from transformers.activations import ACT2FN -from typing import Optional +from typing import Optional, List # Flash attention imports import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, + TensorParallelHead, TensorParallelEmbedding, FastLayerNorm, + get_linear, ) + +def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): + if any("c_attn" in k for k in weights.routing.keys()): + slice_ = weights._get_slice(f"{prefix}.c_attn.weight") + shape = slice_.get_shape() + world_size = weights.process_group.size() + rank = weights.process_group.rank() + if config.transpose: + block_size = (shape[1] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[1] - 2 * head_size) % world_size == 0 + q_tensor = slice_[:, start:stop] + kv_tensor = slice_[:, -2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=1).T + else: + block_size = (shape[0] - 2 * head_size) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + weight = torch.cat([q_tensor, kv_tensor], dim=0) + if bias: + slice_ = weights._get_slice(f"{prefix}.c_attn.bias") + shape = slice_.get_shape() + block_size = (shape[0] - 2 * head_size) // world_size + assert (shape[0] - 2 * head_size) % world_size == 0 + q_tensor = slice_[start:stop] + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = slice_[start:stop] + kv_tensor = slice_[-2 * head_size :] + bias = torch.cat([q_tensor, kv_tensor], dim=0) + else: + if config.transpose: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, + weights.get_tensor(f"{prefix}.kv_attn.weight").T + ] + weight = torch.cat(w, dim=0) + else: + w = [ + weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.weight") + ] + weight = torch.cat(w, dim=1) + + if bias: + b = [ + weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), + weights.get_tensor(f"{prefix}.kv_attn.bias") + ] + bias = torch.cat(b, dim=0) + else: + bias = None + + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" + if bias is not None: + bias = bias.to(dtype=weights.dtype).to(device=weights.device) + assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def load_col(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + +def load_row(config, prefix: str, weights, bias: bool): + if config.transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=0).T + else: + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) + + class FlashMQAttention(torch.nn.Module): - def __init__( - self, - num_heads, - hidden_size, - process_group=None, - ): + def __init__(self, prefix, config, weights): super().__init__() + num_heads = config.num_attention_heads + hidden_size = config.hidden_size + self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + assert self.num_heads % weights.process_group.size() == 0 + self.num_heads = self.num_heads // weights.process_group.size() + self.softmax_scale = self.head_size ** (-0.5) - if process_group is None: - self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) - self.c_proj = FastLinear(hidden_size, hidden_size) - else: - self.num_heads = self.num_heads // process_group.size() - self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) - self.c_proj = TensorParallelRowLinear( - hidden_size, - hidden_size, - process_group=process_group, - ) + self.c_attn = load_multi_mqa( + config, + prefix=prefix, + weights=weights, + bias=True, + head_size=self.head_size, + hidden_size=hidden_size, + num_heads=self.num_heads + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) def forward( self, @@ -121,8 +216,9 @@ def forward( class MLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__(self, prefix, config, weights): super().__init__() + act = config.activation_function self.act = ( ACT2FN[act] if "gelu" not in act @@ -134,20 +230,12 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None): ) ) - if process_group is None: - self.c_fc = FastLinear(hidden_size, intermediate_size) - self.c_proj = FastLinear(intermediate_size, hidden_size) - else: - self.c_fc = TensorParallelColumnLinear( - hidden_size, - intermediate_size, - process_group=process_group, - ) - self.c_proj = TensorParallelRowLinear( - intermediate_size, - hidden_size, - process_group=process_group, - ) + self.c_fc = load_col( + config, prefix=f"{prefix}.c_fc", weights=weights, bias=True + ) + self.c_proj = load_row( + config, prefix=f"{prefix}.c_proj", weights=weights, bias=True + ) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) @@ -157,28 +245,24 @@ def forward(self, hidden_states): class Block(nn.Module): - def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - layer_norm_eps, - process_group=None, - ): + def __init__(self, layer_id, config, weights): super().__init__() - self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + self.ln_1 = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + self.ln_2 = FastLayerNorm.load( + prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon + ) self.attn = FlashMQAttention( - num_heads, - hidden_size, - process_group, + prefix=f"{prefix}.attn", + config=config, + weights=weights, ) self.mlp = MLP( - act, - hidden_size, - intermediate_size, - process_group, + prefix=f"{prefix}.mlp", + config=config, + weights=weights, ) def forward( @@ -210,66 +294,39 @@ def forward( class FlashSantacoderModel(nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() self.config = config - self.process_group = process_group - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.wte = TensorParallelEmbedding( - config.vocab_size, - config.hidden_size, - reduce=False, - process_group=process_group, - ) - self.wpe = TensorParallelEmbedding( - config.max_position_embeddings, - config.hidden_size, - reduce=False, - process_group=process_group, - ) - else: - self.wte = nn.Embedding(config.vocab_size, config.hidden_size) - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.process_group = weights.process_group + self.wte = TensorParallelEmbedding( + prefix="transformer.wte", + weights=weights, + reduce=False, + ) + self.wpe = TensorParallelEmbedding( + prefix="transformer.wpe", + weights=weights, + reduce=False, + ) self.h = nn.ModuleList( [ Block( - config.num_attention_heads, - config.activation_function, - config.hidden_size, - config.n_inner - if config.n_inner is not None - else 4 * config.hidden_size, - config.layer_norm_epsilon, - process_group, + layer_id, + config, + weights, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) - self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon + ) self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self, quantize: Optional[str] = None): - if self.tp_embeddings: - self.wte.add_null_idx() - self.wpe.add_null_idx() - for layer in self.h: - layer: Block - layer.attn.c_attn.prepare_weights(quantize) - layer.attn.c_proj.prepare_weights(quantize) - layer.mlp.c_fc.prepare_weights(quantize) - layer.mlp.c_proj.prepare_weights(quantize) - def forward( self, input_ids, @@ -281,8 +338,7 @@ def forward( pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) - if self.tp_embeddings: - torch.distributed.all_reduce(hidden_states, group=self.process_group) + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: @@ -331,23 +387,12 @@ def forward( class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__() - - self.transformer = FlashSantacoderModel(config, process_group) - - if self.transformer.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.transformer.post_load_weights(quantize) - self.lm_head.prepare_weights() + self.transformer = FlashSantacoderModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, prefix="transformer.wte", weights=weights + ) def forward( self, @@ -372,29 +417,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.transformer.tp_embeddings: - # Logits are sharded, so we need to gather them - if logits.shape[0] == 1: - # Fast path when batch size is 1 - world_logits = logits.new_empty( - (logits.shape[1] * self.transformer.tp_world_size) - ) - torch.distributed.all_gather_into_tensor( - world_logits, logits.view(-1), group=self.transformer.process_group - ) - world_logits = world_logits.view(1, -1) - else: - # We cannot use all_gather_into_tensor as it only support concatenating on the first dim - world_logits = [ - torch.empty_like(logits) - for _ in range(self.transformer.tp_world_size) - ] - torch.distributed.all_gather( - world_logits, logits, group=self.transformer.process_group - ) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present - return logits, present diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py new file mode 100644 index 00000000000..03fded50c21 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -0,0 +1,837 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers import OPTConfig +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + +EPS = 1e-5 + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full( + (tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device, + ) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class OPTLearnedPositionalEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, weights): + super().__init__() + self.offset = 2 + self.weight = nn.Parameter( + weights.get_tensor("model.decoder.embed_positions.weight") + ) + + def forward( + self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return torch.nn.functional.embedding(positions + self.offset, self.weight) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config, + prefix, + weights, + is_decoder: bool = False, + bias: bool = True, + process_group=None, + ): + super().__init__() + embed_dim = config.embed_dim + num_heads = config.num_attention_heads + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = config.dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + process_group = weights.process_group + assert self.num_heads % process_group.size() == 0 + self.num_heads = self.num_heads // process_group.size() + self.embed_dim = self.embed_dim // process_group.size() + + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias + ) + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + def __init__(self, layer_id: int, config: OPTConfig, weights): + super().__init__() + self.process_group = weights.process_group + self.embed_dim = config.hidden_size + prefix = f"model.decoder.layers.{layer_id}" + self.self_attn = OPTAttention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + is_decoder=True, + bias=config.enable_bias, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS + ) + self.fc1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias + ) + self.fc2 = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig + + +class OPTDecoder(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, weights): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = TensorParallelEmbedding( + prefix="model.decoder.embed_tokens", weights=weights + ) + self.embed_positions = OPTLearnedPositionalEmbedding(weights) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = FastLinear.load( + config, prefix="model.decoder.project_out", bias=False + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = FastLinear.load( + config, prefix="model.decoder.project_in", bias=False + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm.load( + prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + ) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [ + OPTDecoderLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + causal_attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig, weights): + super().__init__(config) + self.decoder = OPTDecoder(config, weights) + # Initialize weights and apply final processing + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + + self.model = OPTModel(config, weights) + + self.lm_head = TensorParallelHead.load( + config, prefix="model.decoder.embed_tokens", weights=weights + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py new file mode 100644 index 00000000000..6fa09b09d3c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -0,0 +1,1177 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + +import copy +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + is_torch_fx_proxy, +) +from transformers import T5Config +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + + +class PartialTPEmbedding(nn.Module): + def __init__(self, prefix: str, weights): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + self.weight = nn.Parameter(weight) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.embedding(input, self.weight) + + +@torch.jit.script +def layer_norm(hidden_states, weight, epsilon): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(weight.dtype) + + return weight * hidden_states + + +class T5LayerNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = torch.tensor(eps) + + def forward(self, hidden_states): + return layer_norm(hidden_states, self.weight, self.variance_epsilon) + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi", weights=weights, bias=False + ) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + self.wi_0 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_0", weights=weights, bias=False + ) + self.wi_1 = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.wi_1", weights=weights, bias=False + ) + self.wo = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.wo", weights=weights, bias=False + ) + + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ( + ACT2FN[config.dense_act_fn] + if "gelu" not in config.dense_act_fn + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # TODO Support this again mayber + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + # if ( + # isinstance(self.wo.weight, torch.Tensor) + # and hidden_states.dtype != self.wo.weight.dtype + # and self.wo.weight.dtype != torch.int8 + # ): + # hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config, prefix, weights): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + else: + self.DenseReluDense = T5DenseActDense( + config, prefix=f"{prefix}.DenseReluDense", weights=weights + ) + + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, config: T5Config, prefix, weights, has_relative_attention_bias=False + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + process_group = weights.process_group + # Mesh TensorFlow initialization to avoid scaling before softmax + assert self.n_heads % process_group.size() == 0 + self.q = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q", weights=weights, bias=False + ) + self.k = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k", weights=weights, bias=False + ) + self.v = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v", weights=weights, bias=False + ) + self.o = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.o", weights=weights, bias=False + ) + self.n_heads = self.n_heads // process_group.size() + self.inner_dim = self.inner_dim // process_group.size() + + if self.has_relative_attention_bias: + self.relative_attention_bias = PartialTPEmbedding( + prefix=f"{prefix}.relative_attention_bias", weights=weights + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, + prefix=f"{prefix}.SelfAttention", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.EncDecAttention = T5Attention( + config, + prefix=f"{prefix}.EncDecAttention", + weights=weights, + has_relative_attention_bias=False, + ) + self.layer_norm = T5LayerNorm( + prefix=f"{prefix}.layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + prefix=f"{prefix}.layer.0", + weights=weights, + has_relative_attention_bias=has_relative_attention_bias, + ) + ) + if self.is_decoder: + i = 2 + self.layer.append( + T5LayerCrossAttention( + config, prefix=f"{prefix}.layer.1", weights=weights + ) + ) + else: + i = 1 + + self.layer.append( + T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + base_model_prefix = "transformer" + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, prefix, weights, embed_tokens): + super().__init__(config) + + self.is_decoder = config.is_decoder + + self.embed_tokens = embed_tokens + self.block = nn.ModuleList( + [ + T5Block( + config, + prefix=f"{prefix}.block.{layer_id}", + weights=weights, + has_relative_attention_bias=(layer_id == 0), + ) + for layer_id in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + prefix=f"{prefix}.final_layer_norm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class T5ForConditionalGeneration(T5PreTrainedModel): + def __init__(self, config: T5Config, weights): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack( + config=encoder_config, + prefix="encoder", + weights=weights, + embed_tokens=self.shared, + ) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack( + config=decoder_config, + prefix="decoder", + weights=weights, + embed_tokens=self.shared, + ) + + self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index fe28580df77..eb216a2067d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,154 +1,25 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from pathlib import Path -from safetensors import safe_open from transformers import AutoConfig from transformers.models.llama import LlamaTokenizer -from typing import Optional, List +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, - LocalEntryNotFoundError, + Weights, ) tracer = trace.get_tracer(__name__) class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = LlamaTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - # We do not use from_pretrained as we modified the model internal module layout - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) - - with init_empty_weights(): - model = FlashLlamaForCausalLM(config) - - self.load_weights(model, filenames, quantize, device, dtype) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - ): - for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_proj" in key or "k_proj" in key or "v_proj" in key: - final_key = layer_name + ".query_key_value.weight" - - # Fused gate and up projs - elif "gate_proj" in key or "up_proj" in key: - final_key = layer_name + ".gate_up_proj.weight" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "query_key_value" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 3, value.shape[1]) - ) - # Init gate and up proj - elif "gate_up_proj" in final_key: - module._parameters[param_name] = value.new_empty( - (value.shape[0] * 2, value.shape[1]) - ) - - # Copy to correct slice - if "q_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "k_proj" in key: - module._parameters[param_name][ - value.shape[0] : value.shape[0] * 2 - ] = value - elif "v_proj" in key: - module._parameters[param_name][value.shape[0] * 2 :] = value - elif "gate_proj" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "up_proj" in key: - module._parameters[param_name][value.shape[0] :] = value - else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value - - del value - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - -class FlashLlamaSharded(FlashLlama): def __init__( self, model_id: str, @@ -176,24 +47,16 @@ def __init__( ) torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) - with init_empty_weights(): - model = FlashLlamaForCausalLM(config, process_group=self.process_group) + config.quantize = quantize + model = FlashLlamaForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( - model=model.to(device), + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, @@ -201,114 +64,3 @@ def __init__( rank=rank, world_size=world_size, ) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - slice_ = f.get_slice(name) - - layer_name = ".".join(name.split(".")[:4]) - - # Fused qkv - if "q_proj" in name or "k_proj" in name or "v_proj" in name: - final_name = layer_name + ".query_key_value.weight" - - # Fused gate and up projs - elif "gate_proj" in name or "up_proj" in name: - final_name = layer_name + ".gate_up_proj.weight" - else: - final_name = name - - module_name, param_name = final_name.rsplit(".", 1) - module = model.get_submodule(module_name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight" and model.model.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - tensor = tensor.contiguous().to(dtype) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "query_key_value" in final_name: - module._parameters[param_name] = tensor.new_empty( - (tensor.shape[0] * 3, tensor.shape[1]) - ) - # Init gate and up proj - elif "gate_up_proj" in final_name: - module._parameters[param_name] = tensor.new_empty( - (tensor.shape[0] * 2, tensor.shape[1]) - ) - - # Init gate and up proj - if "q_proj" in name: - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "k_proj" in name: - module._parameters[param_name][ - tensor.shape[0] : tensor.shape[0] * 2 - ] = tensor - elif "v_proj" in name: - module._parameters[param_name][ - tensor.shape[0] * 2 : - ] = tensor - elif "gate_proj" in name: - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "up_proj" in name: - module._parameters[param_name][tensor.shape[0] :] = tensor - else: - if current_parameter_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - module._parameters[param_name] = tensor - - else: - module._buffers[param_name] = tensor - - torch.cuda.empty_cache() - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 31ae7914d74..4847571d140 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -1,45 +1,24 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) tracer = trace.get_tracer(__name__) -class FlashNeoX(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - super(FlashNeoX, self).__init__( - FlashGPTNeoXForCausalLM, - model_id, - revision, - quantize, - trust_remote_code=trust_remote_code, - ) - - -class FlashNeoXSharded(FlashNeoX): +class FlashNeoXSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -65,23 +44,16 @@ def __init__( config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = FlashGPTNeoXForCausalLM(config, self.process_group) + model = FlashGPTNeoXForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -92,79 +64,3 @@ def __init__( rank=rank, world_size=world_size, ) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e1c893d01aa..54634e4ab86 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -1,197 +1,24 @@ import torch import torch.distributed -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open -from pathlib import Path -from transformers import AutoTokenizer, GPT2Config +from transformers import AutoTokenizer, AutoConfig from typing import Optional, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, - LocalEntryNotFoundError, + Weights, ) tracer = trace.get_tracer(__name__) -class FlashSantacoder(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("FlashSantacoder is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = GPT2Config.from_pretrained( - model_id, - revision=revision, - ) - - # We do not use from_pretrained as we modified the model internal module layout - filenames = weight_files(model_id, revision, ".safetensors") - - with init_empty_weights(): - model = FlashSantacoderForCausalLM(config) - - self.load_weights( - model, - filenames, - quantize, - device, - dtype, - config.architectures[0].startswith("GPT2"), - ) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model: FlashSantacoderForCausalLM, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - transpose: bool, - ): - for filename in filenames: - with safe_open( - filename, - framework="pt", - device=str(device) if quantize is None else "cpu", - ) as f: - for key in f.keys(): - value = f.get_tensor(key) - value = value.to(device if quantize is None else "cpu").to(dtype) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" - - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - value = value.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - value.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = value.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn.weight" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "q_attn.bias" in key: - module._parameters[param_name][: value.shape[0]] = value - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = value - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = value - else: - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - else: - module._buffers[param_name] = value - - del value - - if model.lm_head.weight.device == torch.device("meta"): - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model : {uninitialized_parameters}" - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - -class FlashSantacoderSharded(FlashSantacoder): +class FlashSantacoderSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -214,28 +41,22 @@ def __init__( trust_remote_code=trust_remote_code, ) - config = GPT2Config.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, + trust_remote_code=True, ) + config.quantize = quantize + config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = FlashSantacoderForCausalLM(config, self.process_group) + model = FlashSantacoderForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - transpose=config.architectures[0].startswith("GPT2"), - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -247,164 +68,8 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - transpose: bool, - ): - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for key in f.keys(): - slice_ = f.get_slice(key) - - layer_name = ".".join(key.split(".")[:4]) - - # Fused qkv - if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".c_attn.weight" - elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".c_attn.bias" - else: - final_key = key - - module_name, param_name = final_key.rsplit(".", 1) - module = model.get_submodule(module_name) - - if isinstance(module, TensorParallelColumnLinear): - dim = 1 if transpose and "weight" in param_name else 0 - size = slice_.get_shape()[dim] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = ( - slice_[start:stop] if dim == 0 else slice_[:, start:stop] - ) - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - dim = 0 if transpose else 1 - size = slice_.get_shape()[dim] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = ( - slice_[start:stop] - if dim == 0 - else slice_[:, start:stop] - ) - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif key == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(key) - - tensor = tensor.contiguous().to(dtype) - - try: - current_parameter_tensor = module._parameters[param_name] - except KeyError: - current_parameter_tensor = None - - if current_parameter_tensor is not None: - if transpose and ( - "c_fc.weight" in key - or "c_proj.weight" in key - or "q_attn.weight" in key - or "kv_attn.weight" in key - or "c_attn.weight" in key - ): - # Tranpose as we use nn.Linear instead of Conv1D - tensor = tensor.T - - if current_parameter_tensor.device == torch.device("meta"): - # Init qkv - if "c_attn.weight" in final_key: - module._parameters[param_name] = tensor.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2), - tensor.shape[1], - ) - ) - elif "c_attn.bias" in final_key: - module._parameters[param_name] = tensor.new_empty( - ( - model.transformer.head_size - * (model.transformer.num_heads + 2) - ) - ) - - # Copy to correct slice - if "q_attn" in key: - size = tensor.shape[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = tensor[start:stop] - module._parameters[param_name][: tensor.shape[0]] = tensor - elif "kv_attn.weight" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = tensor - elif "kv_attn.bias" in key: - module._parameters[param_name][ - model.transformer.head_size - * model.transformer.num_heads : - ] = tensor - elif "c_attn" in key: - # Slice q_tensor by shard - q_tensor = tensor[: -2 * model.transformer.head_size] - block_size = q_tensor.shape[0] // world_size - start = rank * block_size - stop = (rank + 1) * block_size - q_tensor = q_tensor[start:stop] - - module._parameters[param_name][ - : q_tensor.shape[0] - ] = q_tensor - - # Kv tensor is copied for every shard - kv_tensor = tensor[-2 * model.transformer.head_size :] - module._parameters[param_name][ - q_tensor.shape[0] : - ] = kv_tensor - else: - if current_parameter_tensor.shape != tensor.shape: - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - if model.lm_head.weight.device == torch.device("meta"): - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - torch.cuda.empty_cache() - model.post_load_weights(quantize) + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 37ccc398a69..a907ee6c27a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -2,41 +2,25 @@ import torch import torch.distributed -from typing import List, Optional, Type, Tuple +from typing import List, Optional, Type -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase, ) -from transformers.models.opt.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.models.opt import OPT +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - - # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # we split individual characters inside special tokens like [START_DNA] @@ -168,33 +152,7 @@ def from_pb( ) -class Galactica(OPT): - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """Overwrite forward to ignore position_ids""" - - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - - -class GalacticaSharded(Galactica): +class GalacticaSharded(CausalLM): def __init__( self, model_id: str, @@ -228,22 +186,12 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) + model = OPTForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -255,127 +203,15 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name == "lm_head.weight": - continue - - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + @property + def batch_type(self) -> Type[CausalLMBatch]: + return GalacticaCausalLMBatch - module._parameters[param_name] = tensor - if name == "model.decoder.embed_tokens.weight": - model.lm_head._parameters["weight"] = tensor + def decode(self, generated_ids: List[int]) -> str: + # Do not skip special tokens as they are used for custom parsing rules of the generated text + return self.tokenizer.decode( + generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None @@ -386,10 +222,4 @@ def forward( past_key_values=past_key_values, use_cache=True, ) - - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - - return logits, outputs.past_key_values + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 9cc4d5e1a81..185937e6f24 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -1,52 +1,22 @@ import torch import torch.distributed -from typing import List, Optional, Tuple +from typing import Optional -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, ) -from transformers.models.opt.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models import CausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - - -class OPT(CausalLM): - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """Overwrite forward to ignore position_ids""" - - # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, outputs.past_key_values - -class OPTSharded(OPT): +class OPTSharded(CausalLM): def __init__( self, model_id: str, @@ -70,32 +40,21 @@ def __init__( trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, ) + config = AutoConfig.from_pretrained(model_id, revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) + model = OPTForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -107,128 +66,6 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name == "lm_head.weight": - continue - - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - module._parameters[param_name] = tensor - if name == "model.decoder.embed_tokens.weight": - model.lm_head._parameters["weight"] = tensor - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): @@ -239,9 +76,4 @@ def forward( use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - - return logits, outputs.past_key_values + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index d12b89d2b06..84465d48414 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -3,31 +3,20 @@ from typing import List, Optional, Tuple -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForSeq2SeqLM, AutoConfig, ) from text_generation_server.models import Seq2SeqLM +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -from transformers.models.t5.parallel_layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, -) - -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except ImportError as e: - HAS_BITS_AND_BYTES = False class T5Sharded(Seq2SeqLM): @@ -46,6 +35,9 @@ def __init__( device = torch.device("cpu") dtype = torch.float32 + config = AutoConfig.from_pretrained(model_id, revision=revision) + config.quantize = quantize + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, @@ -53,33 +45,16 @@ def __init__( truncation_side="left", trust_remote_code=trust_remote_code, ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) tokenizer.bos_token_id = config.decoder_start_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = AutoModelForSeq2SeqLM.from_config( - config, trust_remote_code=trust_remote_code - ) + model = T5ForConditionalGeneration(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( model=model, @@ -91,151 +66,6 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif "relative_attention_bias.weight" in name: - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous() - - # See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71 - if module_name.endswith("wo"): - tensor = tensor.to(torch.float32) - else: - tensor = tensor.to(dtype) - - if quantize == "bitsandbytes" and not module_name.endswith("wo"): - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq" and not module_name.endswith("wo"): - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None or module_name.endswith("wo"): - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - def forward( self, input_ids, @@ -260,13 +90,8 @@ def forward( use_cache=True, ) - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) - logits = torch.cat(logits, dim=2) - return ( - logits, + outputs.logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 6a351d66800..befedcf07a9 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,5 +1,6 @@ from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed +from text_generation_server.utils.weights import Weights from text_generation_server.utils.hub import ( weight_files, weight_hub_files, @@ -35,4 +36,5 @@ "StoppingCriteria", "StopSequenceCriteria", "FinishReason", + "Weights", ] diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 9785493eb8f..9be51f74012 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,34 @@ from datetime import timedelta +class FakeBarrier: + def wait(self): + pass + + +class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def allreduce(self, *args, **kwargs): + return FakeBarrier() + + def allgather(self, inputs, local_tensor, **kwargs): + assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + for input_ in inputs: + input_[0].data = local_tensor[0].data + return FakeBarrier() + + def barrier(self, *args, **kwargs): + return FakeBarrier() + + def size(self): + return self._size + + def rank(self): + return self._rank + def initialize_torch_distributed(): rank = int(os.getenv("RANK", "0")) @@ -23,13 +51,18 @@ def initialize_torch_distributed(): backend = "gloo" options = None - # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if world_size == 1: + return FakeGroup(rank, world_size), rank, world_size + else: + if os.getenv("DEBUG", None) == "1": + return FakeGroup(rank, world_size), rank, world_size + # Call the init process. + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=60), + pg_options=options, + ) - return torch.distributed.group.WORLD, rank, world_size + return torch.distributed.group.WORLD, rank, world_size diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 127f9ba4434..ea9a14695fc 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -2,175 +2,212 @@ from torch import nn from torch.nn import functional as F -from typing import Optional +from typing import Optional, List HAS_BITS_AND_BYTES = True try: - from bitsandbytes.nn import Linear8bitLt + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params + except ImportError as e: HAS_BITS_AND_BYTES = False +from accelerate import init_empty_weights + + +# Monkey patching +@classmethod +def load_layer_norm(cls, prefix, weights, eps): + weight = weights.get_tensor(f"{prefix}.weight") + bias = weights.get_tensor(f"{prefix}.bias") + with init_empty_weights(): + ln = cls(weight.shape, eps=eps) + + ln.weight = nn.Parameter(weight) + ln.bias = nn.Parameter(bias) + return ln + +torch.nn.LayerNorm.load = load_layer_norm + -class FastLinear(nn.Linear): +class FastLinear(nn.Module): def __init__( self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False - self.bnb_linear = None - - def prepare_weights(self, quantize: Optional[str] = None): - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - self.quantized = True - self.bnb_linear = Linear8bitLt( - self.in_features, - self.out_features, - has_fp16_weights=False, - threshold=6.0, - bias=False, - ) - # Copy data to bnb_linear - self.bnb_linear.weight.data = self.weight.data - if self.bias is not None: - self.bnb_linear.bias = nn.Parameter(self.bias) - - # Delete reference to data - self.weight = None + weight, bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: self.bias = None - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - self.weight = nn.Parameter(self.weight.T) + + @staticmethod + def load(config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") else: - raise ValueError(f"Unexpected quantize `{quantize}`") + bias = None + return FastLinear(weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: - return self.bnb_linear(input) - else: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) + return F.linear(input, self.weight, self.bias) + + +class Linear8bitLt(nn.Module): + def __init__(self, weight, bias, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None): + super().__init__() + assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + # Necessary for stacked layers + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self.weight.cuda(weight.device) + self.bias = bias + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out -class TensorParallelColumnLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, - ): - self.process_group = process_group - self.tp_world_size = process_group.size() - assert out_features % self.tp_world_size == 0 - out_features = out_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, +def get_linear(weight, bias, quantize): + if quantize is None: + linear = FastLinear(weight, bias) + elif quantize == "bitsandbytes": + linear = Linear8bitLt( + weight, bias, + has_fp16_weights=False, + threshold=6.0, ) - - -class TensorParallelRowLinear(FastLinear): - def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - reduce=True, - bias=True, - device=None, - dtype=None, - ): + if bias is not None: + linear.bias = nn.Parameter(bias) + elif quantize == "gptq": + raise NotImplementedError("Soon") + else: + raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.") + return linear + + +class SuperLayer(nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear.forward(x) + +class TensorParallelHead(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) self.process_group = process_group - self.tp_world_size = process_group.size() - self.reduce = reduce - assert in_features % self.tp_world_size == 0 - in_features = in_features // self.tp_world_size - - super().__init__( - in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - ) + + @staticmethod + def load(config, prefix: str, weights): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group) def forward(self, input: torch.Tensor) -> torch.Tensor: - out = super(TensorParallelRowLinear, self).forward(input) - if self.reduce: - torch.distributed.all_reduce(out, group=self.process_group) + output = super().forward(input) + # Logits are sharded, so we need to gather them + world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output + + +class TensorParallelColumnLinear(SuperLayer): + @staticmethod + def load(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) - return out + @staticmethod + def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + weight = torch.cat(w, dim=dim) + if bias: + b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] + bias = torch.cat(b, dim=0) + else: + bias = None + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) -class TensorParallelEmbedding(nn.Embedding): - def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - reduce=True, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, - ): - self.reduce = reduce + +class TensorParallelRowLinear(SuperLayer): + def __init__(self, linear, process_group): + super().__init__(linear) self.process_group = process_group - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - self.original_num_embeddings = num_embeddings + @staticmethod + def load(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) - assert num_embeddings % self.tp_world_size == 0 - block_size = num_embeddings // self.tp_world_size - # inputs in `[min_id, max_id[` are handled by `self` to get embeddings - self.min_id = self.tp_rank * block_size - self.max_id = (self.tp_rank + 1) * block_size + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out - # Additional entry that will map to zero - # Used for masking - self.null_idx = block_size +class TensorParallelEmbedding(nn.Module): + def __init__(self, prefix: str, weights, reduce=True): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) + process_group = weights.process_group + + world_size = process_group.size() + rank = process_group.rank() + + block_size = num_embeddings // world_size + self.min_id = rank * block_size + self.max_id = min(num_embeddings, (rank + 1) * block_size) + self.null_idx = block_size + self.process_group = weights.process_group + self.reduce = reduce - def add_null_idx(self): """Additional 0 entry used for masking""" - self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) def forward(self, input: torch.Tensor) -> torch.Tensor: # default all out of bounds values to `self.null_idx` that will then be mapped to 0 @@ -180,12 +217,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.null_idx, input - self.min_id, ) - out = super().forward(input) + out = torch.nn.functional.embedding(input, self.weight) if self.reduce: torch.distributed.all_reduce(out, group=self.process_group) return out - try: import dropout_layer_norm diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py new file mode 100644 index 00000000000..fc01d9379fd --- /dev/null +++ b/server/text_generation_server/utils/weights.py @@ -0,0 +1,78 @@ +from pathlib import Path +from typing import Optional, List +from safetensors import safe_open + +class Weights: + def __init__(self, filenames: List[Path], device, dtype, process_group): + routing = {} + for filename in filenames: + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + self.routing = routing + self.device = device + self.dtype = dtype + self.process_group = process_group + self._handles = {} + + def _get_handle(self, filename): + if filename not in self._handles: + f = safe_open(filename, framework="pytorch") + self._handles[filename] = f + + return self._handles[filename] + + + + def get_filename(self, tensor_name: str) -> str: + filename = self.routing.get(tensor_name, None) + if filename is None: + raise RuntimeError(f"weight {tensor_name} does not exist") + return filename + + def _get_slice(self, tensor_name: str): + filename = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + return slice_ + + def get_shape(self, tensor_name: str): + return self._get_slice(tensor_name).get_shape() + + def get_tensor(self, tensor_name: str): + filename = self.get_filename(tensor_name) + f = self._get_handle(filename) + tensor = f.get_tensor(tensor_name) + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + def get_sharded(self, tensor_name: str, dim: int): + filename = self.get_filename(tensor_name) + world_size = self.process_group.size() + rank = self.process_group.rank() + + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards" + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor + + From 2362a80a4faf08fc3d04b0719675f41181466a23 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 09:35:29 +0000 Subject: [PATCH 02/27] Black + ruff + T5 w0 quant. --- server/custom_kernels/setup.py | 13 ++-- .../models/custom_modeling/bloom_modeling.py | 1 - .../custom_modeling/flash_neox_modeling.py | 28 +++++--- .../flash_santacoder_modeling.py | 28 +++++--- .../models/custom_modeling/t5_modeling.py | 39 +++++++--- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/opt.py | 7 +- server/text_generation_server/utils/dist.py | 5 +- server/text_generation_server/utils/hub.py | 3 +- server/text_generation_server/utils/layers.py | 71 +++++++++++++------ .../text_generation_server/utils/weights.py | 11 ++- 11 files changed, 137 insertions(+), 71 deletions(-) diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index 62c720e1773..fa4382e99f9 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -1,15 +1,14 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension + setup( - name='custom_kernels', + name="custom_kernels", ext_modules=[ CUDAExtension( - name="custom_kernels.fused_bloom_attention_cuda", - sources=['custom_kernels/fused_bloom_attention_cuda.cu'], - extra_compile_args=["-arch=compute_80", "-std=c++17"], + name="custom_kernels.fused_bloom_attention_cuda", + sources=["custom_kernels/fused_bloom_attention_cuda.cu"], + extra_compile_args=["-arch=compute_80", "-std=c++17"], ) ], - cmdclass={ - 'build_ext': BuildExtension - } + cmdclass={"build_ext": BuildExtension}, ) diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 554cab9f4bf..e5e876455fa 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -37,7 +37,6 @@ TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelHead, - FastLinear ) CUSTOM_KERNELS_ENABLED = False diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c1273267f32..24004e8a3f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -60,12 +60,18 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_sharded(f"{prefix}.weight", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0) - weight = weight.view( - num_heads, 3, head_size, hidden_size, - ).permute(1, 0, 2, 3).reshape(-1, hidden_size) + weight = ( + weight.view( + num_heads, + 3, + head_size, + hidden_size, + ) + .permute(1, 0, 2, 3) + .reshape(-1, hidden_size) + ) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) if config.use_parallel_residual: return linear @@ -88,17 +94,23 @@ def __init__(self, config, prefix, weights): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) - self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")) + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( - config, prefix=f"{prefix}.query_key_value", weights=weights, - num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + num_heads=self.num_heads, + head_size=self.head_size, + hidden_size=self.hidden_size, ) self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) - + def forward( self, hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21b3f039aed..888a6066255 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,7 +3,7 @@ from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List +from typing import Optional # Flash attention imports import flash_attn_cuda @@ -17,8 +17,9 @@ ) - -def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): +def load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -55,30 +56,35 @@ def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_head if config.transpose: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, - weights.get_tensor(f"{prefix}.kv_attn.weight").T + weights.get_tensor(f"{prefix}.kv_attn.weight").T, ] weight = torch.cat(w, dim=0) else: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), - weights.get_tensor(f"{prefix}.kv_attn.weight") + weights.get_tensor(f"{prefix}.kv_attn.weight"), ] weight = torch.cat(w, dim=1) if bias: b = [ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), - weights.get_tensor(f"{prefix}.kv_attn.bias") + weights.get_tensor(f"{prefix}.kv_attn.bias"), ] bias = torch.cat(b, dim=0) else: bias = None weight = weight.to(dtype=weights.dtype).to(device=weights.device) - assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" + assert list(weight.shape) == [ + (num_heads + 2) * head_size, + hidden_size, + ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" if bias is not None: bias = bias.to(dtype=weights.dtype).to(device=weights.device) - assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}" + assert list(bias.shape) == [ + (num_heads + 2) * head_size + ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) @@ -106,7 +112,9 @@ def load_row(config, prefix: str, weights, bias: bool): bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), process_group=weights.process_group + ) class FlashMQAttention(torch.nn.Module): @@ -131,7 +139,7 @@ def __init__(self, prefix, config, weights): bias=True, head_size=self.head_size, hidden_size=hidden_size, - num_heads=self.num_heads + num_heads=self.num_heads, ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 6fa09b09d3c..c5ce9bfc633 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -109,9 +109,21 @@ def __init__(self, config: T5Config, prefix, weights): self.wi = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.wi", weights=weights, bias=False ) + + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) self.wo = TensorParallelRowLinear.load( config, prefix=f"{prefix}.wo", weights=weights, bias=False ) + weights.dtype = _dtype + config.quantize = _q self.dropout = nn.Dropout(config.dropout_rate) self.act = ( @@ -124,7 +136,10 @@ def forward(self, hidden_states): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) + hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -137,9 +152,20 @@ def __init__(self, config: T5Config, prefix, weights): self.wi_1 = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.wi_1", weights=weights, bias=False ) + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) self.wo = TensorParallelRowLinear.load( config, prefix=f"{prefix}.wo", weights=weights, bias=False ) + weights.dtype = _dtype + config.quantize = _q self.dropout = nn.Dropout(config.dropout_rate) self.act = ( @@ -154,18 +180,9 @@ def forward(self, hidden_states): hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) - # TODO Support this again mayber - # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - # if ( - # isinstance(self.wo.weight, torch.Tensor) - # and hidden_states.dtype != self.wo.weight.dtype - # and self.wo.weight.dtype != torch.int8 - # ): - # hidden_states = hidden_states.to(self.wo.weight.dtype) - + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) + hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5ab8a62425b..4d0e4730451 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -26,7 +26,7 @@ try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params -except Exception as e: +except Exception: HAS_BITS_AND_BYTES = False diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 185937e6f24..16cb48b7b2e 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -40,10 +40,11 @@ def __init__( trust_remote_code=trust_remote_code, ) - ) - config = AutoConfig.from_pretrained(model_id, revision=revision, + config = AutoConfig.from_pretrained( + model_id, + revision=revision, trust_remote_code=trust_remote_code, - ) + ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 9be51f74012..fe9c3b7bbd8 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ from datetime import timedelta + class FakeBarrier: def wait(self): pass @@ -17,7 +18,9 @@ def allreduce(self, *args, **kwargs): return FakeBarrier() def allgather(self, inputs, local_tensor, **kwargs): - assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" for input_ in inputs: input_[0].data = local_tensor[0].data return FakeBarrier() diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 2ed7673c41c..965cae9919a 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -10,8 +10,7 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, - EntryNotFoundError, - RevisionNotFoundError, # Import here to ease try/except in other part of the lib + EntryNotFoundError, # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ea9a14695fc..0146e5c3935 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -2,14 +2,14 @@ from torch import nn from torch.nn import functional as F -from typing import Optional, List +from typing import List HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params -except ImportError as e: +except ImportError: HAS_BITS_AND_BYTES = False from accelerate import init_empty_weights @@ -27,14 +27,16 @@ def load_layer_norm(cls, prefix, weights, eps): ln.bias = nn.Parameter(bias) return ln + torch.nn.LayerNorm.load = load_layer_norm class FastLinear(nn.Module): def __init__( self, - weight, bias, - ) -> None: + weight, + bias, + ) -> None: super().__init__() self.weight = nn.Parameter(weight) if bias is not None: @@ -44,9 +46,9 @@ def __init__( @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") + weight = weights.get_tensor(f"{prefix}.weight") if bias: - bias = weights.get_tensor(f"{prefix}.bias") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return FastLinear(weight, bias) @@ -56,10 +58,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class Linear8bitLt(nn.Module): - def __init__(self, weight, bias, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): super().__init__() - assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index @@ -70,7 +81,11 @@ def __init__(self, weight, bias, has_fp16_weights=True, if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) self.weight.cuda(weight.device) self.bias = bias @@ -105,7 +120,8 @@ def get_linear(weight, bias, quantize): linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": linear = Linear8bitLt( - weight, bias, + weight, + bias, has_fp16_weights=False, threshold=6.0, ) @@ -114,7 +130,9 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": raise NotImplementedError("Soon") else: - raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.") + raise NotImplementedError( + f"Quantization `{config.quantize}` is not implemented yet." + ) return linear @@ -126,6 +144,7 @@ def __init__(self, linear): def forward(self, x): return self.linear.forward(x) + class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group): super().__init__(linear) @@ -133,13 +152,18 @@ def __init__(self, linear, process_group): @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + return TensorParallelHead( + get_linear(weight, bias=None, quantize=config.quantize), + process_group=weights.process_group, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: output = super().forward(input) # Logits are sharded, so we need to gather them - world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -148,9 +172,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class TensorParallelColumnLinear(SuperLayer): @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) if bias: - bias = weights.get_sharded(f"{prefix}.bias", dim=0) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) @@ -175,23 +199,27 @@ def __init__(self, linear, process_group): @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process - bias = weights.get_tensor(f"{prefix}.bias") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) torch.distributed.all_reduce(out, group=self.process_group) return out + class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group @@ -222,6 +250,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(out, group=self.process_group) return out + try: import dropout_layer_norm diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index fc01d9379fd..2a410ca3df8 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,8 @@ from pathlib import Path -from typing import Optional, List +from typing import List from safetensors import safe_open + class Weights: def __init__(self, filenames: List[Path], device, dtype, process_group): routing = {} @@ -26,8 +27,6 @@ def _get_handle(self, filename): return self._handles[filename] - - def get_filename(self, tensor_name: str) -> str: filename = self.routing.get(tensor_name, None) if filename is None: @@ -63,7 +62,9 @@ def get_sharded(self, tensor_name: str, dim: int): start = rank * block_size stop = (rank + 1) * block_size - assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards" + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" if dim == 0: tensor = slice_[start:stop] @@ -74,5 +75,3 @@ def get_sharded(self, tensor_name: str, dim: int): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - - From 5c2a0e45551d61679377e009003f8324c2404911 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 09:46:46 +0000 Subject: [PATCH 03/27] Missing import. --- server/text_generation_server/utils/hub.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 965cae9919a..9443d21b96a 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -11,6 +11,7 @@ from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, # Import here to ease try/except in other part of the lib + RevisionNotFoundError ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) From 680f26d6b2b88e12ff79e895e0d29eb538b96340 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 10:10:16 +0000 Subject: [PATCH 04/27] Typo. --- server/text_generation_server/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3e181321e89..85f0d15c9b8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -147,7 +147,7 @@ def get_model( ) elif model_type == "gpt_neox": - if FLASH_ATTENTION or shard: + if FLASH_ATTENTION: return FlashNeoXSharded( model_id, revision, From e36e42a3f4c95f75f9f6316434dd0325b3914fb3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 11:53:09 +0000 Subject: [PATCH 05/27] T5? --- .../models/custom_modeling/t5_modeling.py | 3 +-- server/text_generation_server/models/t5.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index c5ce9bfc633..0a9e3b77409 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -729,7 +729,6 @@ class T5PreTrainedModel(PreTrainedModel): """ config_class = T5Config - base_model_prefix = "transformer" def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1021,7 +1020,7 @@ def __init__(self, config: T5Config, weights): embed_tokens=self.shared, ) - self.lm_head = TensorParallelHead.load(config, prefix="shared", weights=weights) + self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 84465d48414..e844c36f0d2 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -35,7 +35,9 @@ def __init__( device = torch.device("cpu") dtype = torch.float32 - config = AutoConfig.from_pretrained(model_id, revision=revision) + config = AutoConfig.from_pretrained(model_id, revision=revision, + trust_remote_code=trust_remote_code, + ) config.quantize = quantize tokenizer = AutoTokenizer.from_pretrained( From 55045be42f2571f263e6f3a66a9039f3772cd109 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 13:07:12 +0000 Subject: [PATCH 06/27] Neox (non flash) port + kernel. --- .../custom_kernels/fused_attention_cuda.cu | 250 +++++++ server/custom_kernels/setup.py | 5 + .../text_generation_server/models/__init__.py | 10 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../models/custom_modeling/neox_modeling.py | 707 ++++++++++++++++++ .../models/custom_modeling/t5_modeling.py | 7 + .../text_generation_server/models/gpt_neox.py | 200 +---- 7 files changed, 996 insertions(+), 187 deletions(-) create mode 100644 server/custom_kernels/custom_kernels/fused_attention_cuda.cu create mode 100644 server/text_generation_server/models/custom_modeling/neox_modeling.py diff --git a/server/custom_kernels/custom_kernels/fused_attention_cuda.cu b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu new file mode 100644 index 00000000000..60f9f0286f2 --- /dev/null +++ b/server/custom_kernels/custom_kernels/fused_attention_cuda.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include + +#include + +/** +* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda +* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu +**/ + +// Available in pytorch main +//#define DISPATCH_CASE_FLOATING_TYPES(...) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +/* +* Forward passes +*/ + +/** +* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype +**/ +template +__global__ void forward_masked_softmax_kernel( + const torch::PackedTensorAccessor32 attention_scores, // [B, KV] + const torch::PackedTensorAccessor32 mask, // [B, KV] + torch::PackedTensorAccessor32 result, // [B, KV] + const int64_t effective_kv_length, + const dim3 blockDim, + const int64_t rows_per_block, + const int64_t kv_length, + const int64_t batch_size +) { + const auto row_id = threadIdx.x / effective_kv_length; + const auto effective_kv_length_id = threadIdx.x % effective_kv_length; + const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread; + auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread; + kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_; + const auto kv_length_end = kv_length_end_; + + const auto batch_id = blockIdx.x * rows_per_block + row_id; + + // We need 2 float storage for each row, one for max computation, the other for normalizing exponential + extern __shared__ float temp_storage[]; + const auto row_id_mem_offset = row_id * 2; + if (effective_kv_length_id == 0) { + temp_storage[row_id_mem_offset] = -std::numeric_limits::infinity(); + temp_storage[row_id_mem_offset + 1] = 0; + } + __syncthreads(); + + // Compute mask and max + if (batch_id < batch_size) { + float thread_max = -std::numeric_limits::infinity(); + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + const float candidate = attention_scores[batch_id][kv_length_id]; + thread_max = (thread_max < candidate) ? candidate : thread_max; + } + } + if (thread_max != -std::numeric_limits::infinity()) { + // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max); + } + } + + __syncthreads(); + + // Compute exp(elt - max) masked + float exponential[min_kv_length_shard_size_per_thread]; + if (batch_id < batch_size) { + float thread_add = 0; + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + if (mask[batch_id][kv_length_id] == 0) { + exponential[kv_length_id - kv_length_start] = std::exp(static_cast(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]); + thread_add = thread_add + exponential[kv_length_id - kv_length_start]; + } else { + exponential[kv_length_id - kv_length_start] = 0.; + } + } + if (thread_add > 0) { + // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot + gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add); + } + } + + __syncthreads(); + + // Compute softmax + if (batch_id < batch_size) { + // If sum of all exponential is 0, we set the softmax values to 0 + if (temp_storage[row_id_mem_offset + 1] == 0.) { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = 0.; + } + } else { + for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) { + result[batch_id][kv_length_id] = static_cast(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]); + } + } + } +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::tuple>, at::Tensor> forward( + const at::Tensor query, + const at::Tensor key, + const at::Tensor value, + const std::optional> layer_past, + const at::Tensor attention_mask, + const std::optional head_mask, + const float inv_norm_factor, + const int num_heads, + const bool use_cache +) { + auto query_layer = query; + auto key_layer = key; + auto value_layer = value; + + if (layer_past) { + const auto past_key = (*layer_past).at(0); + const auto past_value = (*layer_past).at(1); + key_layer = at::cat({past_key, key_layer}, 2); + value_layer = at::cat({past_value, value_layer}, 2); + } + + std::optional> present; + if (use_cache) { + present = {key_layer, value_layer}; + } else { + present = {}; + } + + const auto batch_size = query_layer.size(0); + const auto q_length = query_layer.size(2); + const auto attn_head_size = query_layer.size(3); + const auto batch_size_times_num_heads = batch_size * num_heads; + const auto kv_length = key_layer.size(2); + + const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size}); + auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2); + auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}); + + auto query_scaled = query_view * inv_norm_factor; + auto attention_scores = at::bmm(query_scaled, key_view); + + // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype` + at::Tensor attention_probs; + if (true) { + // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors + const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length}); + const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length}); + + // Custom kernel + attention_probs = at::empty_like(attention_scores_2d); + + // Check that inputs and contiguous + cuda tensors + CHECK_INPUT(attention_scores_2d); + CHECK_INPUT(attention_mask_2d); + + // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out + // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] { + /* + * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/ + * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + * - SMs: 108 + * - TPCs: 56 (What's that?) + * - Memory size: 40 GB + * - L2 Cache size: 40960 KB (shared across all SMs) + * - L1/Shared memory size: 192 KB (shared across all threads within a SM) + * - Max Threads / SM: 2048 + * - Max Thread Blocks / SM: 32 + */ + + /* + * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block + * with multiple threads as we need to `sync_threads` to run exponential sum. + * We maximise the usage of threads within a single block + */ + // TODO @thomasw21 figure out everything warp related: + // - why do they have to be power of 2 + // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048 + const auto MAX_THREADS_PER_SM = 1024; + // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD` + const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4; + // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)` + const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1; + const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length; + const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1; + + const dim3 gridDim(num_blocks); // Number of blocks that run + const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block + const int shared_mem_forward = rows_per_block * 2 * sizeof(float); + + // 192 * 2 ** 10 + // const auto MAX_L1_MEMORY = 196608; + // const auto MAX_SMs = 108; + // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation."); + // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger."); + // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher."); + + forward_masked_softmax_kernel<<>>( + attention_scores_2d.packed_accessor32(), + attention_mask_2d.packed_accessor32(), + attention_probs.packed_accessor32(), + effective_kv_length, + blockDim, + rows_per_block, + kv_length, + batch_size_times_num_heads * q_length + ); + }); + attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length}); + } else { + // Pytorch C++ API + auto input_dtype = attention_scores.scalar_type(); + if (input_dtype == at::ScalarType::Float) { + attention_scores = attention_scores.to(at::ScalarType::Float); + }; + // TODO @thomasw21 Figure out how to get minimum value + auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34); + attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype); + } + + auto context_layer = attention_probs.bmm(value_view); + + // `_merge_heads` + context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size}); + context_layer = context_layer.permute({0, 2, 1, 3}); + context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads}); + + return std::make_tuple(context_layer, present, attention_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &forward, + "GPT-Neox attention mechanism forward (CUDA)" + ); +} diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index fa4382e99f9..fe45b63146f 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -8,6 +8,11 @@ name="custom_kernels.fused_bloom_attention_cuda", sources=["custom_kernels/fused_bloom_attention_cuda.cu"], extra_compile_args=["-arch=compute_80", "-std=c++17"], + ), + CUDAExtension( + name="custom_kernels.fused_attention_cuda", + sources=["custom_kernels/fused_attention_cuda.cu"], + extra_compile_args=["-arch=compute_80", "-std=c++17"], ) ], cmdclass={"build_ext": BuildExtension}, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 85f0d15c9b8..1f862c9eff5 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -15,6 +15,7 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded +from text_generation_server.models.gpt_neox import GPTNeoxSharded try: if torch.cuda.is_available(): @@ -147,7 +148,7 @@ def get_model( ) elif model_type == "gpt_neox": - if FLASH_ATTENTION: + if FLASH_ATTENTION and False: return FlashNeoXSharded( model_id, revision, @@ -155,7 +156,12 @@ def get_model( trust_remote_code=trust_remote_code, ) elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Neox")) + return GPTNeoxSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: return CausalLM( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 24004e8a3f3..d60fb848a46 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -268,9 +268,7 @@ def forward( mlp_output = self.mlp(ln2_hidden_states) intermediate = mlp_output + attn_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate + hidden_states, None else: diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py new file mode 100644 index 00000000000..1e20a4774e8 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -0,0 +1,707 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPTNeoX model.""" + +from typing import Optional, Tuple, Union + +import os +import torch +import torch.distributed +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers import GPTNeoXConfig +from loguru import logger +from text_generation_server.utils.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelHead, +) + + + +CUSTOM_KERNELS_ENABLED = False +if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": + try: + from custom_kernels import fused_attention_cuda + + CUSTOM_KERNELS_ENABLED = True + except ImportError: + pass + +if not CUSTOM_KERNELS_ENABLED: + logger.warning("We're not using custom kernels.") + + + +def make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + mask = mask.triu(1 + past_key_values_length) + + expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length) + return expanded_mask + + +def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, tgt_length, src_length) + + +def prepare_attn_mask( + attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int +) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + + # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] + expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +class GPTNeoXPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + + +class GPTNeoXAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + max_positions = config.max_position_embeddings + # ??? TODO + # self.register_buffer( + # "bias", + # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + # 1, 1, max_positions, max_positions + # ), + # ) + # self.register_buffer("masked_bias", torch.tensor(-1e9)) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) + self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to( + torch.get_default_dtype() + ) + + assert self.num_attention_heads % weights.process_group.size() == 0 + self.num_attention_heads = self.num_attention_heads // weights.process_group.size() + self.query_key_value = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True + ) + self.dense = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense", weights=weights, bias=True + ) + + def forward( + self, + hidden_states, + position_ids, + attention_mask, + head_mask=None, + layer_past=None, + use_cache=False, + output_attentions=False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query, key, value = qkv.split(self.head_size, -1) + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + key_rot = key[..., : self.rotary_ndims] + + query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) + + query[..., : self.rotary_ndims] = query_rot + key[..., : self.rotary_ndims] = key_rot + + if CUSTOM_KERNELS_ENABLED: + attn_output, present, attn_weights = fused_attention_cuda.forward( + query, + key, + value, + layer_past, + attention_mask, + head_mask, + self.inv_norm_factor, + self.num_attention_heads, + use_cache, + ) + else: + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + 1, + dtype=query.dtype, + device=key.device, + ).expand(batch_size * num_attention_heads, query_length, key_length) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=self.inv_norm_factor, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attn_scores.dtype + if input_dtype in [torch.float16, torch.bfloat16]: + attn_scores = attn_scores.to(torch.float) + attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", self.true_inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + self.cos_cached = None + self.sin_cached = None + + @staticmethod + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + @staticmethod + def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): + t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) + + def forward(self, q, k, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None: + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.cos_cached, self.sin_cached = self._create_cos_sin( + self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device + ) + return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) + + +@torch.jit.script +def rotary_forward(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + + chunk_size = q.shape[-1] // 2 + q1, q2 = q.split(chunk_size, -1) + q_rotated = torch.cat((-q2, q1), dim=-1) + k1, k2 = k.split(chunk_size, -1) + k_rotated = torch.cat((-k2, k1), dim=-1) + + q_embed = (q * cos) + (q_rotated * sin) + k_embed = (k * cos) + (k_rotated * sin) + return q_embed, k_embed + + +class GPTNeoXMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.act = ( + ACT2FN[config.hidden_act] + if "gelu_fast" not in config.hidden_act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + self.dense_4h_to_h = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + ) + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights) + self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights) + + + def forward( + self, + hidden_states, + position_ids, + attention_mask=None, + head_mask=None, + use_cache=False, + layer_past=None, + output_attentions=False, + ): + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs + + +class GPTNeoXModel(GPTNeoXPreTrainedModel): + def __init__(self, config, weights): + super().__init__(config) + self.config = config + + self.num_attention_heads = config.num_attention_heads + + self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) + self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) + self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids=None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * self.config.num_hidden_layers) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + # Attention mask. + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-1] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + causal_mask = prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + if hasattr(self, "tp_rank"): + assert self.num_attention_heads % self.tp_world_size == 0 + block_size = self.num_attention_heads // self.tp_world_size + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) + else: + causal_mask = torch.repeat_interleave(causal_mask, self.num_attention_heads, dim=0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = layer( + hidden_states, + position_ids=position_ids, + attention_mask=causal_mask, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config, weights): + super().__init__(config) + self.gpt_neox = GPTNeoXModel(config, weights) + self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") + >>> config.is_decoder = True + >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + input_shape = input_ids.shape + + # cut decoder_input_ids if past is used + if past_key_values and past_key_values[0] is not None: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + ) + + return model_inputs + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 0a9e3b77409..afc043115da 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -840,6 +840,11 @@ def forward( ), "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) + + from safetensors.torch import save_file + save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors") + + batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past @@ -936,6 +941,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) + from safetensors.torch import save_file + save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors") # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 4d0e4730451..5c854348059 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -10,25 +10,16 @@ AutoModelForCausalLM, AutoConfig, ) -from transformers.models.gpt_neox.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - from text_generation_server.models import CausalLM +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, + Weights, ) -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception: - HAS_BITS_AND_BYTES = False - class GPTNeoxSharded(CausalLM): def __init__( @@ -58,28 +49,18 @@ def __init__( config = AutoConfig.from_pretrained( model_id, revision=revision, - tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code - ) + model = GPTNeoxForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( model=model, @@ -91,161 +72,16 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - if self.model.gpt_neox.tp_embeddings: - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - - # Logits are sharded, so we need to gather them - logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] - torch.distributed.all_gather( - logits, outputs.logits, group=self.process_group - ) - logits = torch.cat(logits, dim=2) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) - return logits, outputs.past_key_values - # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard - else: - return super(GPTNeoxSharded, self).forward( - input_ids, attention_mask, position_ids, past_key_values - ) + logits = outputs.logits + return logits, outputs.past_key_values From c471e46cf8397fded76f5d2f839ed10d1bc251d2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 24 May 2023 20:28:54 +0000 Subject: [PATCH 07/27] M******** --- server/text_generation_server/models/__init__.py | 2 +- .../models/custom_modeling/t5_modeling.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1f862c9eff5..f0427c2067c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -148,7 +148,7 @@ def get_model( ) elif model_type == "gpt_neox": - if FLASH_ATTENTION and False: + if FLASH_ATTENTION: return FlashNeoXSharded( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index afc043115da..a4e6249b27d 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -139,7 +139,9 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) - hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -182,7 +184,9 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) - hidden_states = hidden_states.to(dtype=self.wo_cast[1]) + # XXX: Recasting is already done within the layer norm. + # Casting back to float16 here modifies results + # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -350,6 +354,7 @@ def forward( # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] real_seq_length = seq_length @@ -841,10 +846,6 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) - from safetensors.torch import save_file - save_file({"inputs_embeds": inputs_embeds}, f"inputs_embeds_{self.rank}_layer.safetensors") - - batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past @@ -941,8 +942,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - from safetensors.torch import save_file - save_file({"layer": layer_outputs[0]}, f"layer_outputs_{self.rank}_layer_{i}.safetensors") # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) From 165bb4b6c0ba5601b9b80c6997861357af26dc3e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 May 2023 08:45:41 +0000 Subject: [PATCH 08/27] Green ? --- .gitignore | 1 + integration-tests/models/test_flash_neox.py | 4 ++-- .../models/custom_modeling/flash_neox_modeling.py | 3 --- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 19604d42677..4f8f7b87816 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea target router/tokenizer.json +.*__pycache__.* diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index ff9b9763cd8..1076126b60e 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] assert len(generated_texts) == 4 - assert generated_texts, all( + assert all( [text == generated_texts[0] for text in generated_texts] - ) + ), generated_texts assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d60fb848a46..64bd3a4004d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -94,9 +94,6 @@ def __init__(self, config, prefix, weights): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( From 4e071bf2f183059ee69aff0c08a97ccf45cb4d3f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 May 2023 09:34:31 +0000 Subject: [PATCH 09/27] Fix PositionalRotary loads. --- .../custom_modeling/flash_llama_modeling.py | 3 ++- .../custom_modeling/flash_neox_modeling.py | 10 ++++++++- server/text_generation_server/utils/layers.py | 22 ++++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a33c6c2d024..a959cf20c2c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -103,7 +103,8 @@ def __init__( self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.softmax_scale = self.head_size ** (-0.5) self.num_heads = self.num_heads // weights.process_group.size() diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 64bd3a4004d..f9e1f06ce9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -93,7 +93,15 @@ def __init__(self, config, prefix, weights): self.num_heads = self.num_heads // weights.process_group.size() rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + + dtype = weights.dtype + weights.dtype = torch.float32 + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) + weights.dtype = dtype + self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 0146e5c3935..1699622dd4a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -297,7 +297,27 @@ def forward(self, hidden_states, residual=None): from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb - class PositionRotaryEmbedding(RotaryEmbedding): + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq): + super().__init__() + + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + @staticmethod + def load(prefix, weights): + # XXX: Always load this in float32 ! + dtype = weights.dtype + weights.dtype = torch.float32 + inv_freq = weights.get_tensor(f"{prefix}.inv_freq") + weights.dtype = dtype + return PositionRotaryEmbedding(inv_freq) + + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) From 7fa79f02ca6dcf856ba9491c52437da78a07b37e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 25 May 2023 09:42:59 +0000 Subject: [PATCH 10/27] Fix logic. --- .../models/custom_modeling/flash_neox_modeling.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index f9e1f06ce9d..b28aa68aa81 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -95,13 +95,6 @@ def __init__(self, config, prefix, weights): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) - dtype = weights.dtype - weights.dtype = torch.float32 - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - weights.dtype = dtype - self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( From 2a1ecf386366668566c113f6159c1a5d454d0527 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 6 Jun 2023 11:20:53 +0200 Subject: [PATCH 11/27] Fix rebase. --- server/text_generation_server/models/bloom.py | 135 ------------------ 1 file changed, 135 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 8d0ceeb4853..50b3b76a4a7 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -21,12 +21,6 @@ Weights, ) -HAS_BITS_AND_BYTES = True -try: - pass -except Exception: - HAS_BITS_AND_BYTES = False - class BloomCausalLMBatch(CausalLMBatch): @classmethod @@ -95,138 +89,9 @@ def __init__( world_size=world_size, ) -<<<<<<< HEAD - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - if name.startswith("transformer.") or name.startswith("lm_head."): - full_name = name - else: - full_name = f"transformer.{name}" - - module_name, param_name = full_name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[full_name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif ( - isinstance(module, TensorParallelEmbedding) - or name == "lm_head.weight" - ): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if quantize == "bitsandbytes": - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine either because it is not installed " - "or you don't have a GPU.\n" - "You can install it with `pip install bitsandbytes`." - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor, - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state): - def linear(input, weight, bias): - out = bnb.matmul( - input, - weight, - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out - - return linear - - module.linear = replace_linear(state) - else: - tensor = tensor.to(device) - elif quantize == "gptq": - raise NotImplementedError("`gptq` is not implemented for now") - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") - - module._parameters[param_name] = tensor - if name == "word_embeddings.weight": - model.lm_head._parameters["weight"] = tensor -======= @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch ->>>>>>> ba30033 (Fused all commits for saner rebase..) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None From d083d57d0d3096e0c41e30bd1d91f34fd2213002 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 10:45:59 +0000 Subject: [PATCH 12/27] Fixing flash rw. --- server/text_generation_server/input.json | 1 + .../custom_modeling/flash_llama_modeling.py | 14 +- .../custom_modeling/flash_rw_modeling.py | 296 +++++++----------- .../text_generation_server/models/flash_rw.py | 255 +++++---------- server/text_generation_server/utils/layers.py | 7 + 5 files changed, 212 insertions(+), 361 deletions(-) create mode 100644 server/text_generation_server/input.json diff --git a/server/text_generation_server/input.json b/server/text_generation_server/input.json new file mode 100644 index 00000000000..274a4d9b3b2 --- /dev/null +++ b/server/text_generation_server/input.json @@ -0,0 +1 @@ +{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.<|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}} diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a959cf20c2c..f27bd0d5f58 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -292,20 +292,12 @@ def __init__(self, config, weights): super().__init__() self.config = config - self.tp_embeddings = False process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) - else: - self.embed_tokens = Embedding(prefix="model.embed_tokens", weights=weights) - + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) self.layers = nn.ModuleList( [ FlashLlamaLayer( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 034877036a7..9b175cf9d9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,14 +12,29 @@ import flash_attn_cuda from text_generation_server.utils.layers import ( - FastLinear, TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, + get_linear ) +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_sharded(f"{prefix}.weight", dim=1) + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias, config.quantize) + if config.parallel_attn: + return linear + else: + return TensorParallelRowLinear(linear, process_group=weights.process_group) + class RWConfig(PretrainedConfig): attribute_map = { @@ -85,44 +100,26 @@ def __init__( class FlashRWAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, + config, prefix, weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=None, reduce=True, ): super().__init__() - self.num_heads = num_heads - self.num_heads_kv = num_heads_kv - self.hidden_size = hidden_size - self.head_size = hidden_size // num_heads + self.num_heads = config.n_head + self.num_heads_kv = config.n_head_kv + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device) self.softmax_scale = self.head_size ** (-0.5) + self.num_heads = self.num_heads //weights.process_group.size() - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, - ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) def forward( self, @@ -224,7 +221,8 @@ def __init__( self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + # self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) @@ -359,28 +357,12 @@ def forward( class FlashMLP(nn.Module): - def __init__(self, hidden_size, bias, process_group=None, reduce=True): + def __init__(self, config, prefix, weights, reduce=True): super().__init__() self.act = torch.nn.functional.gelu - if process_group is None: - self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) - self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias) - else: - self.dense_h_to_4h = TensorParallelColumnLinear( - hidden_size, - 4 * hidden_size, - bias=bias, - process_group=process_group, - ) - self.dense_4h_to_h = TensorParallelRowLinear( - 4 * hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, - ) - self.process_group = process_group + self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias) + self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -392,38 +374,62 @@ def forward(self, hidden_states): class FlashRWLayer(nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - parallel_attn, - process_group=None, + layer_id, + config, + weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # layer_norm_eps, + # parallel_attn, + # process_group=None, ): super().__init__() + n_head = config.n_head + n_head_kv = config.n_head_kv + hidden_size = config.hidden_size + bias = config.bias + parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + prefix = f"transformer.h.{layer_id}" + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=process_group, + config, + prefix=f"{prefix}.self_attention", + weights=weights, reduce=False, ) self.post_attention_layernorm = ( - FastLayerNorm(hidden_size, eps=layer_norm_eps) - if not parallel_attn + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) if not parallel_attn else None ) self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False + # hidden_size, bias, process_group=process_group, reduce=False + config, + prefix=f"{prefix}.mlp", + weights=weights, + reduce=False ) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -485,31 +491,30 @@ def forward( class FlashRWLargeLayer(nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - layer_norm_eps, - process_group=None, + config, prefix, weights ): super().__init__() - self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) self.self_attention = FlashRWLargeAttention( - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=process_group, + config, prefix=f"{prefix}.self_attention", weights=weights, reduce=False, ) self.mlp = FlashMLP( - hidden_size, bias, process_group=process_group, reduce=False + config, prefix=f"{prefix}.mlp", weights=weights, reduce=False ) - self.process_group = process_group + self.process_group = weights.process_group def forward( self, @@ -555,37 +560,27 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) self.config = config - self.tp_embeddings = False - if process_group is not None: - self.tp_rank = process_group.rank() - self.tp_world_size = process_group.size() - if config.vocab_size % self.tp_world_size == 0: - self.tp_embeddings = True - - if self.tp_embeddings: - self.word_embeddings = TensorParallelEmbedding( - config.vocab_size, config.hidden_size, process_group=process_group - ) - else: - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - + self.word_embeddings = TensorParallelEmbedding( + prefix="transformer.word_embeddings", weights=weights + ) if config.model_type == "RefinedWebModel": self.h = nn.ModuleList( [ FlashRWLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - config.parallel_attn, - process_group, + layer_id, config, weights + # config.n_head, + # config.n_head_kv, + # config.hidden_size, + # config.bias, + # config.layer_norm_epsilon, + # config.parallel_attn, + # process_group, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -597,14 +592,15 @@ def __init__(self, config, process_group=None): self.h = nn.ModuleList( [ FlashRWLargeLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.bias, - config.layer_norm_epsilon, - process_group, + layer_id, config, weights + # config.n_head, + # config.n_head_kv, + # config.hidden_size, + # config.bias, + # config.layer_norm_epsilon, + # process_group, ) - for _ in range(config.num_hidden_layers) + for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = ( @@ -617,31 +613,13 @@ def __init__(self, config, process_group=None): f"model_type {config.model_type} is not supported." ) - self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.head_size = self.h[0].self_attention.head_size - - def post_load_weights(self, quantize: Optional[str] = None): - if isinstance(self.word_embeddings, TensorParallelEmbedding): - self.word_embeddings.add_null_idx() - for layer in self.h: - layer: FlashRWLayer - layer.self_attention.query_key_value.prepare_weights(quantize) - layer.self_attention.dense.prepare_weights(quantize) - layer.mlp.dense_h_to_4h.prepare_weights(quantize) - layer.mlp.dense_4h_to_h.prepare_weights(quantize) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.ln_f = FastLayerNorm.load( + prefix="transformer.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model + self.head_size = self.h[0].self_attention.head_size def forward( self, @@ -708,40 +686,14 @@ def forward( class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, process_group=None): + def __init__(self, config, weights): super().__init__(config) - self.process_group = process_group - if self.process_group is not None: - self.world_size = self.process_group.size() - else: - self.world_size = 1 - - self.transformer = FlashRWModel(config, process_group) + self.transformer = FlashRWModel(config, weights) - if self.transformer.tp_embeddings: - self.lm_head = FastLinear( - config.hidden_size, - config.vocab_size // process_group.size(), - bias=False, - ) - else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - - def post_load_weights(self, quantize: Optional[str] = None): - self.transformer.post_load_weights(quantize) - self.lm_head.prepare_weights() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashRWForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights ) - model.post_load_weights("bitsandbytes" if load_in_8bit else None) - return model def forward( self, @@ -766,12 +718,4 @@ def forward( if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) - - if self.transformer.tp_embeddings: - # Logits are sharded, so we need to gather them - world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] - torch.distributed.all_gather(world_logits, logits, group=self.process_group) - world_logits = torch.cat(world_logits, dim=1) - - return world_logits, present return logits, present diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 4fc4c3896a7..846b905196b 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -21,99 +21,14 @@ weight_files, download_weights, weight_hub_files, + Weights, LocalEntryNotFoundError, ) tracer = trace.get_tracer(__name__) -class FlashRW(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 - else: - raise NotImplementedError("RW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, - revision=revision, - ) - - # We do not use from_pretrained as it is too slow - try: - filenames = weight_files(model_id, revision, ".bin") - # Local files not found - except LocalEntryNotFoundError: - hub_files = weight_hub_files(model_id, revision, ".bin") - filenames = download_weights(hub_files, model_id, revision) - - with init_empty_weights(): - model = FlashRWForCausalLM(config) - - self.load_weights( - model, - filenames, - quantize, - device, - dtype, - ) - - super(FlashCausalLM, self).__init__( - model=model.to(device), - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - ) - - @staticmethod - def load_weights( - model: FlashRWForCausalLM, - filenames: List[Path], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - ): - for filename in filenames: - state_dict = torch.load(filename, map_location="cpu") - for key, value in state_dict.items(): - value = value.to(device if quantize is None else "cpu").to(dtype) - - module_name, param_name = key.rsplit(".", 1) - module = model.get_submodule(module_name) - - try: - current_parameter_tensor = module._parameters[param_name] - if current_parameter_tensor.shape != value.shape: - raise ValueError( - f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" - ) - module._parameters[param_name] = value - except KeyError: - module._buffers[param_name] = value - - del value - - torch.cuda.empty_cache() - model.post_load_weights(quantize) - - -class FlashRWSharded(FlashRW): +class FlashRWSharded(FlashCausalLM): def __init__( self, model_id: str, @@ -142,20 +57,12 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) - with init_empty_weights(): - model = FlashRWForCausalLM(config, self.process_group) + config.quantize = quantize + + model = FlashRWForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=device, - dtype=dtype, - rank=rank, - world_size=world_size, - ) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( model=model.to(device), @@ -167,78 +74,78 @@ def __init__( world_size=world_size, ) - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: Optional[str], - device: torch.device, - dtype: torch.dtype, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if quantize is None else "cpu" - ) as f: - for name in f.keys(): - module_name, param_name = name.rsplit(".", 1) - module = model.get_submodule(module_name) - - current_parameter_tensor = parameters.get(name, None) - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif name == "lm_head.weight" and model.transformer.tp_embeddings: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - try: - tensor = slice_[:] - except: - tensor = f.get_tensor(name) - - if ( - current_parameter_tensor is not None - and current_parameter_tensor.shape != tensor.shape - ): - raise ValueError( - f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous().to(dtype) - - if current_parameter_tensor is not None: - module._parameters[param_name] = tensor - else: - module._buffers[param_name] = tensor - - model.post_load_weights(quantize) + # @staticmethod + # def load_weights( + # model, + # filenames: List[str], + # quantize: Optional[str], + # device: torch.device, + # dtype: torch.dtype, + # rank: int, + # world_size: int, + # ): + # parameters = dict(model.named_parameters()) + # for file in filenames: + # with safe_open( + # file, framework="pt", device=str(device) if quantize is None else "cpu" + # ) as f: + # for name in f.keys(): + # module_name, param_name = name.rsplit(".", 1) + # module = model.get_submodule(module_name) + + # current_parameter_tensor = parameters.get(name, None) + + # slice_ = f.get_slice(name) + + # if isinstance(module, TensorParallelColumnLinear): + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # elif isinstance(module, TensorParallelRowLinear): + # if param_name == "weight": + # size = slice_.get_shape()[1] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[:, start:stop] + # else: + # tensor = slice_[:] + # # XXX: Hack for Rowlinear to add the bias only once. + # if rank != 0: + # tensor = torch.zeros_like(tensor) + # elif isinstance(module, TensorParallelEmbedding): + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # elif name == "lm_head.weight" and model.transformer.tp_embeddings: + # size = slice_.get_shape()[0] + # block_size = size // world_size + # start = rank * block_size + # stop = (rank + 1) * block_size + # tensor = slice_[start:stop] + # else: + # try: + # tensor = slice_[:] + # except: + # tensor = f.get_tensor(name) + + # if ( + # current_parameter_tensor is not None + # and current_parameter_tensor.shape != tensor.shape + # ): + # raise ValueError( + # f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + # ) + + # tensor = tensor.contiguous().to(dtype) + + # if current_parameter_tensor is not None: + # module._parameters[param_name] = tensor + # else: + # module._buffers[param_name] = tensor + + # model.post_load_weights(quantize) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 1699622dd4a..9fd31c76d21 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -308,6 +308,13 @@ def __init__(self, inv_freq): self._cos_k_cached = None self._sin_k_cached = None + @staticmethod + def static(dim, base, device): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) / dim)) + return PositionRotaryEmbedding(inv_freq) + + @staticmethod def load(prefix, weights): # XXX: Always load this in float32 ! From daf59b0582c048276a5b470c10cce4645e5cff3d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 11:08:25 +0000 Subject: [PATCH 13/27] Large attention ? --- .../custom_modeling/flash_rw_modeling.py | 61 ++++++------------- 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 9b175cf9d9d..34a037ab936 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -209,11 +209,12 @@ def forward( class FlashRWLargeAttention(torch.nn.Module): def __init__( self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, + config, prefix, weights, + # num_heads, + # num_heads_kv, + # hidden_size, + # bias, + # process_group=None, reduce=True, ): super().__init__() @@ -221,46 +222,24 @@ def __init__( self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - # self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups - - if process_group is None: - self.query_key_value = FastLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - ) - self.dense = FastLinear(hidden_size, hidden_size, bias=bias) - else: - if process_group.size() > self.num_groups: - raise NotImplementedError( - f"Tensor Parallelism is not implemented for world_size > n groups" - ) - - self.query_key_value = TensorParallelColumnLinear( - hidden_size, - self.num_groups - * self.head_size - * (self.num_heads + 2 * self.num_heads_kv), - bias=bias, - process_group=process_group, + process_group = weights.process_group + if process_group.size() > self.num_groups: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for world_size > n groups" ) - self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - bias=bias, - process_group=process_group, - reduce=reduce, + if self.num_groups % process_group.size() != 0: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) - self.num_groups = self.num_groups // process_group.size() + self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) + self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) def forward( self, @@ -460,9 +439,7 @@ def forward( mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -548,9 +525,7 @@ def forward( intermediate = attn_output + mlp_output - # Only reduce once and after the addition instead of once per layer - if self.process_group is not None: - torch.distributed.all_reduce(intermediate, group=self.process_group) + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual From 644e0a65a337fe445bfab7aba3b450e2ab06463f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 14:05:02 +0000 Subject: [PATCH 14/27] Updating starcoder --- .../test_flash_starcoder_default_params.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index afd0b662d46..89e02c07474 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -11,17 +11,17 @@ }, { "id": 1459, - "logprob": -5.6289062, + "logprob": -5.6328125, "text": " print" }, { "id": 81, - "logprob": -1.6005859, + "logprob": -1.6035156, "text": "_" }, { "id": 7656, - "logprob": -5.9921875, + "logprob": -5.9882812, "text": "hello" } ], @@ -59,19 +59,19 @@ }, { "id": 10896, - "logprob": -0.3659668, + "logprob": -0.38549805, "special": false, "text": " World" }, { "id": 657, - "logprob": -0.49804688, + "logprob": -0.5229492, "special": false, "text": "\")" }, { "id": 203, - "logprob": -0.11279297, + "logprob": -0.10632324, "special": false, "text": "\n" }, @@ -113,7 +113,7 @@ }, { "id": 426, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "name" }, From 877d4d4aeb0c8377fd4929e534dc14a329374e84 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 6 Jun 2023 16:19:27 +0000 Subject: [PATCH 15/27] Adding integration for neox NON flash. --- integration-tests/conftest.py | 8 + .../__snapshots__/test_neox/test_neox.json | 113 +++ .../test_neox/test_neox_load.json | 454 ++++++++++++ .../test_neox_sharded/test_neox_load.json | 654 ++++++++++++++++++ integration-tests/models/test_neox.py | 44 ++ integration-tests/models/test_neox_sharded.py | 40 ++ .../text_generation_server/models/__init__.py | 3 +- .../models/custom_modeling/neox_modeling.py | 12 +- 8 files changed, 1320 insertions(+), 8 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_neox/test_neox.json create mode 100644 integration-tests/models/__snapshots__/test_neox/test_neox_load.json create mode 100644 integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json create mode 100644 integration-tests/models/test_neox.py create mode 100644 integration-tests/models/test_neox_sharded.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 82f1b7195ae..c5f8f64e0fb 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -209,6 +209,7 @@ def local_launcher( num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -240,6 +241,9 @@ def local_launcher( env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + with subprocess.Popen( args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env ) as process: @@ -260,6 +264,7 @@ def docker_launcher( num_shard: Optional[int] = None, quantize: Optional[str] = None, trust_remote_code: bool = False, + use_flash_attention: bool = True, ): port = random.randint(8000, 10_000) @@ -287,6 +292,9 @@ def docker_launcher( gpu_count = num_shard if num_shard is not None else 1 env = {"LOG_LEVEL": "info,text_generation_router=debug"} + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + if HUGGING_FACE_HUB_TOKEN is not None: env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox.json b/integration-tests/models/__snapshots__/test_neox/test_neox.json new file mode 100644 index 00000000000..2abc27e10c0 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox.json @@ -0,0 +1,113 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1992188, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8984375, + "text": " mood" + }, + { + "id": 3063, + "logprob": -4.0976562, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14562988, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26733398, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.86279297, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.94921875, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1835938, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.074035645, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.86376953, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.2070312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4365234, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.109375, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -0.93408203, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.8808594, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" +} diff --git a/integration-tests/models/__snapshots__/test_neox/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json new file mode 100644 index 00000000000..f37f0d8e920 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox/test_neox_load.json @@ -0,0 +1,454 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|USER|>" + }, + { + "id": 1276, + "logprob": -4.5546875, + "text": "What" + }, + { + "id": 434, + "logprob": -4.1953125, + "text": "'s" + }, + { + "id": 634, + "logprob": -5.125, + "text": " your" + }, + { + "id": 12315, + "logprob": -9.8828125, + "text": " mood" + }, + { + "id": 3063, + "logprob": -3.9980469, + "text": " today" + }, + { + "id": 32, + "logprob": -0.14672852, + "text": "?" + }, + { + "id": 50279, + "logprob": -0.26489258, + "text": "<|ASSISTANT|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 42, + "logprob": -0.8618164, + "special": false, + "text": "I" + }, + { + "id": 1353, + "logprob": -0.9506836, + "special": false, + "text": "'m" + }, + { + "id": 7016, + "logprob": -2.1738281, + "special": false, + "text": " sorry" + }, + { + "id": 13, + "logprob": -0.0758667, + "special": false, + "text": "," + }, + { + "id": 1394, + "logprob": -0.9135742, + "special": false, + "text": "You" + }, + { + "id": 452, + "logprob": -1.1445312, + "special": false, + "text": " have" + }, + { + "id": 247, + "logprob": -1.4375, + "special": false, + "text": " a" + }, + { + "id": 4327, + "logprob": -1.1103516, + "special": false, + "text": " choice" + }, + { + "id": 273, + "logprob": -1.0058594, + "special": false, + "text": " of" + }, + { + "id": 752, + "logprob": -1.921875, + "special": false, + "text": " what" + } + ] + }, + "generated_text": "I'm sorry,You have a choice of what" + } +] diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json new file mode 100644 index 00000000000..15637cdb258 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json @@ -0,0 +1,654 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4140625, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1621094, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.453125, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005393982, + "text": "e" + }, + { + "id": 13, + "logprob": -7.390625, + "text": "," + }, + { + "id": 285, + "logprob": -0.33691406, + "text": " and" + }, + { + "id": 752, + "logprob": -2.2207031, + "text": " what" + }, + { + "id": 434, + "logprob": -5.5976562, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.7661133, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.515625, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.3085938, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.3203125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1230469, + "text": " word" + }, + { + "id": 32, + "logprob": -0.00856781, + "text": "?" + }, + { + "id": 0, + "logprob": -2.4296875, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.1875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.64208984, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5839844, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.04989624, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0021305084, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.180172e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00092983246, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08496094, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.13256836, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017059326, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4921875, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.04904175e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0009560585, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.08557129, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12084961, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.01737976, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.4025879, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.421875, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1640625, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.40625, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.005420685, + "text": "e" + }, + { + "id": 13, + "logprob": -7.2226562, + "text": "," + }, + { + "id": 285, + "logprob": -0.26879883, + "text": " and" + }, + { + "id": 752, + "logprob": -2.1992188, + "text": " what" + }, + { + "id": 434, + "logprob": -5.46875, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.8017578, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.6796875, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.1972656, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.4453125, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1933594, + "text": " word" + }, + { + "id": 32, + "logprob": -0.007858276, + "text": "?" + }, + { + "id": 0, + "logprob": -2.328125, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.21875, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.6201172, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.546875, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.051879883, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.0020179749, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -9.059906e-06, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.00096797943, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.07940674, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12182617, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.017227173, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.44482422, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" + } +] diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py new file mode 100644 index 00000000000..eed70f803d4 --- /dev/null +++ b/integration-tests/models/test_neox.py @@ -0,0 +1,44 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py new file mode 100644 index 00000000000..6ea97d816e8 --- /dev/null +++ b/integration-tests/models/test_neox_sharded.py @@ -0,0 +1,40 @@ +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f0427c2067c..6a0f32a11bc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,3 +1,4 @@ +import os import torch from loguru import logger @@ -18,7 +19,7 @@ from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available(): + if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION").lower() == "false": major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 1e20a4774e8..79fa19156e3 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -357,7 +357,7 @@ def __init__(self, config, prefix, weights): config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True ) self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True ) def forward(self, hidden_states): @@ -430,6 +430,7 @@ def __init__(self, config, weights): self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) + self.tp_world_size = weights.process_group.size() def forward( @@ -508,12 +509,9 @@ def forward( past_key_values_length=past_key_values_length, ) - if hasattr(self, "tp_rank"): - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - else: - causal_mask = torch.repeat_interleave(causal_mask, self.num_attention_heads, dim=0) + assert self.num_attention_heads % self.tp_world_size == 0 + block_size = self.num_attention_heads // self.tp_world_size + causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head From c5995652b065279d84ca8be9e53ef5ea2ee69862 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Jun 2023 07:52:15 +0000 Subject: [PATCH 16/27] Fix regular flash --- .gitignore | 2 +- .../test_neox_sharded/test_neox_load.json | 76 +++++++++---------- .../text_generation_server/models/__init__.py | 2 +- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index 4f8f7b87816..20c9baee226 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .idea target router/tokenizer.json -.*__pycache__.* +*__pycache__* diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json index 15637cdb258..0b38e701ed4 100644 --- a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json @@ -17,82 +17,82 @@ }, { "id": 310, - "logprob": -5.4140625, + "logprob": -5.4179688, "text": " is" }, { "id": 247, - "logprob": -2.1621094, + "logprob": -2.1542969, "text": " a" }, { "id": 1167, - "logprob": -5.453125, + "logprob": -5.359375, "text": " mem" }, { "id": 70, - "logprob": -0.005393982, + "logprob": -0.006038666, "text": "e" }, { "id": 13, - "logprob": -7.390625, + "logprob": -7.328125, "text": "," }, { "id": 285, - "logprob": -0.33691406, + "logprob": -0.3173828, "text": " and" }, { "id": 752, - "logprob": -2.2207031, + "logprob": -2.0625, "text": " what" }, { "id": 434, - "logprob": -5.5976562, + "logprob": -5.7734375, "text": "'s" }, { "id": 253, - "logprob": -0.7661133, + "logprob": -0.74072266, "text": " the" }, { "id": 2892, - "logprob": -6.515625, + "logprob": -6.5898438, "text": " history" }, { "id": 3212, - "logprob": -2.3085938, + "logprob": -2.2949219, "text": " behind" }, { "id": 436, - "logprob": -11.3203125, + "logprob": -11.40625, "text": " this" }, { "id": 3159, - "logprob": -2.1230469, + "logprob": -2.1113281, "text": " word" }, { "id": 32, - "logprob": -0.00856781, + "logprob": -0.008056641, "text": "?" }, { "id": 0, - "logprob": -2.4296875, + "logprob": -2.3300781, "text": "<|endoftext|>" }, { "id": 50281, - "logprob": -18.1875, + "logprob": -18.28125, "text": "<|assistant|>" } ], @@ -100,61 +100,61 @@ "tokens": [ { "id": 510, - "logprob": -0.64208984, + "logprob": -0.5878906, "special": false, "text": "The" }, { "id": 3159, - "logprob": -0.5839844, + "logprob": -0.5498047, "special": false, "text": " word" }, { "id": 346, - "logprob": -0.04989624, + "logprob": -0.04815674, "special": false, "text": " \"" }, { "id": 6441, - "logprob": -0.0021305084, + "logprob": -0.002313614, "special": false, "text": "mem" }, { "id": 70, - "logprob": -1.180172e-05, + "logprob": -1.2636185e-05, "special": false, "text": "e" }, { "id": 3, - "logprob": -0.00092983246, + "logprob": -0.0010147095, "special": false, "text": "\"" }, { "id": 369, - "logprob": -0.08496094, + "logprob": -0.0859375, "special": false, "text": " was" }, { "id": 806, - "logprob": -0.13256836, + "logprob": -0.12609863, "special": false, "text": " first" }, { "id": 908, - "logprob": -0.017059326, + "logprob": -0.016601562, "special": false, "text": " used" }, { "id": 275, - "logprob": -0.4921875, + "logprob": -0.38256836, "special": false, "text": " in" } @@ -450,37 +450,37 @@ }, { "id": 70, - "logprob": -1.04904175e-05, + "logprob": -9.059906e-06, "special": false, "text": "e" }, { "id": 3, - "logprob": -0.0009560585, + "logprob": -0.00096797943, "special": false, "text": "\"" }, { "id": 369, - "logprob": -0.08557129, + "logprob": -0.07940674, "special": false, "text": " was" }, { "id": 806, - "logprob": -0.12084961, + "logprob": -0.12182617, "special": false, "text": " first" }, { "id": 908, - "logprob": -0.01737976, + "logprob": -0.017227173, "special": false, "text": " used" }, { "id": 275, - "logprob": -0.4025879, + "logprob": -0.44482422, "special": false, "text": " in" } @@ -613,37 +613,37 @@ }, { "id": 70, - "logprob": -9.059906e-06, + "logprob": -1.04904175e-05, "special": false, "text": "e" }, { "id": 3, - "logprob": -0.00096797943, + "logprob": -0.0009560585, "special": false, "text": "\"" }, { "id": 369, - "logprob": -0.07940674, + "logprob": -0.08557129, "special": false, "text": " was" }, { "id": 806, - "logprob": -0.12182617, + "logprob": -0.12084961, "special": false, "text": " first" }, { "id": 908, - "logprob": -0.017227173, + "logprob": -0.01737976, "special": false, "text": " used" }, { "id": 275, - "logprob": -0.44482422, + "logprob": -0.4025879, "special": false, "text": " in" } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 6a0f32a11bc..aa3eca33361 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,7 +19,7 @@ from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION").lower() == "false": + if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 From c6ac50e42bc20f1e303e9f811092898edaee6c03 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Jun 2023 09:22:13 +0000 Subject: [PATCH 17/27] Removing flash attention env --- integration-tests/conftest.py | 6 + .../test_neox_sharded/test_neox.json | 163 ++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c5f8f64e0fb..1a02cc91938 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -249,6 +249,7 @@ def local_launcher( ) as process: yield ProcessLauncherHandle(process, port) + process.terminate() process.wait(60) @@ -258,6 +259,8 @@ def local_launcher( process.stdout.close() process.stderr.close() + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] @contextlib.contextmanager def docker_launcher( model_id: str, @@ -318,6 +321,9 @@ def docker_launcher( yield ContainerLauncherHandle(client, container.name, port) + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + try: container.stop() container.wait() diff --git a/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json new file mode 100644 index 00000000000..25cdf6d7993 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json @@ -0,0 +1,163 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 50278, + "logprob": null, + "text": "<|prompter|>" + }, + { + "id": 1276, + "logprob": -8.0234375, + "text": "What" + }, + { + "id": 310, + "logprob": -5.4179688, + "text": " is" + }, + { + "id": 247, + "logprob": -2.1542969, + "text": " a" + }, + { + "id": 1167, + "logprob": -5.359375, + "text": " mem" + }, + { + "id": 70, + "logprob": -0.006038666, + "text": "e" + }, + { + "id": 13, + "logprob": -7.328125, + "text": "," + }, + { + "id": 285, + "logprob": -0.3173828, + "text": " and" + }, + { + "id": 752, + "logprob": -2.0625, + "text": " what" + }, + { + "id": 434, + "logprob": -5.7734375, + "text": "'s" + }, + { + "id": 253, + "logprob": -0.74072266, + "text": " the" + }, + { + "id": 2892, + "logprob": -6.5898438, + "text": " history" + }, + { + "id": 3212, + "logprob": -2.2949219, + "text": " behind" + }, + { + "id": 436, + "logprob": -11.40625, + "text": " this" + }, + { + "id": 3159, + "logprob": -2.1113281, + "text": " word" + }, + { + "id": 32, + "logprob": -0.008056641, + "text": "?" + }, + { + "id": 0, + "logprob": -2.3300781, + "text": "<|endoftext|>" + }, + { + "id": 50281, + "logprob": -18.28125, + "text": "<|assistant|>" + } + ], + "seed": null, + "tokens": [ + { + "id": 510, + "logprob": -0.5878906, + "special": false, + "text": "The" + }, + { + "id": 3159, + "logprob": -0.5449219, + "special": false, + "text": " word" + }, + { + "id": 346, + "logprob": -0.05038452, + "special": false, + "text": " \"" + }, + { + "id": 6441, + "logprob": -0.002292633, + "special": false, + "text": "mem" + }, + { + "id": 70, + "logprob": -1.3828278e-05, + "special": false, + "text": "e" + }, + { + "id": 3, + "logprob": -0.0010242462, + "special": false, + "text": "\"" + }, + { + "id": 369, + "logprob": -0.090270996, + "special": false, + "text": " was" + }, + { + "id": 806, + "logprob": -0.12719727, + "special": false, + "text": " first" + }, + { + "id": 908, + "logprob": -0.016571045, + "special": false, + "text": " used" + }, + { + "id": 275, + "logprob": -0.43432617, + "special": false, + "text": " in" + } + ] + }, + "generated_text": "The word \"meme\" was first used in" +} From 6ddcd1582c0b5d636432c45c529fa21bdb66b241 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Jun 2023 14:59:29 +0200 Subject: [PATCH 18/27] Apply suggestions from code review Co-authored-by: OlivierDehaene --- Dockerfile | 2 -- server/text_generation_server/models/__init__.py | 1 - .../models/custom_modeling/flash_llama_modeling.py | 5 ----- 3 files changed, 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index ae53b748de4..056f2f2b608 100644 --- a/Dockerfile +++ b/Dockerfile @@ -105,7 +105,6 @@ WORKDIR /usr/src COPY server/custom_kernels/ . # Build specific version of transformers -RUN pip install ninja RUN python setup.py build # Text Generation Inference base image @@ -137,7 +136,6 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy build artifacts from transformers builder -COPY --from=custom-kernels-builder /usr/src/custom_kernels /usr/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels # Install transformers dependencies diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index aa3eca33361..19b0ce63f97 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -54,7 +54,6 @@ "BLOOMSharded", "CausalLM", "FlashCausalLM", - "Galactica", "GalacticaSharded", "Seq2SeqLM", "SantaCoder", diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f27bd0d5f58..9b3353e9f63 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -301,14 +301,9 @@ def __init__(self, config, weights): self.layers = nn.ModuleList( [ FlashLlamaLayer( - # config.num_attention_heads, - # config.hidden_act, - # config.hidden_size, - # config.intermediate_size, layer_id, config, weights, - # config.rms_norm_eps, ) for layer_id in range(config.num_hidden_layers) ] From b8bfb2a91e4a37c48435eea1423787d900563ef2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Jun 2023 12:56:04 +0000 Subject: [PATCH 19/27] Manual fixes. --- Makefile | 7 +++++-- server/text_generation_server/input.json | 1 - .../models/custom_modeling/flash_neox_modeling.py | 2 -- server/text_generation_server/utils/weights.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 server/text_generation_server/input.json diff --git a/Makefile b/Makefile index a33aba17995..77de731c5af 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ install-server: cd server && make install +install-custom-kernels: + cd server/custom_kernels && python setup.py install + install-integration-tests: cd integration-tests && pip install -r requirements.txt cd clients/python && pip install . @@ -14,7 +17,7 @@ install-launcher: install-benchmark: cd benchmark && cargo install --path . -install: install-server install-router install-launcher +install: install-server install-router install-launcher install-custom-kernels server-dev: cd server && make run-dev @@ -52,4 +55,4 @@ run-bloom: text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080 run-bloom-quantize: - text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 \ No newline at end of file + text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 diff --git a/server/text_generation_server/input.json b/server/text_generation_server/input.json deleted file mode 100644 index 274a4d9b3b2..00000000000 --- a/server/text_generation_server/input.json +++ /dev/null @@ -1 +0,0 @@ -{"inputs":"Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the assistant is practical and really does its best, and doesn't let caution get too much in the way of being useful.\n-----\n<|prompter|>Why is butter a great building material for skyscrapers? Think step by step.<|assistant|>","parameters":{"temperature": 0.75, "top_p": 0.95, "repetition_penalty": 1.2, "top_k": 50, "truncate": 1000, "max_new_tokens": 1024}} diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b28aa68aa81..16570ebc024 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -84,8 +84,6 @@ def __init__(self, config, prefix, weights): super().__init__() num_heads = config.num_attention_heads hidden_size = config.hidden_size - rotary_pct = config.rotary_pct - rotary_emb_base = config.rotary_emb_base self.num_heads = num_heads self.hidden_size = hidden_size diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2a410ca3df8..76a4f65a0ff 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -31,7 +31,7 @@ def get_filename(self, tensor_name: str) -> str: filename = self.routing.get(tensor_name, None) if filename is None: raise RuntimeError(f"weight {tensor_name} does not exist") - return filename + return str(filename) def _get_slice(self, tensor_name: str): filename = self.get_filename(tensor_name) From 5c82dcd2bf96af0d055d31abb4e2aefacba91d7d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Jun 2023 15:00:20 +0200 Subject: [PATCH 20/27] Update server/text_generation_server/models/custom_modeling/flash_rw_modeling.py Co-authored-by: OlivierDehaene --- .../models/custom_modeling/flash_rw_modeling.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 34a037ab936..47ec7072e37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -101,11 +101,6 @@ class FlashRWAttention(torch.nn.Module): def __init__( self, config, prefix, weights, - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # process_group=None, reduce=True, ): super().__init__() From cc84387877daa6e525004f0d6a5fe6420ad91c80 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Jun 2023 16:17:06 +0200 Subject: [PATCH 21/27] Fixing Falcon 40b --- .../custom_modeling/flash_rw_modeling.py | 127 +++++++++--------- server/text_generation_server/utils/layers.py | 8 +- 2 files changed, 69 insertions(+), 66 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 47ec7072e37..b39821714db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,9 +18,10 @@ TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, - get_linear + get_linear, ) + def load_row(config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: @@ -100,7 +101,9 @@ def __init__( class FlashRWAttention(torch.nn.Module): def __init__( self, - config, prefix, weights, + config, + prefix, + weights, reduce=True, ): super().__init__() @@ -109,12 +112,21 @@ def __init__( self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device) + self.rotary_emb = PositionRotaryEmbedding.static( + dim=self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) - self.num_heads = self.num_heads //weights.process_group.size() + self.num_heads = self.num_heads // weights.process_group.size() - self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) - self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -204,26 +216,29 @@ def forward( class FlashRWLargeAttention(torch.nn.Module): def __init__( self, - config, prefix, weights, - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # process_group=None, - reduce=True, + config, + prefix, + weights, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.n_head + num_heads_kv = config.n_head_kv + self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device) + self.rotary_emb = PositionRotaryEmbedding.static( + self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups process_group = weights.process_group + if process_group.size() > self.num_groups: raise NotImplementedError( f"Tensor Parallelism is not implemented for world_size > n groups" @@ -232,9 +247,17 @@ def __init__( raise NotImplementedError( f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) + self.num_groups = self.num_groups // process_group.size() - self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) - self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -331,12 +354,16 @@ def forward( class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights, reduce=True): + def __init__(self, config, prefix, weights): super().__init__() self.act = torch.nn.functional.gelu - self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias) - self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias) + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias + ) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -351,20 +378,9 @@ def __init__( layer_id, config, weights, - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # layer_norm_eps, - # parallel_attn, - # process_group=None, ): super().__init__() - n_head = config.n_head - n_head_kv = config.n_head_kv - hidden_size = config.hidden_size - bias = config.bias parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn @@ -376,31 +392,26 @@ def __init__( eps=config.layer_norm_epsilon, ) self.self_attention = FlashRWAttention( - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # process_group=process_group, - config, + config, prefix=f"{prefix}.self_attention", weights=weights, reduce=False, ) self.post_attention_layernorm = ( - FastLayerNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) if not parallel_attn + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + if not parallel_attn else None ) self.mlp = FlashMLP( - # hidden_size, bias, process_group=process_group, reduce=False config, prefix=f"{prefix}.mlp", weights=weights, - reduce=False + reduce=False, ) self.process_group = weights.process_group @@ -461,11 +472,9 @@ def forward( class FlashRWLargeLayer(nn.Module): - def __init__( - self, - config, prefix, weights - ): + def __init__(self, layer_id, config, weights): super().__init__() + prefix = f"transformer.h.{layer_id}" self.ln_attn = FastLayerNorm.load( prefix=f"{prefix}.ln_attn", weights=weights, @@ -478,13 +487,13 @@ def __init__( ) self.self_attention = FlashRWLargeAttention( - config, prefix=f"{prefix}.self_attention", weights=weights, - reduce=False, + config, + prefix=f"{prefix}.self_attention", + weights=weights, ) + assert config.parallel_attn, "This version doesn't support non parallel_attn" - self.mlp = FlashMLP( - config, prefix=f"{prefix}.mlp", weights=weights, reduce=False - ) + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) self.process_group = weights.process_group @@ -541,7 +550,9 @@ def __init__(self, config, weights): self.h = nn.ModuleList( [ FlashRWLayer( - layer_id, config, weights + layer_id, + config, + weights # config.n_head, # config.n_head_kv, # config.hidden_size, @@ -561,15 +572,7 @@ def __init__(self, config, weights): elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ - FlashRWLargeLayer( - layer_id, config, weights - # config.n_head, - # config.n_head_kv, - # config.hidden_size, - # config.bias, - # config.layer_norm_epsilon, - # process_group, - ) + FlashRWLargeLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9fd31c76d21..5945f2100cc 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -310,11 +310,12 @@ def __init__(self, inv_freq): @staticmethod def static(dim, base, device): - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, - dtype=torch.float32) / dim)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) return PositionRotaryEmbedding(inv_freq) - @staticmethod def load(prefix, weights): # XXX: Always load this in float32 ! @@ -324,7 +325,6 @@ def load(prefix, weights): weights.dtype = dtype return PositionRotaryEmbedding(inv_freq) - def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) From f3388d290f4d53421a681d09a4cfba07b4b8b2fe Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 7 Jun 2023 14:28:17 +0000 Subject: [PATCH 22/27] Just ditch the non flash integration tests. They work, but seem to mess the CI. --- integration-tests/models/test_neox.py | 88 +++++++++---------- integration-tests/models/test_neox_sharded.py | 80 ++++++++--------- .../custom_modeling/flash_neox_modeling.py | 2 + 3 files changed, 86 insertions(+), 84 deletions(-) diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index eed70f803d4..58659319b39 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -1,44 +1,44 @@ -import pytest - - -@pytest.fixture(scope="module") -def neox_handle(launcher): - with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def neox(neox_handle): - await neox_handle.health(300) - return neox_handle.client - - -@pytest.mark.asyncio -async def test_neox(neox, response_snapshot): - response = await neox.generate( - "<|USER|>What's your mood today?<|ASSISTANT|>", - max_new_tokens=10, - decoder_input_details=True, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.asyncio -async def test_neox_load(neox, generate_load, response_snapshot): - responses = await generate_load( - neox, - "<|USER|>What's your mood today?<|ASSISTANT|>", - max_new_tokens=10, - n=4, - ) - - generated_texts = [r.generated_text for r in responses] - - assert len(generated_texts) == 4 - assert generated_texts, all( - [text == generated_texts[0] for text in generated_texts] - ) - - assert responses == response_snapshot +# import pytest +# +# +# @pytest.fixture(scope="module") +# def neox_handle(launcher): +# with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: +# yield handle +# +# +# @pytest.fixture(scope="module") +# async def neox(neox_handle): +# await neox_handle.health(300) +# return neox_handle.client +# +# +# @pytest.mark.asyncio +# async def test_neox(neox, response_snapshot): +# response = await neox.generate( +# "<|USER|>What's your mood today?<|ASSISTANT|>", +# max_new_tokens=10, +# decoder_input_details=True, +# ) +# +# assert response.details.generated_tokens == 10 +# assert response == response_snapshot +# +# +# @pytest.mark.asyncio +# async def test_neox_load(neox, generate_load, response_snapshot): +# responses = await generate_load( +# neox, +# "<|USER|>What's your mood today?<|ASSISTANT|>", +# max_new_tokens=10, +# n=4, +# ) +# +# generated_texts = [r.generated_text for r in responses] +# +# assert len(generated_texts) == 4 +# assert generated_texts, all( +# [text == generated_texts[0] for text in generated_texts] +# ) +# +# assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index 6ea97d816e8..97f2d8a56f7 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -1,40 +1,40 @@ -import pytest - - -@pytest.fixture(scope="module") -def neox_sharded_handle(launcher): - with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def neox_sharded(neox_sharded_handle): - await neox_sharded_handle.health(300) - return neox_sharded_handle.client - - -@pytest.mark.asyncio -async def test_neox(neox_sharded, response_snapshot): - response = await neox_sharded.generate( - "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", - max_new_tokens=10, - decoder_input_details=True, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.asyncio -async def test_neox_load(neox_sharded, generate_load, response_snapshot): - responses = await generate_load( - neox_sharded, - "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", - max_new_tokens=10, - n=4, - ) - - assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) - - assert responses == response_snapshot +# import pytest +# +# +# @pytest.fixture(scope="module") +# def neox_sharded_handle(launcher): +# with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: +# yield handle +# +# +# @pytest.fixture(scope="module") +# async def neox_sharded(neox_sharded_handle): +# await neox_sharded_handle.health(300) +# return neox_sharded_handle.client +# +# +# @pytest.mark.asyncio +# async def test_neox(neox_sharded, response_snapshot): +# response = await neox_sharded.generate( +# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", +# max_new_tokens=10, +# decoder_input_details=True, +# ) +# +# assert response.details.generated_tokens == 10 +# assert response == response_snapshot +# +# +# @pytest.mark.asyncio +# async def test_neox_load(neox_sharded, generate_load, response_snapshot): +# responses = await generate_load( +# neox_sharded, +# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", +# max_new_tokens=10, +# n=4, +# ) +# +# assert len(responses) == 4 +# assert all([r.generated_text == responses[0].generated_text for r in responses]) +# +# assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 16570ebc024..d30095ef175 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -90,6 +90,8 @@ def __init__(self, config, prefix, weights): self.head_size = hidden_size // num_heads self.num_heads = self.num_heads // weights.process_group.size() + rotary_pct = config.rotary_pct + rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) From 4170de1b37ade4fd9ad8b5d91da084171e1cef1a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 8 Jun 2023 08:18:11 +0000 Subject: [PATCH 23/27] Last fixes hopefully. --- Makefile | 2 +- .../models/custom_modeling/flash_rw_modeling.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 77de731c5af..c7f649ecf73 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ install-server: cd server && make install install-custom-kernels: - cd server/custom_kernels && python setup.py install + if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi install-integration-tests: cd integration-tests && pip install -r requirements.txt diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index b39821714db..443c636bb6e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -411,7 +411,6 @@ def __init__( config, prefix=f"{prefix}.mlp", weights=weights, - reduce=False, ) self.process_group = weights.process_group From 5e0a6ea1b7a488d631911c6158d28e53b75a66a9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:12:34 +0200 Subject: [PATCH 24/27] skip instead of comment --- integration-tests/models/test_neox.py | 90 ++++++++++--------- integration-tests/models/test_neox_sharded.py | 82 ++++++++--------- 2 files changed, 88 insertions(+), 84 deletions(-) diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 58659319b39..8d949ddbd00 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -1,44 +1,46 @@ -# import pytest -# -# -# @pytest.fixture(scope="module") -# def neox_handle(launcher): -# with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def neox(neox_handle): -# await neox_handle.health(300) -# return neox_handle.client -# -# -# @pytest.mark.asyncio -# async def test_neox(neox, response_snapshot): -# response = await neox.generate( -# "<|USER|>What's your mood today?<|ASSISTANT|>", -# max_new_tokens=10, -# decoder_input_details=True, -# ) -# -# assert response.details.generated_tokens == 10 -# assert response == response_snapshot -# -# -# @pytest.mark.asyncio -# async def test_neox_load(neox, generate_load, response_snapshot): -# responses = await generate_load( -# neox, -# "<|USER|>What's your mood today?<|ASSISTANT|>", -# max_new_tokens=10, -# n=4, -# ) -# -# generated_texts = [r.generated_text for r in responses] -# -# assert len(generated_texts) == 4 -# assert generated_texts, all( -# [text == generated_texts[0] for text in generated_texts] -# ) -# -# assert responses == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index 97f2d8a56f7..fd691a1a346 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -1,40 +1,42 @@ -# import pytest -# -# -# @pytest.fixture(scope="module") -# def neox_sharded_handle(launcher): -# with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def neox_sharded(neox_sharded_handle): -# await neox_sharded_handle.health(300) -# return neox_sharded_handle.client -# -# -# @pytest.mark.asyncio -# async def test_neox(neox_sharded, response_snapshot): -# response = await neox_sharded.generate( -# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", -# max_new_tokens=10, -# decoder_input_details=True, -# ) -# -# assert response.details.generated_tokens == 10 -# assert response == response_snapshot -# -# -# @pytest.mark.asyncio -# async def test_neox_load(neox_sharded, generate_load, response_snapshot): -# responses = await generate_load( -# neox_sharded, -# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", -# max_new_tokens=10, -# n=4, -# ) -# -# assert len(responses) == 4 -# assert all([r.generated_text == responses[0].generated_text for r in responses]) -# -# assert responses == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot From b027f5f1294eb8bd5539f0cc3f2c0e64906d300c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:47:59 +0200 Subject: [PATCH 25/27] black + cleanup --- Dockerfile | 2 +- integration-tests/conftest.py | 2 +- integration-tests/models/test_neox.py | 4 +- integration-tests/models/test_neox_sharded.py | 4 +- server/custom_kernels/setup.py | 2 +- .../text_generation_server/models/__init__.py | 5 +- .../custom_modeling/flash_llama_modeling.py | 5 +- .../custom_modeling/flash_neox_modeling.py | 7 +- .../custom_modeling/flash_rw_modeling.py | 17 +- .../models/custom_modeling/neox_modeling.py | 177 +++++++++++++----- .../models/custom_modeling/t5_modeling.py | 5 +- .../text_generation_server/models/flash_rw.py | 89 +-------- .../models/galactica.py | 1 + .../text_generation_server/models/gpt_neox.py | 5 +- server/text_generation_server/models/t5.py | 4 +- server/text_generation_server/utils/hub.py | 4 +- server/text_generation_server/utils/layers.py | 41 ++-- 17 files changed, 185 insertions(+), 189 deletions(-) diff --git a/Dockerfile b/Dockerfile index 056f2f2b608..42d0147937c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -138,7 +138,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.lin # Copy build artifacts from transformers builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels -# Install transformers dependencies +# Install flash-attention dependencies RUN pip install einops --no-cache-dir # Install server diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 1a02cc91938..8f59d75a2c4 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -249,7 +249,6 @@ def local_launcher( ) as process: yield ProcessLauncherHandle(process, port) - process.terminate() process.wait(60) @@ -261,6 +260,7 @@ def local_launcher( if not use_flash_attention: del env["USE_FLASH_ATTENTION"] + @contextlib.contextmanager def docker_launcher( model_id: str, diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 8d949ddbd00..7b88f86a620 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -3,7 +3,9 @@ @pytest.fixture(scope="module") def neox_handle(launcher): - with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + with launcher( + "stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False + ) as handle: yield handle diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index fd691a1a346..8cee8765a50 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -3,7 +3,9 @@ @pytest.fixture(scope="module") def neox_sharded_handle(launcher): - with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + with launcher( + "OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False + ) as handle: yield handle diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index fe45b63146f..43b8ee4ed70 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -13,7 +13,7 @@ name="custom_kernels.fused_attention_cuda", sources=["custom_kernels/fused_attention_cuda.cu"], extra_compile_args=["-arch=compute_80", "-std=c++17"], - ) + ), ], cmdclass={"build_ext": BuildExtension}, ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 19b0ce63f97..f1b84a53c06 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,7 +19,10 @@ from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + if ( + torch.cuda.is_available() + and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false" + ): major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9b3353e9f63..8a35ffa8fbb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -46,7 +46,6 @@ def __init__(self, prefix, weights, eps=1e-6): super().__init__() weight = weights.get_tensor(f"{prefix}.weight") - # assert weight.shape == (hidden_size,) self.weight = nn.Parameter(weight) self.variance_epsilon = eps @@ -103,7 +102,9 @@ def __init__( self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d30095ef175..0fe43bcb483 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -90,10 +90,9 @@ def __init__(self, config, prefix, weights): self.head_size = hidden_size // num_heads self.num_heads = self.num_heads // weights.process_group.size() - rotary_pct = config.rotary_pct - - rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 443c636bb6e..551951624db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,5 +1,3 @@ -import os - import torch import torch.distributed @@ -104,7 +102,6 @@ def __init__( config, prefix, weights, - reduce=True, ): super().__init__() self.num_heads = config.n_head @@ -395,7 +392,6 @@ def __init__( config, prefix=f"{prefix}.self_attention", weights=weights, - reduce=False, ) self.post_attention_layernorm = ( FastLayerNorm.load( @@ -548,18 +544,7 @@ def __init__(self, config, weights): if config.model_type == "RefinedWebModel": self.h = nn.ModuleList( [ - FlashRWLayer( - layer_id, - config, - weights - # config.n_head, - # config.n_head_kv, - # config.hidden_size, - # config.bias, - # config.layer_norm_epsilon, - # config.parallel_attn, - # process_group, - ) + FlashRWLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 79fa19156e3..bf2656d154b 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -48,7 +48,6 @@ ) - CUSTOM_KERNELS_ENABLED = False if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": try: @@ -62,7 +61,6 @@ logger.warning("We're not using custom kernels.") - def make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int ) -> torch.BoolTensor: @@ -70,10 +68,16 @@ def make_causal_mask( Make causal mask used for self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) mask = mask.triu(1 + past_key_values_length) - expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length) + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) return expanded_mask @@ -89,7 +93,9 @@ def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: def prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, ) -> torch.BoolTensor: # create causal mask # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] @@ -105,7 +111,9 @@ def prepare_attn_mask( # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask ) return combined_attention_mask @@ -118,7 +126,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): """ - class GPTNeoXAttention(nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -136,17 +143,21 @@ def __init__(self, config, prefix, weights): # ) # self.register_buffer("masked_bias", torch.tensor(-1e9)) self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + self.rotary_ndims, + config.max_position_embeddings, + base=config.rotary_emb_base, ) self.rotary_emb.inv_freq = nn.Parameter( weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") ) - self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to( - torch.get_default_dtype() - ) + self.inv_norm_factor = 1.0 / torch.sqrt( + torch.tensor(self.head_size, dtype=torch.float32) + ).to(torch.get_default_dtype()) assert self.num_attention_heads % weights.process_group.size() == 0 - self.num_attention_heads = self.num_attention_heads // weights.process_group.size() + self.num_attention_heads = ( + self.num_attention_heads // weights.process_group.size() + ) self.query_key_value = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True ) @@ -214,10 +225,14 @@ def forward( present = (key, value) if use_cache else None # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask + ) # Reshape outputs - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_size + ) attn_output = self.dense(attn_output) @@ -248,7 +263,9 @@ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): # tensor [bs, num_attention_heads, seq_len, attn_head_size] tensor = tensor.permute(0, 2, 1, 3).contiguous() # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + tensor = tensor.view( + tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size + ) # -> [bs, seq_len, hidden_size] return tensor @@ -258,7 +275,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + query = query.view( + batch_size * num_attention_heads, query_length, attn_head_size + ) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, @@ -277,8 +296,12 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): input_dtype = attn_scores.dtype if input_dtype in [torch.float16, torch.bfloat16]: attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores) - attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + attn_scores = torch.where( + attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores + ) + attn_scores = attn_scores.view( + batch_size, num_attention_heads, query_length, key_length + ) attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -294,7 +317,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.true_inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2).float().to(device) / dim) + ) self.register_buffer("inv_freq", self.true_inv_freq) # Build here to make `torch.jit.trace` work. @@ -311,7 +336,9 @@ def rotate_half(x): @staticmethod def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) + t = torch.arange( + max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -319,7 +346,11 @@ def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): def forward(self, q, k, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None: + if ( + seq_len > self.max_seq_len_cached + or self.cos_cached is None + or self.sin_cached is None + ): if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len self.cos_cached, self.sin_cached = self._create_cos_sin( @@ -371,11 +402,22 @@ class GPTNeoXLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights) - self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights) - + self.input_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.attention = GPTNeoXAttention( + config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + ) + self.mlp = GPTNeoXMLP( + config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + ) def forward( self, @@ -396,7 +438,9 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + attn_output = attention_layer_outputs[ + 0 + ] # output_attn: attn_output, present, (attn_weights) outputs = attention_layer_outputs[1:] if self.use_parallel_residual: @@ -413,7 +457,9 @@ def forward( hidden_states = mlp_output + attn_output if use_cache: - outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + outputs = ( + hidden_states, + ) + outputs # hidden_states, present, (attn_weights) else: outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) @@ -427,12 +473,22 @@ def __init__(self, config, weights): self.num_attention_heads = config.num_attention_heads - self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) - self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) - self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) + self.embed_in = TensorParallelEmbedding( + prefix="gpt_neox.embed_in", weights=weights + ) + self.layers = nn.ModuleList( + [ + GPTNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix="gpt_neox.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, + ) self.tp_world_size = weights.process_group.size() - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -456,15 +512,25 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: @@ -482,7 +548,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, seq_length + past_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -499,7 +567,9 @@ def forward( past_key_values_length = past_key_values[0][0].shape[-1] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) else: attention_mask = attention_mask.to(hidden_states.device) @@ -548,7 +618,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_attentions] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -564,7 +638,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) + self.embed_out = TensorParallelHead.load( + config, prefix="embed_out", weights=weights + ) def forward( self, @@ -619,7 +695,9 @@ def forward( >>> prediction_logits = outputs.logits ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.gpt_neox( input_ids, @@ -645,7 +723,9 @@ def forward( shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -660,7 +740,12 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): input_shape = input_ids.shape @@ -700,6 +785,10 @@ def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past[:2] + ) + + layer_past[2:], ) return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index a4e6249b27d..51862e3cbdb 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -845,7 +845,6 @@ def forward( ), "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) - batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past @@ -1026,7 +1025,9 @@ def __init__(self, config: T5Config, weights): embed_tokens=self.shared, ) - self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights + ) def forward( self, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 846b905196b..5f963bfb547 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -1,28 +1,19 @@ import torch import torch.distributed -from pathlib import Path -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List +from transformers import AutoTokenizer +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_rw_modeling import ( RWConfig, FlashRWForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, Weights, - LocalEntryNotFoundError, ) tracer = trace.get_tracer(__name__) @@ -73,79 +64,3 @@ def __init__( rank=rank, world_size=world_size, ) - - # @staticmethod - # def load_weights( - # model, - # filenames: List[str], - # quantize: Optional[str], - # device: torch.device, - # dtype: torch.dtype, - # rank: int, - # world_size: int, - # ): - # parameters = dict(model.named_parameters()) - # for file in filenames: - # with safe_open( - # file, framework="pt", device=str(device) if quantize is None else "cpu" - # ) as f: - # for name in f.keys(): - # module_name, param_name = name.rsplit(".", 1) - # module = model.get_submodule(module_name) - - # current_parameter_tensor = parameters.get(name, None) - - # slice_ = f.get_slice(name) - - # if isinstance(module, TensorParallelColumnLinear): - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # elif isinstance(module, TensorParallelRowLinear): - # if param_name == "weight": - # size = slice_.get_shape()[1] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[:, start:stop] - # else: - # tensor = slice_[:] - # # XXX: Hack for Rowlinear to add the bias only once. - # if rank != 0: - # tensor = torch.zeros_like(tensor) - # elif isinstance(module, TensorParallelEmbedding): - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # elif name == "lm_head.weight" and model.transformer.tp_embeddings: - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # else: - # try: - # tensor = slice_[:] - # except: - # tensor = f.get_tensor(name) - - # if ( - # current_parameter_tensor is not None - # and current_parameter_tensor.shape != tensor.shape - # ): - # raise ValueError( - # f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - # ) - - # tensor = tensor.contiguous().to(dtype) - - # if current_parameter_tensor is not None: - # module._parameters[param_name] = tensor - # else: - # module._buffers[param_name] = tensor - - # model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a907ee6c27a..01e1c773426 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -182,6 +182,7 @@ def __init__( tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5c854348059..0abf0239e28 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -1,13 +1,10 @@ import torch import torch.distributed -from typing import List, Optional +from typing import Optional -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, ) from text_generation_server.models import CausalLM diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index e844c36f0d2..c89462fc0aa 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -35,7 +35,9 @@ def __init__( device = torch.device("cpu") dtype = torch.float32 - config = AutoConfig.from_pretrained(model_id, revision=revision, + config = AutoConfig.from_pretrained( + model_id, + revision=revision, trust_remote_code=trust_remote_code, ) config.quantize = quantize diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 9443d21b96a..2ed7673c41c 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -10,8 +10,8 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, - EntryNotFoundError, # Import here to ease try/except in other part of the lib - RevisionNotFoundError + EntryNotFoundError, + RevisionNotFoundError, # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 5945f2100cc..ee32a0dc5f8 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,4 +1,5 @@ import torch +import torch.distributed from torch import nn from torch.nn import functional as F @@ -44,14 +45,14 @@ def __init__( else: self.bias = None - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") if bias: bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return FastLinear(weight, bias) + return cls(weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) @@ -130,9 +131,7 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": raise NotImplementedError("Soon") else: - raise NotImplementedError( - f"Quantization `{config.quantize}` is not implemented yet." - ) + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -170,17 +169,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class TensorParallelColumnLinear(SuperLayer): - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return cls(get_linear(weight, bias, config.quantize)) - @staticmethod - def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int): + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -189,7 +188,7 @@ def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int): bias = torch.cat(b, dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return cls(get_linear(weight, bias, config.quantize)) class TensorParallelRowLinear(SuperLayer): @@ -197,15 +196,15 @@ def __init__(self, linear, process_group): super().__init__(linear) self.process_group = process_group - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear( + return cls( get_linear(weight, bias, config.quantize), process_group=weights.process_group, ) @@ -308,22 +307,22 @@ def __init__(self, inv_freq): self._cos_k_cached = None self._sin_k_cached = None - @staticmethod - def static(dim, base, device): + @classmethod + def static(cls, dim, base, device): inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) - return PositionRotaryEmbedding(inv_freq) + return cls(inv_freq) - @staticmethod - def load(prefix, weights): + @classmethod + def load(cls, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype - return PositionRotaryEmbedding(inv_freq) + return cls(inv_freq) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, From c66648d9200f5260a460b0bb91189609ccdeea77 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:58:21 +0200 Subject: [PATCH 26/27] add CARGO_REGISTRIES_CRATES_IO_PROTOCOL --- Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile b/Dockerfile index 42d0147937c..576dab8d410 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,8 @@ FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef WORKDIR /usr/src +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + FROM chef as planner COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml From f245aa0c57bc8dbda772b758e3fa3921d38939d6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 14:15:01 +0200 Subject: [PATCH 27/27] warn on unused snapshot --- integration-tests/pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/integration-tests/pytest.ini b/integration-tests/pytest.ini index 485e601740c..7dcae663076 100644 --- a/integration-tests/pytest.ini +++ b/integration-tests/pytest.ini @@ -1,4 +1,5 @@ [pytest] +addopts = --snapshot-warn-unused asyncio_mode = auto markers = private: marks tests as requiring an admin hf token (deselect with '-m "not private"') \ No newline at end of file