Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
812a72c
refactor
mayank31398 Oct 17, 2022
21713ac
refactor
mayank31398 Oct 17, 2022
0ea738a
refactor
mayank31398 Oct 17, 2022
6239fc6
refactor
mayank31398 Oct 17, 2022
01e9515
refactor
mayank31398 Oct 17, 2022
b5a29b8
test
mayank31398 Dec 1, 2022
e4a29b5
test
mayank31398 Dec 1, 2022
48f0aa0
test
mayank31398 Dec 1, 2022
646b63b
test
mayank31398 Dec 1, 2022
3281d16
test
mayank31398 Dec 1, 2022
2fbb6c3
test
mayank31398 Dec 1, 2022
1090704
test
mayank31398 Dec 1, 2022
ef8ec7c
test
mayank31398 Dec 1, 2022
b94ea81
fp32, bf16, int8
mayank31398 Dec 3, 2022
17534fd
fp32, bf16, int8
mayank31398 Dec 3, 2022
e7230b5
fp32, bf16, int8
mayank31398 Dec 3, 2022
38c616b
use_cache
mayank31398 Dec 3, 2022
15a2c80
use_cache
mayank31398 Dec 3, 2022
80ba9bb
gc
mayank31398 Dec 3, 2022
f28f8ac
benchmark
mayank31398 Dec 3, 2022
d04dc14
benchmark
mayank31398 Dec 4, 2022
9dc5268
benchmark
mayank31398 Dec 4, 2022
23a5eb1
fix
mayank31398 Dec 4, 2022
391e055
fix
mayank31398 Dec 4, 2022
856c77b
fix
mayank31398 Dec 4, 2022
9d99f46
fp32
mayank31398 Dec 4, 2022
dfe8cb3
bf16
mayank31398 Dec 4, 2022
7344ae0
bf16
mayank31398 Dec 4, 2022
a4c3b81
ds-inference
mayank31398 Dec 4, 2022
a0f308d
device map
mayank31398 Dec 4, 2022
0947688
device map
mayank31398 Dec 4, 2022
379bfd9
fix
mayank31398 Dec 4, 2022
6dc0c07
fp32
mayank31398 Dec 5, 2022
7dc67ea
bf16
mayank31398 Dec 5, 2022
2ac761d
int8
mayank31398 Dec 5, 2022
28e1e71
attention_type
mayank31398 Dec 5, 2022
b2c7de7
fp32
mayank31398 Dec 5, 2022
76b3b8d
bf16
mayank31398 Dec 6, 2022
c149ee9
fp32
mayank31398 Dec 6, 2022
8427b94
int8
mayank31398 Dec 6, 2022
487954f
fp16
mayank31398 Dec 6, 2022
0253839
total params
mayank31398 Dec 6, 2022
893c521
models
mayank31398 Dec 6, 2022
daea92d
Add code to vary input length (#5)
minimario Dec 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
batch_size := 1

install-mqa-transformers:
git clone https://github.com/bigcode-project/transformers.git; \
cd transformers; \
git checkout mayank/multi_query; \
pip install .; \
cd ..; \
rm -rf transformers;

# BLOOM AliBi
hf-1b-bloom-fp32:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype float32 --batch_size ${batch_size}

hf-1b-bloom-bf16:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype bfloat16 --batch_size ${batch_size}

hf-1b-bloom-int8:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype int8 --batch_size ${batch_size}

ds-inference-1b-bloom-fp16:
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class BLOOM --batch_size ${batch_size}

# GPT2 MHA
hf-1b-GPT2-mha-fp32:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype float32 --batch_size ${batch_size}

hf-1b-GPT2-mha-bf16:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype bfloat16 --batch_size ${batch_size}

hf-1b-GPT2-mha-int8:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype int8 --batch_size ${batch_size}

ds-inference-1b-GPT2-mha-fp16:
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --batch_size ${batch_size}

# GPT2 MQA
hf-1b-GPT2-mqa-fp32:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype float32 --batch_size ${batch_size}

hf-1b-GPT2-mqa-bf16:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype bfloat16 --batch_size ${batch_size}

hf-1b-GPT2-mqa-int8:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype int8 --batch_size ${batch_size}

ds-inference-1b-GPT2-mqa-fp16:
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --batch_size ${batch_size}

# GPT2 MQA1
hf-1b-GPT2-mqa1-fp32:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype float32 --batch_size ${batch_size}

hf-1b-GPT2-mqa1-bf16:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype bfloat16 --batch_size ${batch_size}

hf-1b-GPT2-mqa1-int8:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype int8 --batch_size ${batch_size}

ds-inference-1b-GPT2-mqa1-fp16:
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --batch_size ${batch_size}

# Input length experiments
hf-1b-GPT2-mqa1-int8-input-length:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype int8 --batch_size ${batch_size} --max_input_length ${max_input_length}

hf-1b-GPT2-mha-int8-input-length:
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype int8 --batch_size ${batch_size} --max_input_length ${max_input_length}
123 changes: 122 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,122 @@
# bigcode-inference-benchmark
# bigcode-inference-benchmark
A100 80GB

## BLOOM
```python
hidden_size = 2048
n_head = 16
n_layer = 24
total_params = 1311535104
```

Throughput (tokens/sec | msec/token)
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
|:----------:|:---------------:|:---------------:|:---------------:|:-------------------:|
| 1 | 51.59 \| 19.38 | 47.46 \| 21.07 | 16.53 \| 60.49 | 61.61 \| 16.23 |
| 2 | 103.92 \| 9.62 | 96.88 \| 10.32 | 33.79 \| 29.60 | 121.55 \| 8.23 |
| 4 | 211.96 \| 4.72 | 193.72 \| 5.16 | 67.38 \| 14.84 | 240.06 \| 4.17 |
| 8 | 411.79 \| 2.43 | 370.67 \| 2.70 | 134.34 \| 7.44 | 492.42 \| 2.03 |
| 16 | 804.55 \| 1.24 | 781.29 \| 1.28 | 275.69 \| 3.63 | 970.59 \| 1.03 |
| 32 | 1574.68 \| 0.64 | 1539.19 \| 0.65 | 537.14 \| 1.86 | 1999.04 \| 0.50 |
| 64 | 2712.46 \| 0.37 | 3038.01 \| 0.33 | 1070.50 \| 0.93 | 3971.09 \| 0.25 |
| 128 | 2974.36 \| 0.34 | 5795.97 \| 0.17 | 2055.34 \| 0.49 | 7514.59 \| 0.13 |
| 256 | 3695.44 \| 0.27 | 8216.27 \| 0.12 | 3523.77 \| 0.28 | 10226.50 \| 0.10 |
| 384 | 3591.13 \| 0.28 | 9328.18 \| 0.11 | 4585.33 \| 0.22 | 11094.27 \| 0.09 |
| 512 | 3708.54 \| 0.27 | 9446.34 \| 0.11 | 5416.48 \| 0.18 | 11390.85 \| 0.09 |
| 640 | 3859.43 \| 0.26 | 9572.53 \| 0.10 | 6113.65 \| 0.16 | 11625.71 \| 0.09 |
| 768 | 3804.82 \| 0.26 | 9464.75 \| 0.11 | 6582.52 \| 0.15 | 11814.31 \| 0.08 |
| 896 | 3652.42 \| 0.27 | 9482.11 \| 0.11 | 7111.08 \| 0.14 | 11744.38 \| 0.09 |
| 1024 | oom | 9710.46 \| 0.10 | 7486.36 \| 0.13 | 11534.95 \| 0.09 |
| 1152 | oom | 9712.39 \| 0.10 | 7544.99 \| 0.13 | oom |
| 1280 | oom | 9667.19 \| 0.10 | 7858.91 \| 0.13 | oom |
| 1408 | oom | 9771.91 \| 0.10 | 8116.30 \| 0.12 | oom |
| 1536 | oom | 9744.56 \| 0.10 | 8201.28 \| 0.12 | oom |
| 1664 | oom | 9719.82 \| 0.10 | 8227.56 \| 0.12 | oom |
| 1792 | oom | 9690.61 \| 0.10 | 8344.36 \| 0.12 | oom |
| 1920 | oom | oom | oom | oom |

Latency (sec)
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
|:----------:|:---------:|:---------:|:---------:|:-------------------:|
| 1 | 1.94 | 2.11 | 6.05 | 1.62 |
| 2 | 1.92 | 2.06 | 5.92 | 1.65 |
| 4 | 1.89 | 2.06 | 5.94 | 1.67 |
| 8 | 1.94 | 2.16 | 5.96 | 1.62 |
| 16 | 1.99 | 2.05 | 5.80 | 1.65 |
| 32 | 2.03 | 2.08 | 5.96 | 1.60 |
| 64 | 2.36 | 2.11 | 5.98 | 1.61 |
| 128 | 4.30 | 2.21 | 6.23 | 1.70 |
| 256 | 6.93 | 3.12 | 7.26 | 2.50 |
| 384 | 10.69 | 4.12 | 8.37 | 3.46 |
| 512 | 14.82 | 5.42 | 9.45 | 4.49 |
| 640 | 19.85 | 6.69 | 10.47 | 5.51 |
| 768 | 20.18 | 8.11 | 11.67 | 6.50 |
| 896 | 24.53 | 9.45 | 12.60 | 7.63 |
| 1024 | oom | 10.55 | 13.68 | 8.88 |
| 1152 | oom | 11.86 | 15.27 | oom |
| 1280 | oom | 13.24 | 16.29 | oom |
| 1408 | oom | 14.41 | 17.35 | oom |
| 1536 | oom | 15.76 | 18.73 | oom |
| 1664 | oom | 17.12 | 20.22 | oom |
| 1792 | oom | 18.49 | 21.48 | oom |
| 1920 | oom | oom | oom | oom |

## GPT2 Multi-Head Attention
```python
hidden_size = 2048
n_head = 16
n_layer = 24
total_params = 1315725312
```

Throughput (tokens/sec | msec/token)
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
|:----------:|:---------------:|:----------------:|:----------------:|:-------------------:|
| 1 | 43.11 \| 23.20 | 40.69 \| 24.57 | 32.29 \| 30.97 | 122.76 \| 8.15 |
| 2 | 80.76 \| 12.38 | 80.87 \| 12.37 | 63.54 \| 15.74 | 247.85 \| 4.03 |
| 4 | 160.38 \| 6.24 | 154.98 \| 6.45 | 131.00 \| 7.63 | 503.52 \| 1.99 |
| 8 | 328.62 \| 3.04 | 332.90 \| 3.00 | 260.16 \| 3.84 | 1022.20 \| 0.98 |
| 16 | 662.08 \| 1.51 | 669.27 \| 1.49 | 523.29 \| 1.91 | 2027.35 \| 0.49 |
| 32 | 1314.92 \| 0.76 | 1287.95 \| 0.78 | 1055.57 \| 0.95 | 4231.82 \| 0.24 |
| 64 | 2118.17 \| 0.47 | 2487.35 \| 0.40 | 1969.26 \| 0.51 | 8311.39 \| 0.12 |
| 128 | 2860.26 \| 0.35 | 4268.99 \| 0.23 | 3581.49 \| 0.28 | 15879.15 \| 0.06 |
| 256 | 3487.86 \| 0.29 | 6917.01 \| 0.14 | 6132.47 \| 0.16 | 21635.49 \| 0.05 |
| 384 | 3794.16 \| 0.26 | 8821.31 \| 0.11 | 7774.37 \| 0.13 | 23872.25 \| 0.04 |
| 512 | 3804.37 \| 0.26 | 10068.51 \| 0.10 | 8872.88 \| 0.11 | 25009.06 \| 0.04 |
| 640 | 4124.01 \| 0.24 | 10547.88 \| 0.09 | 9956.58 \| 0.10 | oom |
| 768 | 3950.39 \| 0.25 | 10675.09 \| 0.09 | 10584.21 \| 0.09 | oom |
| 896 | 3937.28 \| 0.25 | 10780.82 \| 0.09 | 10994.00 \| 0.09 | oom |
| 1024 | oom | 11192.55 \| 0.09 | 11306.37 \| 0.09 | oom |
| 1152 | oom | 11178.30 \| 0.09 | 11290.51 \| 0.09 | oom |
| 1280 | oom | 11383.98 \| 0.09 | 11459.89 \| 0.09 | oom |
| 1408 | oom | 11477.66 \| 0.09 | 11565.90 \| 0.09 | oom |
| 1536 | oom | 11382.66 \| 0.09 | 11491.99 \| 0.09 | oom |
| 1664 | oom | 11571.52 \| 0.09 | 11603.73 \| 0.09 | oom |
| 1792 | oom | 11394.20 \| 0.09 | 11412.46 \| 0.09 | oom |
| 1920 | oom | oom | oom | oom |

Latency (sec)
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
|:----------:|:---------:|:---------:|:---------:|:-------------------:|
| 1 | 2.32 | 2.46 | 3.10 | 0.81 |
| 2 | 2.48 | 2.47 | 3.15 | 0.81 |
| 4 | 2.49 | 2.58 | 3.05 | 0.79 |
| 8 | 2.43 | 2.40 | 3.07 | 0.78 |
| 16 | 2.42 | 2.39 | 3.06 | 0.79 |
| 32 | 2.43 | 2.48 | 3.03 | 0.76 |
| 64 | 3.02 | 2.57 | 3.25 | 0.77 |
| 128 | 4.48 | 3.00 | 3.57 | 0.81 |
| 256 | 7.34 | 3.70 | 4.17 | 1.18 |
| 384 | 10.12 | 4.35 | 4.94 | 1.61 |
| 512 | 13.46 | 5.09 | 5.77 | 2.05 |
| 640 | 15.52 | 6.07 | 6.43 | oom |
| 768 | 19.44 | 7.19 | 7.26 | oom |
| 896 | 22.76 | 8.31 | 8.15 | oom |
| 1024 | oom | 9.15 | 9.06 | oom |
| 1152 | oom | 10.31 | 10.20 | oom |
| 1280 | oom | 11.24 | 11.17 | oom |
| 1408 | oom | 12.27 | 12.17 | oom |
| 1536 | oom | 13.49 | 13.37 | oom |
| 1664 | oom | 14.38 | 14.34 | oom |
| 1792 | oom | 15.73 | 15.70 | oom |
| 1920 | oom | oom | oom | oom |
5 changes: 0 additions & 5 deletions benchmark.sh

This file was deleted.

14 changes: 14 additions & 0 deletions run_batch_size.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export CUDA_VISIBLE_DEVICES=0

rm -rf ./tmp

for bs in {1,2,4,8,16,32,64}
do
make $1 batch_size=$bs
done

for i in {1..20}
do
bs=$(($i*128))
make $1 batch_size=$bs
done
8 changes: 8 additions & 0 deletions run_input_length.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export CUDA_VISIBLE_DEVICES=0

rm -rf ./tmp

for max_input_length in {4,8,16,32,64,128,256,512,1024,1536,1900}
do
make $1 batch_size=32 max_input_length=$max_input_length
done
3 changes: 2 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def main() -> None:

args = get_args(get_arg_parser())

inputs = get_dummy_batch(args.batch_size)
inputs = get_dummy_batch(args.batch_size, args.max_input_length)

generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)

pipeline_class = getattr(pipelines, args.pipeline_class)
Expand Down
3 changes: 1 addition & 2 deletions src/pipelines/ds_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import deepspeed
import torch
from transformers import BloomForCausalLM

from .pipeline import Pipeline

Expand All @@ -16,7 +15,7 @@ def __init__(self, args: Namespace) -> None:

# with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
# model = BloomForCausalLM._from_config(config, torch_dtype=torch.bfloat16)
self.model = BloomForCausalLM._from_config(self.config, torch_dtype=torch.bfloat16)
self.model = self.model_class.from_pretrained("tmp", torch_dtype=torch.bfloat16)
self.model.eval()

# checkpoints_json = os.path.join(args.model_name, "checkpoints.json")
Expand Down
5 changes: 3 additions & 2 deletions src/pipelines/hf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from argparse import Namespace

import torch
from transformers import BloomForCausalLM

from .pipeline import Pipeline

Expand All @@ -11,13 +10,15 @@ def __init__(self, args: Namespace, device: str = "cpu") -> None:
super().__init__(args)

model_kwargs = {}

if args.dtype == torch.int8:
model_kwargs["load_in_8bit"] = True
model_kwargs["device_map"] = "auto"
else:
model_kwargs["torch_dtype"] = args.dtype

self.input_device = device
self.model = BloomForCausalLM._from_config(self.config, **model_kwargs).to(self.input_device)
self.model = self.model_class.from_pretrained("tmp", **model_kwargs).to(self.input_device)
self.model.eval()


Expand Down
77 changes: 41 additions & 36 deletions src/pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,14 @@
import os
from argparse import Namespace
from typing import List, Tuple
from typing import List, Tuple, Union

import torch
from transformers import AutoTokenizer, BloomConfig
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, GPT2Config, GPT2LMHeadModel


class Pipeline:
def __init__(self, args: Namespace) -> None:
self.config = BloomConfig.from_dict(
{
"apply_residual_connection_post_layernorm": False,
"architectures": ["BloomModel"],
"attention_dropout": 0.0,
"attention_softmax_in_fp32": True,
"bias_dropout_fusion": True,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_dropout": 0.0,
"hidden_size": args.hidden_size,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"masked_softmax_fusion": True,
"model_type": "bloom",
"n_head": args.n_head,
"n_inner": None,
"n_layer": args.n_layer,
"offset_alibi": 100,
"pad_token_id": 3,
"pretraining_tp": 1,
"skip_bias_add": True,
"skip_bias_add_qkv": False,
"slow_but_exact": False,
"transformers_version": "4.22.2",
"unk_token_id": 0,
"use_cache": True,
"vocab_size": 250880,
}
)

# hardcoded for now to bigscience/bloom
self.tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom")

self.config, self.tokenizer, self.model_class = get_config_tokenizer_model_class(args)
self.model = None
self.input_device = None

Expand Down Expand Up @@ -69,3 +37,40 @@ def get_num_parameters(self) -> int:
for i in self.model.parameters():
param_count += i.numel()
return param_count


def get_config_tokenizer_model_class(args: Namespace) -> Union[BloomConfig, GPT2Config]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if args.model_class.lower() == "bloom":
config = BloomConfig(
attention_softmax_in_fp32=True,
hidden_size=args.hidden_size,
n_head=args.n_head,
n_layer=args.n_layer,
vocab_size=len(tokenizer),
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
model_class = BloomForCausalLM
elif args.model_class.lower() == "gpt2":
config = GPT2Config(
n_embd=args.hidden_size,
n_head=args.n_head,
n_layer=args.n_layer,
n_positions=args.n_positions,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
attention_type=args.attention_type,
print_details=False,
vocab_size=len(tokenizer),
use_cache=True,
)
model_class = GPT2LMHeadModel

if not os.path.exists("tmp"):
model_class._from_config(config).save_pretrained("tmp")

return config, tokenizer, model_class
5 changes: 5 additions & 0 deletions src/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
def get_arg_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--pipeline_class", default="HF_GPU_Pipeline", type=str)
parser.add_argument("--model_class", default="GPT2", type=str)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--dtype", default="bfloat16", type=str)
parser.add_argument("--max_input_length", default=-1, type=int)
parser.add_argument("--max_new_tokens", default=100, type=int)
parser.add_argument("--local_rank", type=int)
parser.add_argument("--hidden_size", type=int)
parser.add_argument("--attention_type", type=int)
parser.add_argument("--n_positions", type=int)
parser.add_argument("--n_head", type=int)
parser.add_argument("--n_layer", type=int)
parser.add_argument("--benchmark_cycles", type=int, default=5)
parser.add_argument("--clear_every_run", action="store_true")
return parser


Expand Down
Loading