Skip to content

Commit

Permalink
MI300 compatibility (#1764)
Browse files Browse the repository at this point in the history
Adds support for AMD Instinct MI300 in TGI.

Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from pytorch/pytorch#124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
ROCm/vllm@3489ce7
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1

By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.

Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
  • Loading branch information
fxmarty and mht-sharma committed May 17, 2024
1 parent a60fa84 commit 232e8d5
Show file tree
Hide file tree
Showing 29 changed files with 1,326 additions and 179 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ jobs:
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
outputs:
# env is not available in the later `container:`, but previous job outputs are.
short_sha: ${{ env.GITHUB_SHA_SHORT }}
steps:
- name: Checkout repository
uses: actions/checkout@v3
Expand Down Expand Up @@ -392,3 +395,37 @@ jobs:
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

integration-tests-rocm:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image
- integration-tests
- build-and-push-image-rocm
- stop-runner
runs-on: [self-hosted, docker-gpu, amd-gpu, multi-gpu, mi300]
container:
image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
env:
DOCKER_VOLUME: /cache
steps:
- name: ROCM-SMI
run: |
rocm-smi
- name: ROCM-INFO
run: |
rocminfo | grep "Agent" -A 14
- name: Show ROCR environment
run: |
echo "ROCR: $ROCR_VISIBLE_DEVICES"
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
73 changes: 58 additions & 15 deletions Dockerfile_amd
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ COPY launcher launcher
RUN cargo build --release

# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:5.7 as base
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
Expand All @@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev && \
hipblas-dev \
hipblaslt-dev \
rocblas-dev \
hiprand-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3-dev && \
rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7'
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
Expand All @@ -75,12 +86,43 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh

# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir

RUN conda install intel::mkl-static intel::mkl-include
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
pip install .

RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir

ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
USE_ROCM="1" \
BUILD_TEST="0" \
USE_FBGEMM="0" \
USE_NNPACK="0" \
USE_QNNPACK="0" \
USE_XNNPACK="0" \
USE_FLASH_ATTENTION="1" \
USE_MEM_EFF_ATTENTION="0"

RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install

# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1

# On MI300, performances for flash with Triton FA is very competitive (actually better than CK)
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1

FROM base AS kernel-builder

# Build vllm kernels
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src

Expand All @@ -102,21 +144,21 @@ RUN make build-flash-attention-v2-rocm
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
RUN python setup.py build

# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .

RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
RUN python setup.py build

# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .

RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
RUN python setup.py build

FROM base as base-copy

Expand All @@ -140,9 +182,6 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

# Install server
COPY proto proto
COPY server server
Expand All @@ -160,7 +199,8 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher

# AWS Sagemaker compatible image
FROM base-copy as sagemaker
FROM base as sagemaker

COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

Expand All @@ -169,5 +209,8 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy

ENTRYPOINT ["text-generation-launcher"]
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
10 changes: 9 additions & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
title: Text Generation Inference
- local: quicktour
title: Quick Tour
- local: installation_nvidia
title: Using TGI with Nvidia GPUs
- local: installation_amd
title: Using TGI with AMD GPUs
- local: installation_gaudi
title: Using TGI with Intel Gaudi
- local: installation_inferentia
title: Using TGI with AWS Inferentia
- local: installation
title: Installation
title: Installation from source
- local: supported_models
title: Supported Models and Hardware
- local: messages_api
Expand Down
2 changes: 1 addition & 1 deletion docs/source/basic_tutorials/gated_model_access.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \
--model-id $model
```
8 changes: 6 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Installation
# Installation from source

This section explains how to install the CLI tool as well as installing TGI from source. **The strongly recommended approach is to use Docker, as it does not require much setup. Check [the Quick Tour](./quicktour) to learn how to run TGI with Docker.**
<Tip warning={true}>

Installing TGI from source is not the recommended usage. We strongly recommend to use TGI through Docker, check the [Quick Tour](./quicktour), [Installation for Nvidia GPUs](./installation_nvidia) and [Installation for AMD GPUs](./installation_amd) to learn how to use TGI with Docker.

</Tip>

## Install CLI

Expand Down
38 changes: 38 additions & 0 deletions docs/source/installation_amd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Using TGI with AMD GPUs

TGI is supported and tested on [AMD Instinct MI210](https://www.amd.com/en/products/accelerators/instinct/mi200/mi210.html), [MI250](https://www.amd.com/en/products/accelerators/instinct/mi200/mi250.html) and [MI300](https://www.amd.com/en/products/accelerators/instinct/mi300.html) GPUs. The support may be extended in the future. The recommended usage is through Docker. Make sure to check the [AMD documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) on how to use Docker with AMD GPUs.

On a server powered by AMD GPUs, TGI can be launched with the following command:

```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run

docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3-rocm \
--model-id $model
```

The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.

## TunableOp

TGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable), which allows to do an additional warmup to select the best performing matrix multiplication (GEMM) kernel from rocBLAS or hipBLASLt.

Experimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.

TunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED="0"` when launcher TGI's docker container.

## Flash attention implementation

Two implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/flash_attn_triton.py).

By default, as its performances have experimentally been better, Triton implementation is used. It can be disabled (using CK implementation instead) by passing `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.

## Unsupported features

The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Kernel for sliding window attention (Mistral)
3 changes: 3 additions & 0 deletions docs/source/installation_gaudi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Using TGI with Intel Gaudi

Check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index).
3 changes: 3 additions & 0 deletions docs/source/installation_inferentia.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Using TGI with Inferentia

Check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
18 changes: 18 additions & 0 deletions docs/source/installation_nvidia.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Using TGI with Nvidia GPUs

TGI optimized models are supported on NVIDIA [H100](https://www.nvidia.com/en-us/data-center/h100/), [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it.

For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.

TGI can be used on NVIDIA GPUs through its official docker image:

```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run

docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3 \
--model-id $model
```

The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
23 changes: 10 additions & 13 deletions docs/source/quicktour.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@

The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).

Let's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI. Here is an example on how to do that:
## Launching TGI

Let's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI on an Nvidia GPU. Here is an example on how to do that:

```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.0.3 \
--model-id $model
```

<Tip warning={true}>

To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher.
### Supported hardware

</Tip>
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.

TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:

```bash
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
```
## Consuming TGI

Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.


<inferencesnippet>
<python>

Expand Down Expand Up @@ -91,7 +88,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.

```bash
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
docker run ghcr.io/huggingface/text-generation-inference:2.0.3 --help
```

</Tip>
14 changes: 0 additions & 14 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,3 @@ If you wish to serve a supported model that already exists on a local folder, ju
```bash
text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
``````


## Supported Hardware

TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.

TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
* Kernel for sliding window attention (Mistral)

TGI is also supported on the following AI hardware accelerators:
- *Habana first-gen Gaudi and Gaudi2:* check out this [repository](https://github.com/huggingface/tgi-gaudi) to serve models with TGI on Gaudi and Gaudi2 with [Optimum Habana](https://huggingface.co/docs/optimum/habana/index)
* *AWS Inferentia2:* check out this [guide](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) on how to serve models with TGI on Inferentia2.
6 changes: 3 additions & 3 deletions server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6


flash-attention-v2-cuda:
Expand All @@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
flash-attention-v2-rocm:
# Clone flash attention
pip install -U packaging ninja --no-cache-dir
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2

build-flash-attention-v2-rocm: flash-attention-v2-rocm
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
6 changes: 3 additions & 3 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ install-vllm-cuda: build-vllm-cuda
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/fxmarty/vllm-public.git vllm
git clone https://github.com/fxmarty/rocm-vllm.git vllm

build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install

install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true
Expand Down
Loading

0 comments on commit 232e8d5

Please sign in to comment.