Skip to content

mlc-ai/llm-perf-bench

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LLM Performance Benchmarking

Performance

All experiments are based on fp16 activation and compute, decoding 256 tokens with a prompt "What is the meaning of life?". And all numbers are based on PCIe, not NVLink.

Single GPU, 4-bit

Model GPU MLC LLM (tok/sec) Exllama V2 (tok/sec) Llama.cpp (tok/sec)
Llama2-7B RTX 3090 Ti 186.7 161.67 144.93
Llama2-13B RTX 3090 Ti 107.4 92.11 86.65
Llama2-7B RTX 4090 204.8 177.46 151.1
Llama2-13B RTX 4090 113.5 105.94 88.0

Multiple NVIDIA GPUs, FP16

Model GPU MLC LLM (tok/sec) Exllama V2 (tok/sec) Llama.cpp (tok/sec) vLLM (tok/sec)
Llama2-70B A100 x 2 17.0 N/A 10.46 15.27
A100 x 4 26.6 N/A 11.07 17.64
A100 x 8 38.8 N/A 9.37 14.32
A10G x 8 21.8 N/A 6.91 13.9
CodeLlama-34B A10G x 4 24.8 N/A 14.37 16.67
A10G x 8 41.3 N/A 11.83 23.5

Exllama doesn't support fp16.

Multiple NVIDIA GPUs, 4-bit

Model GPU MLC-LLM exllama Llama.cpp vLLM
Llama2-70B A100 x 2 40.9 32.64 17.35 21.4
A100 x 4 55.8 30.36 15.45 21.36
A100 x 8 59.4 32.23 11.2 17.6
A10G x 2 19.8 13.48 11.98 12.89
A10G x 4 34.3 13.48 13.37 16.91
A10G x 8 47.7 13.48 8.01 20.79
RTX 4090 x 2 34.5 24.39 17.55 23.8
CodeLlama-34B A10G x 2 38.4 25.86 21.93 23.67
A10G x 4 61.2 25.84 23.53 29.83
A10G x 8 84.2 25.82 13.25 N/A
RTX 4090 x 2 64.9 45.59 31.78 26.16

Multiple AMD GPUs, 4-bit

Model GPU MLC-LLM
Llama2-70B 7900 XTX x 2 29.9
CodeLlama-34B 7900 XTX x 2 56.5

Instructions

Prerequisites

GPU Docker. Before proceeding, make sure you have NVIDIA Docker installed for NVIDIA GPUs. Follow the installation guide at NVIDIA Docker Installation Guide for detailed instructions.

CUDA ROCm
docker run --gpus all \
  nvidia/cuda:12.1.1-devel-ubuntu22.04 nvidia-smi
docker run --device=/dev/kfd --device=/dev/dri   \
           --security-opt seccomp=unconfined     \
           --group-add video \
       rocm/rocm-terminal rocm-smi

Repository Setup. Clone the repository, as all subsequent steps assume you are in the repository root:

git clone https://github.com/mlc-ai/llm-perf-bench
cd llm-perf-bench

Now you are ready to proceed with the next steps in the repository.


MLC LLM

In this section, we use int4 quantized Llama2 as an example.

Step 1. Build Docker image and download pre-quantized weights from HuggingFace, then log into the docker image and activate Python environment:

git lfs install
git clone https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-Llama-2-13b-chat-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-Llama-2-70b-chat-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-Llama-2-70b-chat-hf-q0f16
# git clone https://huggingface.co/mlc-ai/mlc-chat-CodeLlama-7b-Instruct-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-CodeLlama-13b-Instruct-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-CodeLlama-34b-Instruct-hf-q4f16_1
# git clone https://huggingface.co/mlc-ai/mlc-chat-CodeLlama-34b-Instruct-hf-q0f16
CUDA ROCm
docker build --no-cache -t llm-perf-mlc:v0.1    \
    -f ./docker/Dockerfile.cu121.mlc .
./docker/bash.sh llm-perf-mlc:v0.1
docker build --no-cache -t llm-perf-mlc:v0.1    \
    -f ./docker/Dockerfile.rocm57.mlc .
./docker/bash.sh --amd llm-perf-mlc:v0.1

Step 2. Stay logged in, set some basic environment variables for convenient scripting.

conda activate python311
MODEL_NAME=Llama-2-7b-chat-hf
QUANTIZATION=q4f16_1
NUM_SHARDS=1
PATH_COMPILE=/tmp/model/
PATH_TEST=/tmp/test/

MODEL_CONFIG=./model_configs/${MODEL_NAME}.json
WEIGHT_PATH=$(pwd)/mlc-chat-${MODEL_NAME}-${QUANTIZATION}/

if [ -e "$WEIGHT_PATH/mlc-chat-config.json" ]; then
	sed -i "/\"num_shards\"/c\ \"num_shards\": ${NUM_SHARDS}," $WEIGHT_PATH/mlc-chat-config.json
else
	echo "Path '$WEIGHT_PATH/mlc-chat-config.json' does not exist."
	exit
fi

rm -rf $PATH_TEST && mkdir $PATH_TEST && rm -rf $PATH_COMPILE && mkdir $PATH_COMPILE && ln -s ${WEIGHT_PATH} ${PATH_TEST}/params && cp $MODEL_CONFIG $PATH_COMPILE/config.json

Step 3. Stay logged in, and compile MLC model lib. It may take a few seconds:

CUDA ROCm
python -m mlc_llm.build \
	--model $PATH_COMPILE \
	--artifact-path $PATH_COMPILE \
	--quantization $QUANTIZATION \
	--max-seq-len 2048 \
	--num-shards $NUM_SHARDS \
	--target cuda --use-cuda-graph --build-model-only
mv $PATH_COMPILE/model-${QUANTIZATION}/model-${QUANTIZATION}-cuda.so \
                    $PATH_TEST/${MODEL_NAME}-${QUANTIZATION}-cuda.so
python -m mlc_llm.build \
	--model $PATH_COMPILE \
	--artifact-path $PATH_COMPILE \
	--quantization $QUANTIZATION \
	--max-seq-len 2048 \
	--num-shards $NUM_SHARDS \
	--target rocm --build-model-only
mv $PATH_COMPILE/model-${QUANTIZATION}/model-${QUANTIZATION}-rocm.so \
                    $PATH_TEST/${MODEL_NAME}-${QUANTIZATION}-rocm.so

Step 4. Stay logged in, and run benchmarking:

CUDA ROCm
python -m mlc_chat.cli.benchmark \
	--model ${PATH_TEST}/params \
	--device "cuda:0" \
	--prompt "What is the meaning of life?" \
	--generate-length 256
python -m mlc_chat.cli.benchmark \
	--model ${PATH_TEST}/params \
	--device "rocm:0" \
	--prompt "What is the meaning of life?" \
	--generate-length 256

Exllama V2

In this section, we use Llama2 GPTQ model as an example.

Step 1. Build Docker image and download pre-quantized weights from HuggingFace, then log into the docker image and activate Python environment:

git lfs install
git clone https://huggingface.co/TheBloke/Llama-2-7B-GPTQ
# git clone https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ
# git clone https://huggingface.co/TheBloke/CodeLlama-34B-Instruct-GPTQ

docker build --no-cache -t llm-perf-exllama-v2:v0.1    \
    -f ./docker/Dockerfile.cu121.exllama_v2 .
./docker/bash.sh llm-perf-exllama-v2:v0.1
conda activate python311

NOTE. Docker image building for ExllamaV2 is particularly memory consuming on certain GPU instances. Kill the process in time if it lags or screen freezes.

Step 2. Stay logged in, run benchmarking

For single GPU:

MODEL_PATH=/workspace/Llama-2-7B-GPTQ/
OUTPUT_LEN=256
cd /exllamav2
python test_inference.py -m $MODEL_PATH -p "What is the meaning of life?" -t $OUTPUT_LEN

For Multiple GPU:

MODEL_PATH=$(pwd)/Llama-2-7B-GPTQ/
OUTPUT_LEN=256
GPU_SPLIT="17,17" # depend on how you want to split memory
cd /exllamav2
python test_inference.py -m $MODEL_PATH -p "What is the meaning of life?" -gs $GPU_SPLIT -t $OUTPUT_LEN

Llama.cpp

Step 1. Build Docker image:

docker build --no-cache -t llm-perf-llama-cpp:v0.1 -f ./docker/Dockerfile.cu121.llama_cpp .

Step 2. Download the quantized GGML models from HuggingFace:

mkdir -p ./llama_cpp_models
wget -O ./llama_cpp_models/llama-2-7b-chat.Q4_0.gguf https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf
wget -O ./llama_cpp_models/llama-2-70b-chat.Q4_0.gguf https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF/resolve/main/llama-2-70b-chat.Q4_0.gguf
wget -O ./llama_cpp_models/codellama-34b.Q4_0.gguf https://huggingface.co/TheBloke/CodeLlama-34B-GGUF/resolve/main/codellama-34b.Q4_0.gguf
# wget -O ./llama_cpp_models/llama-2-13b-chat.Q4_0.gguf https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q4_0.gguf
# wget -O ./llama_cpp_models/llama-2-70b-chat.Q4_0.gguf https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF/resolve/main/llama-2-70b-chat.Q4_0.gguf

Step 3. Log into docker, and the CLI tool to see the performance numbers. Note that modify CUDA_VISIBLE_DEVICES settings for different numbers of GPUs experiments.

./docker/bash.sh llm-perf-llama-cpp:v0.1
cd /llama.cpp
# run quantized Llama-2-7B models on a single GPU.
CUDA_VISIBLE_DEVICES=0 ./build/bin/main -m /workspace/llama_cpp_models/llama-2-7b-chat.Q4_0.gguf -p "What is the meaning of life?" -n 256 -ngl 999 --ignore-eos
# test quantized Llama-2-70B models on 2 GPUS.
CUDA_VISIBLE_DEVICES=0,1 ./build/bin/main -m /workspace/llama_cpp_models/llama-2-70b-chat.Q4_0.gguf -p "What is the meaning of life?" -n 256 -ngl 999 --ignore-eos

Note. For float16 models, stay logged in and convert the hf models (download here) to GGUF FP16 format first.

cd /llama.cpp
conda activate python311
# convert the weight using llama.cpp script
python3 convert.py /path/to/Llama-2-70b-hf/ \
    --outfile /workspace/llama_cpp_models/llama-2-70b.fp16.gguf
# run fp16 models on 4 GPUs.
CUDA_VISIBLE_DEVICES=0,1,2,3 ./build/bin/main -m /workspace/llama_cpp_models/llama-2-70b.fp16.gguf -p "What is the meaning of life?" -n 256 -ngl 999 --ignore-eos

HuggingFace Transformer

Step 1. Build Docker image:

docker build -t llm-perf-hf:v0.1 -f ./docker/Dockerfile.cu121.hf .

Step 2. Download Llama-2 weight from huggingface.

git lfs install
git clone https://huggingface.co/meta-llama/Llama-2-7b-hf
# git clone https://huggingface.co/meta-llama/Llama-2-13b-hf
# git clone https://huggingface.co/meta-llama/Llama-2-70b-hf

Step 3. Log into docker and run the python script to see the performance numbers. Note that modify CUDA_VISIBLE_DEVICES settings for different numbers of GPUs experiments:

./docker/bash.sh llm-perf-hf:v0.1
conda activate python311
# run fp16 Llama-2-7b models on a single GPU.
CUDA_VISIBLE_DEVICES=0 python scripts/benchmark_hf.py --model-path ./Llama-2-7b-hf --format q0f16 --prompt "What is the meaning of life?" --max-new-tokens 256
# run int 4 quantized Llama-2-70b model on two GPUs.
CUDA_VISIBLE_DEVICES=0,1 python scripts/benchmark_hf.py --model-path ./Llama-2-70b-hf --format q4f16 --prompt "What is the meaning of life?" --max-new-tokens 256

vLLM

In this section, we use Llama2 GPTQ model as an example.

Step 1. Build Docker image and download pre-quantized weights from HuggingFace, then log into the docker image and activate Python environment:

git lfs install
git clone https://huggingface.co/TheBloke/Llama-2-7B-fp16
# You can also git clone awq models, e.g.
# git clone https://huggingface.co/TheBloke/Llama-2-70B-AWQ
docker build --no-cache -t llm-perf-vllm:v0.1    \
    -f ./docker/Dockerfile.cu118.vllm .
./docker/bash.sh llm-perf-vllm:v0.1
conda activate python311

Step 2. Modify script and run benchmarking

To skip limitation of max number of batched tokens, we can use the following script to skip argument verification, and make the benchmark results more readable:

sed -i '287s/self._verify_args()/# self._verify_args()/' /vllm/vllm/config.py
sed -i '63i\    print(f"Speed: {args.output_len / np.mean(latencies):.2f} tok/s")' /vllm/benchmarks/benchmark_latency.py
sed -i '64i\    print(f"Speed: {np.mean(latencies)/ args.output_len:.5f} s/tok")' /vllm/benchmarks/benchmark_latency.py

To benchmark fp16 performance:

MODEL_PATH=/workspace/Llama-2-7B-fp16/
OUTPUT_LEN=256
GPU_NUM=1
cd /vllm && python benchmarks/benchmark_latency.py \
--model $MODEL_PATH \
--output-len $OUTPUT_LEN \
--tensor-parallel-size $GPU_NUM \
--batch-size 1 \
--input-len 7 # for prompt "What is the meaning of life?"

And for 4-bit AWQ model:

MODEL_PATH=/workspace/Llama-2-7B-AWQ/
OUTPUT_LEN=256
GPU_NUM=1
cd /vllm && python benchmarks/benchmark_latency.py \
--model $MODEL_PATH \
--output-len $OUTPUT_LEN \
--tensor-parallel-size $GPU_NUM \
--batch-size 1 \
--quantization awq \
--input-len 7 # for prompt "What is the meaning of life?"

Setup Details

We are using the following commits:

  • MLC LLM commit, TVM commit on 10/04/2023;
  • ExllamaV2 commit on 10/05/2023;
  • Llama.cpp commit on 10/02/2023;
  • vLLM commit on 10/06/2023;
  • HuggingFace transformers 4.33.3 on 10/06/2023.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published