-
Notifications
You must be signed in to change notification settings - Fork 17
/
run.sh
130 lines (108 loc) · 4.01 KB
/
run.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# container: docker.io/cphsieh/ruler:0.1.0
# bash run.sh MODEL_NAME BENCHMARK_NAME
if [ $# -ne 2 ]; then
echo "Usage: $0 <model_name> $1 <benchmark_name>"
exit 1
fi
# Root Directories
GPUS="" # GPU size for tensor_parallel.
ROOT_DIR="" # the path that stores generated task samples and model predictions.
MODEL_DIR="" # the path that contains individual model folders from HUggingface.
ENGINE_DIR="" # the path that contains individual engine folders from TensorRT-LLM.
# Model and Tokenizer
source config_models.sh
MODEL_NAME=${1}
MODEL_CONFIG=$(MODEL_SELECT ${MODEL_NAME} ${MODEL_DIR} ${ENGINE_DIR})
IFS=":" read MODEL_PATH MODEL_TEMPLATE_TYPE MODEL_FRAMEWORK TOKENIZER_PATH TOKENIZER_TYPE OPENAI_API_KEY GEMINI_API_KEY AZURE_ID AZURE_SECRET AZURE_ENDPOINT <<< "$MODEL_CONFIG"
if [ -z "${MODEL_PATH}" ]; then
echo "Model: ${MODEL_NAME} is not supported"
exit 1
fi
export OPENAI_API_KEY=${OPENAI_API_KEY}
export GEMINI_API_KEY=${GEMINI_API_KEY}
export AZURE_API_ID=${AZURE_ID}
export AZURE_API_SECRET=${AZURE_SECRET}
export AZURE_API_ENDPOINT=${AZURE_ENDPOINT}
# Benchmark and Tasks
source config_tasks.sh
BENCHMARK=${2}
declare -n TASKS=$BENCHMARK
if [ -z "${TASKS}" ]; then
echo "Benchmark: ${BENCHMARK} is not supported"
exit 1
fi
# Start server (you may want to run in other container.)
if [ "$MODEL_FRAMEWORK" == "vllm" ]; then
python pred/serve_vllm.py \
--model=${MODEL_PATH} \
--tensor-parallel-size=${GPUS} \
--dtype bfloat16 \
--disable-custom-all-reduce \
&
elif [ "$MODEL_FRAMEWORK" == "trtllm" ]; then
python pred/serve_trt.py \
--model_path=${MODEL_PATH} \
&
elif [ "$MODEL_FRAMEWORK" == "sglang" ]; then
python -m sglang.launch_server \
--model-path ${MODEL_PATH} \
--tp ${GPUS} \
--port 5000 \
--enable-flashinfer \
&
# use sglang/test/killall_sglang.sh to kill sglang server if it hangs
fi
# Start client (prepare data / call model API / obtain final metrics)
total_time=0
for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do
RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME}/${BENCHMARK}/${MAX_SEQ_LENGTH}"
DATA_DIR="${RESULTS_DIR}/data"
PRED_DIR="${RESULTS_DIR}/pred"
mkdir -p ${DATA_DIR}
mkdir -p ${PRED_DIR}
for TASK in "${TASKS[@]}"; do
python data/prepare.py \
--save_dir ${DATA_DIR} \
--benchmark ${BENCHMARK} \
--task ${TASK} \
--tokenizer_path ${TOKENIZER_PATH} \
--tokenizer_type ${TOKENIZER_TYPE} \
--max_seq_length ${MAX_SEQ_LENGTH} \
--model_template_type ${MODEL_TEMPLATE_TYPE} \
--num_samples ${NUM_SAMPLES} \
${REMOVE_NEWLINE_TAB}
start_time=$(date +%s)
python pred/call_api.py \
--data_dir ${DATA_DIR} \
--save_dir ${PRED_DIR} \
--benchmark ${BENCHMARK} \
--task ${TASK} \
--server_type ${MODEL_FRAMEWORK} \
--model_name_or_path ${MODEL_PATH} \
--temperature ${TEMPERATURE} \
--top_k ${TOP_K} \
--top_p ${TOP_P} \
${STOP_WORDS}
end_time=$(date +%s)
time_diff=$((end_time - start_time))
total_time=$((total_time + time_diff))
done
python eval/evaluate.py \
--data_dir ${PRED_DIR} \
--benchmark ${BENCHMARK}
done
echo "Total time spent on call_api: $total_time seconds"