Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions scripts/bloom-inference-scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ Throughput in msecs on 8x80GB gpus:
| accelerate int8 | 286.56 | 40.92 | 22.65 | 13.27 | oom | | | |
| ds-inference fp16 | 44.02 | 5.70 | 3.01 | 1.68 | 1.00 | 0.69 | oom | |
| ds-inference int8 | 89.09 | 11.44 | 5.88 | 3.09 | 1.71 | 1.02 | 0.71 | oom |
| ds-zero | 283 | 34.88 | oom | | | | | |
| | | | | | | | | |
| ds-zero bf16 | 283 | 34.88 | oom | | | | | |

note: Since Deepspeed-ZeRO can process multiple generate streams in parallel its throughput can be further divided by 8 or 16, depending on whether 8 or 16 gpus were used during the generate. and, of course, it means that it can process a bs of 64 in the case of 8x80 A100 (the table above).

Start to ready to generate in secs (mainly loading and data preparation time):

Expand All @@ -39,7 +40,6 @@ Throughput in msecs 4x80GB A100:
| :---------------- | :----- | :---- | :---- | :---- | :--- | :--- |
| accelerate int8 | 284.15 | 40.14 | 21.97 | oom | | |
| ds-inference int8 | 156.51 | 20.11 | 10.38 | 5.50 | 2.96 | oom |
| | | | | | | |

To get the benchmark results simply add `--benchmark` to any of these 3 scripts discussed below.

Expand Down Expand Up @@ -145,6 +145,8 @@ Note that the script currently runs the same inputs on all GPUs, but you can run
deepspeed --num_gpus 8 scripts/bloom-inference-scripts/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt
```

Please remember that with ZeRO the user can generate multiple unique streams at the same time - and thus the overall performance should be throughput in secs/token divided by number of participating gpus - so 8x to 16x faster depending on whether 8 or 16 gpus were used!

You can also try the offloading solutions with just one small GPU, which will take a long time to run, but if you don't have 8 huge GPUs this is as good as it gets.


Expand Down
12 changes: 4 additions & 8 deletions scripts/bloom-inference-scripts/bloom-accelerate-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,22 @@ def generate():

return zip(inputs, outputs, total_new_tokens)

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
_ = generate()

print_rank0(f"*** Running generate")
t_generate_start = time.time()
generated = generate()
t_generate_span = time.time() - t_generate_start
for i,o,_ in generated:
print_rank0(f"{'-'*60}\nin={i}\nout={o}\n")


### Benchmark

if args.benchmark:
# clear cache / free memory
torch.cuda.empty_cache()
gc.collect()

### Benchmark

if args.benchmark:
print_rank0(f"*** Running benchmark")

# warm up
for i in range(1):
_ = generate()
Expand Down
2 changes: 2 additions & 0 deletions scripts/bloom-inference-scripts/bloom-ds-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,10 @@ def generate():

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
print_rank0(f"*** Running generate warmup")
_ = generate()

print_rank0(f"*** Running generate")
t_generate_start = time.time()
generated = generate()
t_generate_span = time.time() - t_generate_start
Expand Down
16 changes: 7 additions & 9 deletions scripts/bloom-inference-scripts/bloom-ds-zero-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def print_rank0(*msg):
device="nvme",
pin_memory=True,
nvme_path=args.nvme_offload_path,
buffer_size=6e8
buffer_size=4e9,
)

dschf = HfDeepSpeedConfig(ds_config) # this tells from_pretrained to instantiate directly on gpus
Expand All @@ -130,6 +130,7 @@ def print_rank0(*msg):

if args.benchmark:
t_ready = time.time()
deepspeed.runtime.utils.see_memory_usage('start-of-generate', force=True)


### Generate
Expand Down Expand Up @@ -175,25 +176,22 @@ def generate():

# XXX: this is currently doing world_size streams on world_size gpus, so we can feed it different inputs on each! and hence the time can be divided by world_size

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
_ = generate()

print_rank0(f"*** Running generate")
t_generate_start = time.time()
pairs = generate()
t_generate_span = time.time() - t_generate_start
for i,o,_ in pairs:
print_rank0(f"{'-'*60}\nin={i}\nout={o}\n")


### Benchmark

if args.benchmark:
# clear cache / free memory
torch.cuda.empty_cache()
gc.collect()
deepspeed.runtime.utils.see_memory_usage('end-of-run', force=True)
deepspeed.runtime.utils.see_memory_usage('end-of-generate', force=True)

### Benchmark

if args.benchmark:
print_rank0(f"*** Running benchmark")

# warm up
Expand Down