-
Notifications
You must be signed in to change notification settings - Fork 243
/
run_tuning.sh
92 lines (82 loc) · 2.27 KB
/
run_tuning.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/bin/bash
set -x
function main {
init_params "$@"
run_tuning
}
# init params
function init_params {
tuned_checkpoint=saved_results
for var in "$@"
do
case $var in
--topology=*)
topology=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
tuned_checkpoint=$(echo $var |cut -f2 -d=)
;;
*)
echo "Error: No such parameter: ${var}"
exit 1
;;
esac
done
}
# run_tuning
function run_tuning {
extra_cmd=''
batch_size=16
MAX_SEQ_LENGTH=128
model_type='bert'
approach='post_training_dynamic_quant'
if [ "${topology}" = "distilbert_base_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='distilbert'
elif [ "${topology}" = "albert_base_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='albert'
elif [ "${topology}" = "funnel_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='funnel'
elif [ "${topology}" = "mbart_WNLI" ]; then
TASK_NAME='WNLI'
model_name_or_path=$input_model
model_type='mbart'
elif [ "${topology}" = "transfo_xl_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='transfo-xl-wt103'
elif [ "${topology}" = "ctrl_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='ctrl'
elif [ "${topology}" = "xlm_roberta_MRPC" ]; then
TASK_NAME='MRPC'
model_name_or_path=$input_model
model_type='xlm'
fi
sed -i "/: bert/s|name:.*|name: $model_type|g" conf.yaml
sed -i "/approach:/s|approach:.*|approach: $approach|g" conf.yaml
python -u ./run_glue_tune.py \
--model_name_or_path ${model_name_or_path} \
--task_name ${TASK_NAME} \
--do_eval \
--do_train \
--max_seq_length ${MAX_SEQ_LENGTH} \
--per_device_eval_batch_size ${batch_size} \
--no_cuda \
--output_dir ${tuned_checkpoint} \
--tune \
${extra_cmd}
}
main "$@"