Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[deepspeed] bigscience/T0* multi-gpu inference with ZeRO #15399

Closed
2 of 4 tasks
AADeLucia opened this issue Jan 28, 2022 · 57 comments
Closed
2 of 4 tasks

[deepspeed] bigscience/T0* multi-gpu inference with ZeRO #15399

AADeLucia opened this issue Jan 28, 2022 · 57 comments
Assignees

Comments

@AADeLucia
Copy link

AADeLucia commented Jan 28, 2022

Environment info

  • transformers version: 4.17.0.dev0
  • Platform: Linux-5.13.0-27-generic-x86_64-with-glibc2.10
  • Python version: 3.8.0
  • PyTorch version (GPU?): 1.10.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes (deepspeed)
  • Note: I installed DeepSpeed from source

Who can help

Models:
(I'm actually trying to use T0pp but T5 is close enough)

Library:

Information

Model I am using (Bert, XLNet ...): T0pp / T0_3B

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

I want to load T0pp across 2 24GB GPUs and only run inference. I know Deepspeed wit zeRO stage 3 is the way to go for this from reading documentation. I am following the HuggingFace example here to use Deepspeed without a Trainer object.

The error I get is

[2022-01-28 18:36:41,193] [INFO] [partition_parameters.py:456:__exit__] finished initializing model with 2.85B parameters
Traceback (most recent call last):
  File "multi_gpu_T0pp.py", line 26, in <module>
    engine = deepspeed.initialize(model=model, config_params=ds_config)
AttributeError: module 'transformers.deepspeed' has no attribute 'initialize'

My code:

Run with CUDA_VISIBLE_DEVICES="0,1" deepspeed <script.py>

"""
Example code to load a PyTorch model across GPUs
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import deepspeed
import pandas as pd
import torch
import pdb
import os

seed = 42
torch.manual_seed(seed)

ds_config = {
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 0,
    "steps_per_print": 2000,
    "train_batch_size": 2,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false
}

if __name__ == "__main__":
    # must run before instantiating the model
    # ds_config is deepspeed config object or path to the file
    dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

    model_name = "bigscience/T0_3B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    engine = deepspeed.initialize(model=model, config_params=ds_config)

    inputs = tokenizer.encode(
        "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy",
        return_tensors="pt")
    outputs = model.generate(inputs)
    print(tokenizer.decode(outputs[0]))

Expected behavior

T0pp (or T0_3B) to load across 2 GPUs, generate an answer, and then quit.

@stas00
Copy link
Contributor

stas00 commented Jan 29, 2022

My apologies, it looks like I wrote wrong instructions for the non-HF Trainer case here: https://huggingface.co/docs/transformers/master/main_classes/deepspeed#nontrainer-deepspeed-integration - is that where you found this code or in another place. I'm asking so that we ensure it's fixed everywhere.

It should be just import deepspeed instead of from transformers import deepspeed - but let me double check that it all works.

@stas00 stas00 self-assigned this Jan 29, 2022
@stas00
Copy link
Contributor

stas00 commented Jan 29, 2022

ok, so indeed the import was wrong. I will fix the doc at https://huggingface.co/docs/transformers/master/main_classes/deepspeed#nontrainer-deepspeed-integration => #15400

But where did you take the rest of the code from? it can't possibly work.

You may want to look into using / adapting https://github.com/huggingface/transformers/blob/master/examples/pytorch/text-generation/run_generation.py - I see it doesn't currently support t5 models.

@VictorSanh, you were working on a multi-gpu generation with t0 models, what's the latest incarnation of the code that you were using if you don't mind sharing. I think it was with Deepspeed-Inference, right? Thanks.

Perhaps examples/pytorch/text-generation/run_generation.py could be updated to include support for T0 models?

@AADeLucia
Copy link
Author

AADeLucia commented Jan 29, 2022

It should be just import deepspeed instead of from transformers import deepspeed - but let me double check that it all works.

That works! Now running into a different issue, figuring out the default config arguments to change.

My apologies, it looks like I wrote wrong instructions for the non-HF Trainer case here: https://huggingface.co/docs/transformers/master/main_classes/deepspeed#nontrainer-deepspeed-integration - is that where you found this code or in another place. I'm asking so that we ensure it's fixed everywhere.

That was the only place I found that line.

But where did you take the rest of the code from? it can't possibly work.

Which part of the code can't work? The T0pp/T0_3B is just from the model card: https://huggingface.co/bigscience/T0pp

Updated code:

"""
Example code to load a PyTorch model across GPUs
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import pandas as pd
import torch
import pdb
import os

seed = 42
torch.manual_seed(seed)


if __name__ == "__main__":
    # must run before instantiating the model
    # ds_config is deepspeed config object or path to the file
    ds_config = "ds_config_zero3_gpu.json"
    dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

    model_name = "bigscience/T0_3B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    engine = deepspeed.initialize(model=model, config_params=ds_config)

    inputs = tokenizer.encode(
        "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy",
        return_tensors="pt")
    outputs = model.generate(inputs)
    print(tokenizer.decode(outputs[0]))

I moved the config file outside because I was getting weird errors:

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": 1,
        "stage3_prefetch_bucket_size": 1,
        "stage3_param_persistence_threshold": 1,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 0,
    "steps_per_print": 2000,
    "train_batch_size": 1,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false
}

New error:

    self._configure_train_batch_size()
  File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 1050, in _configure_train_batch_size
    self._batch_assertion()
  File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 997, in _batch_assertion
    assert train_batch == micro_batch * grad_acc * self.world_size, (
AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size1 != 1 * 1 * 2
[2022-01-28 22:06:24,208] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 643653

Now just playing with the arguments... I'm not even training, I just want to run inference.

@AADeLucia
Copy link
Author

Looking in the HfTrainerDeepSpeedConfig for some argument clues: https://github.com/huggingface/transformers/blob/v4.16.1/src/transformers/deepspeed.py#L206

Every time I change train_batch_size, it's still broken. I change it to 2, it says it's supposed to be 1, I change it to 1, it's supposed to be 2!

    self._batch_assertion()
  File "/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/lib/python3.8/site-packages/deepspeed/runtime/config.py", line 997, in _batch_assertion
    assert train_batch == micro_batch * grad_acc * self.world_size, (
AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size2 != 1 * 1 * 1

@stas00
Copy link
Contributor

stas00 commented Jan 29, 2022

Inference is a relatively new thing, I think until now most work was done on training, so please bear with us. Lots of tech is being developed as we speak and it's being polished to be super easy and fast.

Until then, let's focus on something that works now.

So this works with gpt2

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GPT2LMHeadModel
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

# model_name = "bigscience/T0_3B"

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

ds_config = {
    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": True
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 0,
    "steps_per_print": 2000,
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive
engine = deepspeed.initialize(model=model, config_params=ds_config)

text = "Is this review " 

inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)
outputs = model.generate(inputs)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    print(tokenizer.decode(outputs[0]))

run:

$ deepspeed --num_gpus 2 gpt2.py
[...]
Is this review  of the book?
I'm not sure if I'm going to read

With a small adjustment it works with t5:

from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

model_name = "t5-base"
#model_name = "bigscience/T0_3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

ds_config = {
    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": True
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 0,
    "steps_per_print": 2000,
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive
engine = deepspeed.initialize(model=model, config_params=ds_config)

text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"

inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)
outputs = model.generate(inputs)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    print(tokenizer.decode(outputs[0]))

run:

$ deepspeed --num_gpus 2 t0.py
[...]
<pad> True</s>

but if I switch the t5 model in the example above to "bigscience/T0_3B" it generates gibberish under Deepspeed but works fine w/o Deepspeed.

<pad> d is is is is is is is is is is is is is is is is is

This is very puzzling

@stas00
Copy link
Contributor

stas00 commented Jan 29, 2022

OK, I figured out the culprit - the model breaks when run under fp16! like many other bf16-pretrained models - most t5 models have this issue.

Here are 2 possible solutions:

  1. So if you're on Ampere GPU (A100, RTX-30*) you can use bf16 like so:
    "bf16": {
        "enabled": True,
    },

(and as of this moment deepspeed@master is needed to use bf16 - they will make a new release any day now)

  1. or you have to turn off fp16, like so:
    "fp16": {
        "enabled": False,
    },

Now you're running in fp32 (more memory).

So this works with either of the 2 fixes from above:

from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

model_name = "bigscience/T0_3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

ds_config = {
    "fp16": {
        "enabled": False,
    },    
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
    },
    "steps_per_print": 2000,
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive
engine = deepspeed.initialize(model=model, config_params=ds_config, optimizer=None, lr_scheduler=None)

text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"

inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)
outputs = model.generate(inputs)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    print(tokenizer.decode(outputs[0]))

run:

$ deepspeed --num_gpus 2 t0.py
[...]
<pad> Positive</s>

Note that in the code above I pass optimizer=None, lr_scheduler=None which saves a huge amount of GPU memory as you don't need those for inference. And also in the ds config I enabled cpu offloading which nicely uses some of your CPU memory to aid with shortages of GPU memory.

The ds config needs more work to become efficient if you plan to use this in production or something where you care for speed. Since you're not using the HF Integration you will have to put the right numbers together, hint:

hidden_size = model.config.hidden_size
self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size)
self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size)

If you don't care to shave a few %s off, then leave the above as is and it'll use the Deepspeed untuned defaults.

Please let me know if you're able to see it working for yourself.

If you have any questions please ask.

If all is satisfactory you may close this Issue.

As I said in the previous comment the whole inference experience should get much much better really soon now.

@AADeLucia
Copy link
Author

It's working! Thank you SO MUCH!!! I did have to use all the space-saving tips (bf16 and changing the defaults) because I want the entire model on GPU without off-loading parameters to CPU.

Here is the completed code:

"""
Example code to load a PyTorch model across GPUs
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import torch
import pdb
import os

seed = 42
torch.manual_seed(seed)

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
model_hidden_size = 4096  # this is hard-coded to T0pp

ds_config = {
    "fp16": {
        "enabled": False,
    },
    "bf16": {
        "enabled": True,
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "steps_per_print": 2000,
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}


if __name__ == "__main__":
    # must run before instantiating the model
    # ds_config is deepspeed config object or path to the file
    dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

    model_name = "bigscience/T0pp"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.eval()

    engine = deepspeed.initialize(model=model, config_params=ds_config, optimizer=None, lr_scheduler=None)

    text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
    inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)

    outputs = model.generate(inputs)
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        print(tokenizer.decode(outputs[0], skip_special_tokens=True))

@stas00
Copy link
Contributor

stas00 commented Jan 30, 2022

So glad it worked for you, Alexandra. But you can do better.

Currently, with this code each gpu processes the same input, i.e. duplicated effort, but you can do parallel processing at 0 extra cost.

You just need to change the end to something like:

    rank = torch.distributed.get_rank()
    if rank == 0:
        text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
    elif rank == 1:
        text = "Is this review positive or negative? Review: this is the worst restaurant ever"

    inputs = tokenizer.encode(text, return_tensors="pt").to(device=local_rank)
    outputs = model.generate(inputs)
    print(f"rank{rank}", tokenizer.decode(outputs[0], skip_special_tokens=True))

(the code is untested)

You can of course do that for more than 2 gpus, and each gpu will handle its own unique input.

And of course, you can do batches too if you have enough memory left.

Beware the multiple-processes prints tend to interleave - so you can use the following hack to overcome this issue:
https://github.com/stas00/toolbox/blob/master/pytorch/multi-gpu-non-interleaved-print.py

This Issue would be a good example for the DYI Deepspeed integration inference docs.

@stas00
Copy link
Contributor

stas00 commented Jan 30, 2022

Also please add the distributed init, so that the logging knows to not repeat the same logs for more than 1 gpu:

deepspeed.init_distributed()
engine = deepspeed.initialize(model=model, config_params=ds_config, optimizer=None, lr_scheduler=None)

@stas00
Copy link
Contributor

stas00 commented Jan 30, 2022

OK, here is a much improved program which also integrates some enhancements from @VictorSanh's work.

This script can now handle both cpu offload and/or multiple gpu.

e.g. I can process "bigscience/T0_3B" on a 8GB GPU no problem.

I added a bunch of notes before and through the code - please let me know if anything is unclear or missing:

#!/usr/bin/env python

# This script demonstrates how to use Deepspeed ZeRO in an inference mode when one can't fit a model
# into a single GPU
#
# 1. Use 1 GPU with CPU offload
# 2. Or use multiple GPUs instead
#
# First you need to install deepspeed: pip install deepspeed
#
# Here we use a 3B "bigscience/T0_3B" model which needs about 15GB GPU RAM - so 1 largish or 2
# small GPUs can handle it. or 1 small GPU and a lot of CPU memory.
#
# To use a larger model like "bigscience/T0" which needs about 50GB, unless you have an 80GB GPU -
# you will need 2-4 gpus. And then you can adapt the script to handle more gpus if you want to
# process multiple inputs at once.
#
# The provided deepspeed config also activates CPU memory offloading, so chances are that if you
# have a lot of available CPU memory and you don't mind a slowdown you should be able to load a
# model that doesn't normally fit into a single GPU. If you have enough GPU memory the program will
# run faster if you don't want offload to CPU - so disable that section then.
#
# To deploy on 1 gpu:
#
# deepspeed --num_gpus 1 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=1 t0.py
#
# To deploy on 2 gpus:
#
# deepspeed --num_gpus 2 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=2 t0.py


from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers

# distributed setup
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()

model_name = "bigscience/T0_3B"

config = AutoConfig.from_pretrained(model_name)
model_hidden_size = config.d_model

# batch size has to be divisible by world_size, but can be bigger than world_size
train_batch_size = 1 * world_size

# ds_config notes
#
# - enable bf16 if you use Ampere or higher GPU - this will run in mixed precision and will be
# faster.
#
# - for older GPUs you can enable fp16, but it'll only work for non-bf16 pretrained models - e.g.
# all official t5 models are bf16-pretrained
#
# - set offload_param.device to "none" or completely remove the `offload_param` section if you don't
# - want CPU offload
#
# - if using `offload_param` you can manually finetune stage3_param_persistence_threshold to control
# - which params should remain on gpus - the larger the value the smaller the offload size
#
# For indepth info on Deepspeed config see
# https://huggingface.co/docs/transformers/master/main_classes/deepspeed
ds_config = {
    "fp16": {
        "enabled": False,
    },
    "bf16": {
        "enabled": False,
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "steps_per_print": 2000,
    "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive

# now a model can be loaded.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# we are ready to initialise deepspeed ZeRO now
ds_engine = deepspeed.initialize(model=model,
                                 config_params=ds_config,
                                 model_parameters=None,
                                 optimizer=None,
                                 lr_scheduler=None)[0]
ds_engine.module.eval() # inference

# Deepspeed ZeRO can process unrelated inputs on each GPU. So for 2 gpus you process 2 inputs at once.
# If you use more GPUs adjust for more.
# And of course if you have just one input to process you then need to pass the same string to both gpus
# If you use only one GPU, then you will have only rank 0.
rank = torch.distributed.get_rank()
if rank == 0:
    text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"

tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")

Let's take it for a run:

$ deepspeed --num_gpus 2 t0.py
rank0:
   in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
  out=Positive
rank1:
   in=Is this review positive or negative? Review: this is the worst restaurant ever
  out=negative

@AADeLucia
Copy link
Author

I tested the provided script with bigscience/T0_3B and it worked as expected. However when I run it with bigscience/T0pp I get the following output:

[2022-01-30 20:09:26,784] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 861020
[2022-01-30 20:09:26,784] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 861021
[2022-01-30 20:09:26,784] [ERROR] [launch.py:184:sigkill_handler] ['/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/bin/python', '-u', 'hf_zero_example.py', '--local_rank=1'] exits with return code = -9

I tried with only using rank 0 (just not doing anything with rank 1) and the same thing happened.

And when I change bf16 to True, only rank 0 runs and then the program hangs. I see it's still on GPU 1 but exits out of GPU 0. Rank 1 never finishes and the process just hangs.

rank0:
   in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
  out=Positive
[2022-01-30 18:42:22,779] [INFO] [launch.py:210:main] Process 775166 exits successfully.
Sun Jan 30 19:34:16 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.86       Driver Version: 470.86       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   56C    P8    24W / 350W |     70MiB / 24260MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:21:00.0 Off |                  N/A |
| 49%   65C    P2   155W / 350W |  14083MiB / 24268MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1491      G   /usr/lib/xorg/Xorg                 56MiB |
|    0   N/A  N/A      1696      G   /usr/bin/gnome-shell                9MiB |
|    1   N/A  N/A      1491      G   /usr/lib/xorg/Xorg                  4MiB |
|    1   N/A  N/A    775167      C   ..._cersi_tobacco/bin/python    14033MiB |
+-----------------------------------------------------------------------------+

And then with bf16=True and passing the same input to both GPUs, it works.

rank = torch.distributed.get_rank()
text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
if rank == 0:
    print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")
rank0:
   in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
  out=Positive
[2022-01-30 20:19:47,097] [INFO] [launch.py:210:main] Process 861553 exits successfully.
[2022-01-30 20:19:48,099] [INFO] [launch.py:210:main] Process 861552 exits successfully.

Also,

To use a larger model like "bigscience/T0" which needs about 50GB

Why does T0 need 50GB? The model plus the vocab/config files is ~42GB according to the card. I have two 24GB GPUs (48GB total) and I was hoping to put the entire model on both GPUs without offloading to CPU, but I seem to run out of space, even with the ZeRO optimizations you suggested. Is there really ~8GB of extras loaded onto the GPU? (I only expect to fit batch size of 1)


Summary:

  • The parallelism doesn't seem to work with T0
  • What is taking up space on the GPUs?

@stas00
Copy link
Contributor

stas00 commented Jan 31, 2022

Ah right! I keep forgetting this not HF Trainer integration, so everything has to be done manually. In the HF Trainer integration it's all done already and you don't need to think about any of this.

Please change the code to:

outputs = ds_engine.module.generate(inputs, synced_gpus=True)

synced_gpus: Optional[bool] = None,

I fixed the example above.

All gpus have to work in sync even if their output is shorter than other gpu, which is what may happen when inputs are different. W/o sync if one gpu finished early the whole ensemble hangs because each gpu has a shard of a model and other gpus depend on it. and the gpus gather the missing shards in pre-forward call. So if one gpu stopped, the rest can't continue.

when you use the same input, it automatically syncs the gpus because all gpus finish at the same time.


What is taking up space on the GPUs?

We are now talking Inference only:

  1. cuda kernels - 1-2GB per gpu
  2. and then it also depends on mixed precision vs fp32. if fp32 you have params * 4 bytes = 44GB. In mixed precision I actually need to measure for inference with because mixed precision in training actually costs more memory than fp32 in some parts - 4+2 bytes per param, but then saved in other parts of the program path - at the point of matrix multiply.
  3. memory taken by activations and temps - these depend primarily on seqlen and batch size hence I gave very rough estimations.
  4. additionally the setting for generate can make an additional memory demand - depending on the size of the beam search for example

For training please see:
https://huggingface.co/docs/transformers/performance#anatomy-of-models-memory

@AADeLucia
Copy link
Author

Thank you, it works with T0pp now!

And thanks for the memory analysis. My confusion was because I kept getting this cryptic error when I ran the provided script without offloading to CPU:

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
[2022-01-31 17:33:28,620] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 883030
[2022-01-31 17:33:28,621] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 883031
[2022-01-31 17:33:28,621] [ERROR] [launch.py:184:sigkill_handler] ['/home/aadelucia/miniconda3/envs/fda_cersi_tobacco/bin/python', '-u', 'hf_zero_example.py', '--local_rank=1'] exits with return code = 1

My guess is it is a weird manifestation of an out of memory error, since the program works just fine when I run the program with CUDA_LAUNCH_BLOCKING=1. This essentially undoes the parallelism of the process and makes it run serially, correct?

More reading on CUBLAS_STATUS_NOT_INITIALIZED which all point to different culprits:

@stas00
Copy link
Contributor

stas00 commented Jan 31, 2022

Glad to hear it's not hanging any longer.

CUDA_LAUNCH_BLOCKING=1 is a useful debugging tool which tells cuda not to run any async operations, because normally when an async operation fails it often lacks context it was called from and thus it's very difficult to debug. So that env var forces blocking operations which slows everything down, but when the same code fails it now tells exactly what happened including full context.

Enabling it however shouldn't impact the overall memory usage, so if everything works when you enable CUDA_LAUNCH_BLOCKING=1 then something is broken.

How do I reproduce this issue?

@stas00 stas00 changed the title 'transformers.deepspeed' has no attribute 'initialize' [deepspeed] bigscience/T0* multi-gpu inference with ZeRO Jan 31, 2022
@stas00
Copy link
Contributor

stas00 commented Jan 31, 2022

I adjusted the title of this Issue and re-opened this Issue since we are clearly still working through it.

@stas00 stas00 reopened this Jan 31, 2022
@tuhinjubcse
Copy link

I have been trying this @stas00 . I am using 1 A100 GPU, but it's pretty slow, I tried batch size 10, but don't see much difference . Any idea how to improve inference size? it's using only 4285MiB / 40536MiB

@stas00
Copy link
Contributor

stas00 commented Feb 1, 2022

Please give me more context, @tuhinjubcse. Are you trying the script I shared above?

So 40GB A100 - got that part.

for T0_3B disable cpu offloading (instructions in the script) and it will be much faster.

the bigger T0-ones I don't think will fit into a single GPU w/o offload, unless you use full bf16, instead of mixed precision, so basically cast your model and input to .to(dtype=torch.bfloat16) and don't use deepspeed at all. e.g. you can do a simple experiment with the examples, so you'd just pass --bf16_full_eval - of course it may affect the results somewhat but as it was training in amp/bf16 it should be pretty close. You will also want the latest pytorch, as some earlier versions weren't doing the right thing in full half-precision (no amp) mode.

at bf16 you will only need 22GB for the weights and then some for activations. Let me know how it went.

I don't think Deepspeed can do non-mixed precision. Let me ask if this can be done somehow

@tuhinjubcse
Copy link

I am using T0pp and yes trying the script shared above. Happy to use multiple GPUs but do you have any idea if that will make it faster?

I didn't get this part. what is the logic of assigning one prompt to rank 0 and another to rank 1

if rank == 0:
    text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"

@stas00
Copy link
Contributor

stas00 commented Feb 1, 2022

Happy to use multiple GPUs but do you have any idea if that will make it faster?

disable cpu offload.

what is the logic of assigning one prompt to rank 0 and another to rank 1

If you get a chance please read https://huggingface.co/docs/transformers/parallelism#zero-data-parallelism - you will hopefully see that each gpu will assemble a full layer and run the inputs as if it had the full model all along.

So when you use 2 gpus, you can process 2 unrelated batches at once.

With 4 gpus, 4, etc.

If you're using a single batch and replicate it to all gpus, they will each calculate an identical output. so your efficiency is 1/n_gpus.

You're asking how to make things run faster. If you have 4 gpus give each gpu a different input and you have 4x speed up! whoah!

The more gpus you use the bigger batch size you can use. So that will give you further speedup. the example above is a demo, so you want to switch to batches and not single inputs.

@AADeLucia
Copy link
Author

How do I reproduce this issue?

  • enable bf16
  • do not offload to CPU
  • 2 24GB NVIDIA GeForce RTX 3090
  • pytorch=1.10.1
  • cudatoolkit =11.3.1

Then run the script you provided with T0pp

@stas00
Copy link
Contributor

stas00 commented Feb 2, 2022

I'm trying to think where to find a similar setup, as I only have 1x RTX 3090, I've been trying to buy a 2nd one for a long time and it's just not available to buy :(

@AADeLucia
Copy link
Author

Is there anything I can try in the meantime?

@stas00
Copy link
Contributor

stas00 commented Feb 2, 2022

I tried to reproduce on 2x A100 40GB and there is no problem there. The program completes w/o hanging or errors.

@michaelroyzen
Copy link

michaelroyzen commented Apr 2, 2022

Thanks @stas00 for the great example. I'm trying to run T0_3B inference on a single A10 GPU, so I don't need ZeRO here or multi-GPU inference. Using your suggestion to run bf16 inference without deepspeed, I'm casting both the model and inputs to bfloat16, but PyTorch returns RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding).

The rough flow is as follows:

t0_tokenizer = AutoTokenizer.from_pretrained('bigscience/T0_3B')
t0_model = AutoModelForSeq2SeqLM.from_pretrained('bigscience/T0_3B').to(device='cuda:0', dtype=torch.bfloat16)
_ = t0_model.eval()

inputs = t0_tokenizer.encode(text_input, return_tensors="pt").to(device='cuda:0', dtype=torch.bfloat16)
outputs = t0_model.generate(inputs, min_length=64,
                max_length=256,
                do_sample=False,
                num_beams=2,
                early_stopping=True,
                no_repeat_ngram_size=3)
answer = t0_tokenizer.decode(outputs[0], skip_special_tokens=True)

Your help would be greatly appreciated. I've tried it with casting only the model to bfloat16 but not the inputs -- it runs, but there's no speedup over FP32. GPU utilization is also a lot lower with BF16 model weights vs. FP32.

Hardware: Nvidia A10, 24GB VRAM
Software: Python 3.8, torch==1.11.0, transformers==4.17.0, deepspeed=0.6.1

@stas00
Copy link
Contributor

stas00 commented Apr 3, 2022

@michaelroyzen,

  1. As you discovered you shouldn't cast inputs away from its torch.int64 (int). The model will automatically convert the activations to the correct float dtype at the moment of embedding lookup. So only the model needs to be cast.

  2. By default when you run in fp32, you actually run under tf32 which on A10 is 2x faster than the former. If you want to compare bf16 to fp32 you have to explicitly turn tf32 off. See this.

  3. you can load the model directly in bf16 if you're short on memory and want it to load faster - instead of casting it from fp32:

AutoModelForSeq2SeqLM.from_pretrained('bigscience/T0_3B', torch_dtype=torch.bfloat16)
  1. on a small batch usually it's very difficult to notice any speed improvements by tweaking dtype as the overhead of all other components dominates the back-to-back path. You can study these reports to get a feeling to when one starts benefiting from mixed or half precision: [Benchmark] HF Trainer on A100 #15026 and this is with training which has about 3x math to do than inference. The inference only has a forward path and thus is likely to require an even bigger batch to have a noticeable impact over a batch size of 1.

    You could for example experiment with a smaller model and a larger batch size to compare the different dtypes as they speed things up with a growing batch size.

  2. in the future for new discussions it's best to start a new Issue.

@michaelroyzen
Copy link

Thank you for your response @stas00. Yeah, by casting I meant the approach you described in part 3 of your answer -- AutoModelForSeq2SeqLM.from_pretrained('bigscience/T0_3B', torch_dtype=torch.bfloat16) is what I ended up using.

And my apologies, next time I'll start a new issue.

@archieCanada
Copy link

Hello,

I was able to run the code stas00 mentioned above.
Though my task is along the same lines, it is a little more demanding. And I struggle to make the adjustments.
I would appreciate any help.

The description of the task:

  1. I have a list of csv documents.
    list = [doc1,doc2,doc3,...,docN]

  2. Each document contains a dataframe with two columns: dataframe['question'] and dataframe['context']. There are around 25 rows in each dataframe.

  3. Without parallelization, I generate the text by using:

for element in list:
    dataframe = pd.read_csv(element)
    for index, row in dataframe.iterrows():
          query_and_docs = "question: {} context: {}".format(row['question'], row['context'])
          model_input = tokenizer(query_and_docs, padding=True, return_tensors="pt")
          generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
                                                   attention_mask=model_input["attention_mask"].to(device),
                                                   min_length=200,
                                                   max_length=400,
                                                   do_sample=False, 
                                                   early_stopping=True,
                                                   num_beams=8,
                                                   temperature=1.0,
                                                   top_k=None,
                                                   top_p=None,
                                                   eos_token_id=tokenizer.eos_token_id,
                                                   no_repeat_ngram_size=3,
                                                   num_return_sequences=1)
          Answer = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]
          row['answer'] = Answer

Question:

  1. How can I adjust this code for multi-gpu inference?
  2. Would the combination of GPUs and CPUs (let's say 40 GPU and 40 CPU) be more efficient than just GPUs (40)?

@stas00
Copy link
Contributor

stas00 commented Sep 8, 2022

  • How can I adjust this code for multi-gpu inference?

You just get each rank to generate its unique sequence. See the simple example for 2 gpus here:

https://huggingface.co/docs/transformers/main/main_classes/deepspeed#custom-deepspeed-zero-inference

Specifically this part:

rank = torch.distributed.get_rank()
if rank == 0:
    text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"

tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")

If you run that example verbatim, you will see that each rank (gpu) will print its own generated answer.

I hope you can see that you code remains exactly the same. You just make the input different for each rank.

  • Would the combination of GPUs and CPUs (let's say 40 GPU and 40 CPU) be more efficient than just GPUs (40)?

How do you imagine to use both together?

You can use CPU offload if the gpus are too few, but that would make the speed slower. Please see:

https://huggingface.co/docs/transformers/main/main_classes/deepspeed#deployment-with-one-gpu

But overall, no, it won't be more efficient.

@stas00
Copy link
Contributor

stas00 commented Sep 8, 2022

Additionally, Deepspeed has recently released a new product called Deepspeed-Inference which speeds up inference by splitting the processing over multiple GPUs using Tensor Parallelism.

And it can even handle quantized int8 input and thus requiring half the resources at a cost of course of slower execution.

See https://github.com/bigscience-workshop/Megatron-DeepSpeed/tree/main/scripts/bloom-inference-scripts#deepspeed-inference though it's a temporary home - most likely the scripts will move to another location soon.

The demo scripts are written for BLOOM but can be adapted to other models. If you run into problems please ask directly at Deepspeed Issues as this is not my code (just the bloom demos are mine ;).

@archieCanada
Copy link

Thank you, Stas. This is exactly my question.

Let's say I have 40 GPUs but just 24 inputs in one document. Would that mean that I can take advantage only of 24 GPUs at a time? It is particularly an issue since the number of inputs in each document varies. If I assign each input for each GPU using the example you mentioned, I will run into the issue that while processing some documents, I will have many idle GPUs (if I have 40 GPUs and 24 inputs in the particular document).

  1. Is it possible to assign not the inputs, but the documents to different GPUs while the model (T0) itself is deployed among all the GPUs?

Or maybe there is a smarter way to iterate the GPUs over the many inputs in many documents.

Thank you for your help with this issue.

@stas00
Copy link
Contributor

stas00 commented Sep 8, 2022

you need to understand how ZeRO works - all gpus must always work in sync, so you never have idling gpus. I suggest to perhaps read the main paper: https://arxiv.org/abs/1910.02054

you can do a single stream and then all other gpus will process it too, you can send unique streams - so you can have 24 unique streams and the rest will get whatever input and you can ignore the results. but again all gpus have to work in sync, because each gpu carries a unique shard of weights, and other gpus can't continue w/o it.

I think most likely you will want to research Deepspeed-Inference which is faster than ZeRO-Inference and there you don't need to bother with multiple streams, as it always has just a single stream - you just feed it a large batch-size instead and of course you can change its size on every generate call.

Deepspeed-Inference also uses custom fused kernels, which makes it super-fast. If some model isn't supported you can ask the Deepspeed team to add the support - it should be pretty quick.

You can see the benchmark results here (albeit for a much larger model - bloom-176b)
https://github.com/bigscience-workshop/Megatron-DeepSpeed/tree/main/scripts/bloom-inference-scripts#bloom-inference-solutions

If you're planning to build a server solution, there are several WIP solutions as well, one is:
https://github.com/bigscience-workshop/Megatron-DeepSpeed/tree/main/scripts/bloom-inference-server

@archieCanada
Copy link

Thank you, Stas. I will follow your advice and do more reading.

@archieCanada
Copy link

Hello Stas,

It's me again.
I encountered the following problem. According to my limited understanding deepspeed has to initialize the model first. For this initialization phase, I need to have the CPU memory proportional to the number of GPUs I use for the inference. The faster I want my inference (more GPUs), the more CPU memory I have to have for the initialization of the model.

This is very unfortunate. Since this bottleneck deprives of so many benefits of parallelization. For example, if I have 4 GPUs with 32GB RAM each, I still can't run T0pp if I have just 50 GB of CPU RAM total.

Are there any ways around this issue?

@stas00
Copy link
Contributor

stas00 commented Sep 16, 2022

Yes, there is. You can pre-shard the model weights into small shards.

You're in luck since I already did that for T0pp: https://huggingface.co/bigscience/T0pp/tree/sharded

so all you need to do is from_pretrained(..., revision="sharded")

please note that the revision is named sharded just because I called it so, there is no standard.

All new models starting from a few months ago are added as sharded into 5-10GB shards by default, the old ones - I sharded many - you can see the status here: #16884

And you can further reshard those models into even smaller chunks, now you can have little CPU RAM and concurrently load into many gpus no problem. e.g. to 5GB

python -c 'from transformers import AutoModelForSeq2SeqLM; \
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp"); \
model.save_pretrained("t0pp-sharded", max_shard_size="5GB")'

@archieCanada
Copy link

Wonderful news! Thank you very much.

Is there any document/readme which can educate me on how to choose the right parameter choice (5GB, 10GB) and how this would affect the speed of the inference?

Do I understand correctly that according to your calculations in the above-mentioned issue, to run the inference with T0pp would take (5GB (max_shard_size) * #-of-GPUs) CPU RAM?

@stas00
Copy link
Contributor

stas00 commented Sep 16, 2022

The size of the shard is only important at the loading time for many concurrent processes.

The formula would be roughly this:

model loading RAM required = n_gpus * shard_size * 2

so say 4 gpus and 5gb shard:

4 * 5 * 2 = 40

so 40GB of additional CPU memory will be needed to load the model.

the 2x is because at some point you have the state_dict containing the shard and the model itself, though in the case of deepspeed/zero3 it'll get offloaded to the gpus right away. So in this case it'll be 1x shard + largest submodule weights size as it needs to gather those from all the gpus to update with the loaded weights.

But also look into the very recent solutions which will be even faster than deepspeed zero: https://huggingface.co/blog/bloom-inference-pytorch-scripts - the article is written for BLOOM, but there is no reason why it shouldn't work for t0 models. Definitely for Accelerate as it's model agnostic. For Deepspeed-Inference which uses custom CUDA kernels - I haven't tried - if it doesn't work with the latter please ask at Deepspeed Issue - but these will be much faster solutions - unless you infer different streams for each gpu with deepspeed-zero - please read the article and you should be able to see what would work the best for you.

If you try Deepspeed-Inference w/ T0 please report back success/failure I'd like to know. Thank you!

@SaeedShadkam
Copy link

SaeedShadkam commented Sep 20, 2022

OK, here is a much improved program which also integrates some enhancements from @VictorSanh's work.

This script can now handle both cpu offload and/or multiple gpu.

e.g. I can process "bigscience/T0_3B" on a 8GB GPU no problem.

I added a bunch of notes before and through the code - please let me know if anything is unclear or missing:

#!/usr/bin/env python

# This script demonstrates how to use Deepspeed ZeRO in an inference mode when one can't fit a model
# into a single GPU
#
# 1. Use 1 GPU with CPU offload
# 2. Or use multiple GPUs instead
#
# First you need to install deepspeed: pip install deepspeed
#
# Here we use a 3B "bigscience/T0_3B" model which needs about 15GB GPU RAM - so 1 largish or 2
# small GPUs can handle it. or 1 small GPU and a lot of CPU memory.
#
# To use a larger model like "bigscience/T0" which needs about 50GB, unless you have an 80GB GPU -
# you will need 2-4 gpus. And then you can adapt the script to handle more gpus if you want to
# process multiple inputs at once.
#
# The provided deepspeed config also activates CPU memory offloading, so chances are that if you
# have a lot of available CPU memory and you don't mind a slowdown you should be able to load a
# model that doesn't normally fit into a single GPU. If you have enough GPU memory the program will
# run faster if you don't want offload to CPU - so disable that section then.
#
# To deploy on 1 gpu:
#
# deepspeed --num_gpus 1 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=1 t0.py
#
# To deploy on 2 gpus:
#
# deepspeed --num_gpus 2 t0.py
# or:
# python -m torch.distributed.run --nproc_per_node=2 t0.py


from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers.deepspeed import HfDeepSpeedConfig
import deepspeed
import os
import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false" # To avoid warnings about parallelism in tokenizers

# distributed setup
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()

model_name = "bigscience/T0_3B"

config = AutoConfig.from_pretrained(model_name)
model_hidden_size = config.d_model

# batch size has to be divisible by world_size, but can be bigger than world_size
train_batch_size = 1 * world_size

# ds_config notes
#
# - enable bf16 if you use Ampere or higher GPU - this will run in mixed precision and will be
# faster.
#
# - for older GPUs you can enable fp16, but it'll only work for non-bf16 pretrained models - e.g.
# all official t5 models are bf16-pretrained
#
# - set offload_param.device to "none" or completely remove the `offload_param` section if you don't
# - want CPU offload
#
# - if using `offload_param` you can manually finetune stage3_param_persistence_threshold to control
# - which params should remain on gpus - the larger the value the smaller the offload size
#
# For indepth info on Deepspeed config see
# https://huggingface.co/docs/transformers/master/main_classes/deepspeed
ds_config = {
    "fp16": {
        "enabled": False,
    },
    "bf16": {
        "enabled": False,
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": model_hidden_size * model_hidden_size,
        "stage3_prefetch_bucket_size": 0.9 * model_hidden_size * model_hidden_size,
        "stage3_param_persistence_threshold": 10 * model_hidden_size
    },
    "steps_per_print": 2000,
    "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}

# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive

# now a model can be loaded.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# we are ready to initialise deepspeed ZeRO now
ds_engine = deepspeed.initialize(model=model,
                                 config_params=ds_config,
                                 model_parameters=None,
                                 optimizer=None,
                                 lr_scheduler=None)[0]
ds_engine.module.eval() # inference

# Deepspeed ZeRO can process unrelated inputs on each GPU. So for 2 gpus you process 2 inputs at once.
# If you use more GPUs adjust for more.
# And of course if you have just one input to process you then need to pass the same string to both gpus
# If you use only one GPU, then you will have only rank 0.
rank = torch.distributed.get_rank()
if rank == 0:
    text_in = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
elif rank == 1:
    text_in = "Is this review positive or negative? Review: this is the worst restaurant ever"

tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(text_in, return_tensors="pt").to(device=local_rank)
with torch.no_grad():
    outputs = ds_engine.module.generate(inputs, synced_gpus=True)
text_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"rank{rank}:\n   in={text_in}\n  out={text_out}")

Let's take it for a run:

$ deepspeed --num_gpus 2 t0.py
rank0:
   in=Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy
  out=Positive
rank1:
   in=Is this review positive or negative? Review: this is the worst restaurant ever
  out=negative

@stas00 Hello Stas, I have been experimenting with this code you posted, and I got strange results. I was wondering if you have any thoughts/suggestions for me to improve the code, especially speed-wise.

I ran my code with deepspeed using 4 and 1 V100 32 GB GPUs, respectively. To generate answers for the same exact questions, 1 GPU finishes the job in 70 minutes; using four GPUs takes 112 minutes!!!

I use your code except for the last couple of lines, which I change to:

number_cores = 4 #Or 1 depending on number of GPUs I am using
for i in range(int(len(questions_list)/number_cores)):
    if rank == 0:
        number = number_cores*i
        query_and_docs = f"question: {questions_list[number]} context: {context}"
    elif rank == 1:
        number = number_cores*i+1
        query_and_docs = f"question: {questions_list[number]} context: {context}"
    elif rank == 2:
        number = number_cores*i+2
        query_and_docs = f"question: {questions_list[number]} context: {context}"
    elif rank == 3:
        number = number_cores*i+3
        query_and_docs = f"question: {questions_list[number]} context: {context}"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_input = tokenizer(query_and_docs, max_length=512, padding='max_length', return_tensors="pt")
    with torch.no_grad():
        generated_answers_encoded = ds_engine.module.generate(input_ids=model_input["input_ids"].to(device),
                                               attention_mask=model_input["attention_mask"].to(device),
                                               synced_gpus=True,
                                               min_length=20,
                                               max_length=400,
                                               do_sample=False,
                                               early_stopping=True,
                                               num_beams=8,
                                               temperature=1.0,)
        Answer = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
    answers.append(Answer)

@stas00
Copy link
Contributor

stas00 commented Sep 20, 2022

It's not strange at all. When using ZeRO 1 gpu will always be faster if you can fit the model in - this is because of the overhead of comms with multiple gpus which a 1-gpu setup doesn't have.

This is about using the right tool for the right job. ZeRO was written for models that can't be fit into a single GPU. If you can use a single GPU use it ;)

Also the following nuance is very important: a 4-gpu set up in your case generates 4 different outputs and 1 gpu only one, so the effective speed of 4 gpus is 1/4th of 112 minutes, so 28 minutes per gpu. Does it make sense?

And there is a much faster Deepspeed-Inference solution that was released just recently https://huggingface.co/blog/bloom-inference-pytorch-scripts that will indeed make your 4 gpus speed up the inference by much and faster than 1 gpu. The Accelerate solution is also likely to be faster or on par with ZeRO when you feed the latter unique streams. I haven't tried with this particular model to tell for sure.

@SaeedShadkam
Copy link

Oh, I see. Although I knew the purpose of Zero is to fit big models, I assumed it would also help with the speed!!

I do not understand the nuance you mentioned, though; I generated the same number of answers in both the 4-GPU and 1-GPU setups. To make it more specific, to answer 60 questions, 4-GPU set up took 112 minutes while 1-GPU set up only took 70 minutes.

I will try to implement the model using the document you mentioned. That's very helpful.

@stas00
Copy link
Contributor

stas00 commented Sep 20, 2022

I do not understand the nuance you mentioned, though; I generated the same number of answers in both the 4-GPU and 1-GPU setups. To make it more specific, to answer 60 questions, 4-GPU set up took 112 minutes while 1-GPU set up only took 70 minutes.

I wasn't able to derive that this was the case - it's possible I missed that.

If it is as you say it is then the overhead of comms is really big then and 4 gpus are indeed slower even with unique streams.

When you have fast intranode connectivity like NVLink as compared to PCIe usually the comms overhead is lower and then compute dominates and gpus excel at what they do - fast results. when comms are slow then the gpus idle a lot - slow results.

same goes for multiple nodes - one node is fast, more than one node is usually much slower since inter-node networks are usually slower than intra-node - but it's not always the case (e.g. NVSwitch can connect many nodes at almost the same speed as NVLink on one node).

You can also watch your GPU utilization in nvidia-smi while processing in 1 vs 4 gpus - if it's always close to 100% then comms are super fast - if it's jumping between 0 and 100, then the overhead of getting the data around is large. nvidia-smi doesn't show the exact compute util, but it's a good enough indication.

If for example you try NVMe offload gpu will be heavily under-utilized since disc IO will be slow.

@archieCanada
Copy link

archieCanada commented Sep 20, 2022

The size of the shard is only important at the loading time for many concurrent processes.

The formula would be roughly this:

model loading RAM required = n_gpus * shard_size * 2

so say 4 gpus and 5gb shard:

4 * 5 * 2 = 40

so 40GB of additional CPU memory will be needed to load the model.

the 2x is because at some point you have the state_dict containing the shard and the model itself, though in the case of deepspeed/zero3 it'll get offloaded to the gpus right away. So in this case it'll be 1x shard + largest submodule weights size as it needs to gather those from all the gpus to update with the loaded weights.

But also look into the very recent solutions which will be even faster than deepspeed zero: https://huggingface.co/blog/bloom-inference-pytorch-scripts - the article is written for BLOOM, but there is no reason why it shouldn't work for t0 models. Definitely for Accelerate as it's model agnostic. For Deepspeed-Inference which uses custom CUDA kernels - I haven't tried - if it doesn't work with the latter please ask at Deepspeed Issue - but these will be much faster solutions - unless you infer different streams for each gpu with deepspeed-zero - please read the article and you should be able to see what would work the best for you.

If you try Deepspeed-Inference w/ T0 please report back success/failure I'd like to know. Thank you!

Hello Stas,

Here is my report.
I have been able to run the DeepSpeed inference with T0_3B. I used the code from End-to-End GPT NEO 2.7B Inference tutorial:
https://www.deepspeed.ai/tutorials/inference-tutorial/

The tricky part is that I had to downgrade my version of "transformers" to version 4.21.3. Otherwise, it doesn't work.
I provide code here for convenience:

import deepspeed
import torch
from transformers import pipeline

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
generator = pipeline('text-generation', model='bigscience/T0_3B',
                     device=local_rank)



generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.float,
                                           replace_method='auto',
					   replace_with_kernel_inject=True)

string = generator("DeepSpeed is", do_sample=True, min_length=20)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    print(string)

However, I struggle with integrating "sharded revision" into this code. I assume a pipeline can't be used for it, right?

@stas00
Copy link
Contributor

stas00 commented Sep 20, 2022

Thank you for sharing the outcome, @archieCanada

I'd recommend opening an Issue at https://github.com/microsoft/DeepSpeed/issues so that it can be fixed on their side.

Besides your repro code please include the traceback of the failure you had with the latest transformers.

I wonder if they need to add tests to ensure.

However, I struggle with integrating "sharded revision" into this code. I assume a pipeline can't be used for it, right?

I'm not quite sure, I use explicit code rather than pipelines so that I have full control over all parts.

I think you could open a feature request to have pipeline support a revision argument if it doesn't already - this makes total sense.

To hack around it you could download the revision you want and load it locally using the path to the clone rather than the model name.

@au-revoir
Copy link

@stas00 I see that you have mentioned synced_gpus=True when using model.module.generate but in my case, I am generating the logits using forward method model.module(input_ids=input_ids, attention_mask=attn_mask, labels=decoder_ids).logits. This doesn't support synced_gpus=True. Currently the process hangs after showing Process 36678 exits successfully. The rank 1 GPU becomes empty while the rank 0 is still running but since the rank 1 GPU no longer has the shard with it that rank 0 needs, could you please tell me if there is any way I could make this work?

@stas00
Copy link
Contributor

stas00 commented Sep 25, 2022

That's the thing about ZeRO - all participating in sharding gpus must always work in sync. If one gpu finished early for some reason it must continue running forward until all gpus finish. If it doesn't the other gpus will block waiting for all gpus to send their shards to each other, which leads to hanging.

You can see how I implemented it in generate - follow synced_gpus code branches - but of course any other way would work as well.

You can also use py-spy to see where the gpus hang, but it'll be some forward call.

@archieCanada
Copy link

archieCanada commented Sep 26, 2022

Hello Stas and community,

I try to implement the sharded version of the T0pp model, but I fail to do so for reasons I don't understand.
Here is my code:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import deepspeed

model_name = "bigscience/T0pp"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.save_pretrained("t0pp-sharded", max_shard_size="5GB")

I have the following resources:

2 GPUs (Tesla V100-SXM2-32GB)
60 GB CPU RAM per GPU (120 GB CPU RAM total)

According to my understanding these resources should be enough. But the system crushes already on the line:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
with "not-enough-memory" Error.

@stas00
Copy link
Contributor

stas00 commented Sep 26, 2022

most likely you have too little cpu memory, try with AutoModelForSeq2SeqLM.from_pretrained(model_name, low_cpu_mem_usage=True)

normally with unsharded model you need 2x model size of CPU memory to load it.

@archieCanada
Copy link

archieCanada commented Sep 26, 2022

Hello Stas,

I probably don't understand everything, because I am missing your point.

Here is my situation:
I want to use the sharded model you mentioned earlier to avoid the bottleneck with many GPUs and limited CPU RAM. But I don't understand the algorithm of how to do that.

What I have tried:

model_name = "bigscience/t0pp-sharded"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

model_name = "t0pp-sharded"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In both cases, I get an Error:
t0pp-sharded is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'

Probably this is not the right way of loading the sharded model.

Should I then myself save the model as sharded using the following algorithm?:

  1. download the full model using 2 GPUs and 200 GB CPU RAM
  2. save the model as sharded using:
    model.save_pretrained("t0pp-sharded", max_shard_size="5GB")
  3. Start a new session with more than 2 GPUs but less CPU RAM (because now I can use the sharded model)
  4. Download the sharded model using:
    model_name = "t0pp-sharded"
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

Please, correct me where I am wrong.
Thank you.

@stas00
Copy link
Contributor

stas00 commented Sep 27, 2022

I was replying to your comment of getting your program killed.

I think now I perhaps understand what you are struggling with.

there is no such model as bigscience/t0pp-sharded

Here is what you can do:

  1. Use the 10GB-shard presharded model:
python -c 'from transformers import AutoModelForSeq2SeqLM; \
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", revision="sharded");'

if 10GB is too big then try next to make your own sharded model:

  1. say 5-GB shards
python -c 'from transformers import AutoModelForSeq2SeqLM; \
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp"); \
model.save_pretrained("/path/to/t0pp-sharded", max_shard_size="5GB")'

this step will require 2x model-size cpu memory and then a bit more. So 100GB of CPU memory should be enough

and then use the resulting model like so:

python -c 'from transformers import AutoModelForSeq2SeqLM; \
model = AutoModelForSeq2SeqLM.from_pretrained("/path/to/t0pp-sharded");'
  1. you can also take the output in /path/to/t0pp-sharded and upload it to the hub and then use that model:

python -c 'from transformers import AutoModelForSeq2SeqLM;
model = AutoModelForSeq2SeqLM.from_pretrained("MYUSERNAME/MYMODELNAME");'

you will of course will have to adapt the upcase name to your situation

None of these needs a GPU.

Please let me know if any of the 3 worked for you

@archieCanada
Copy link

Hello Stas,
Thank you for your kind help.

I tried method 2) you described above.
Here is my code:

#Inputs
questionSet = ['question0', 'question1', 'question2']
context = 'Some context'

#installation of the necessary packages
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import deepspeed

#download the model
model_name_token = "bigscience/T0pp"
model_name = "/path/to/t0pp-sharded"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

tokenizer = AutoTokenizer.from_pretrained(model_name_token)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

model = model.to(device=local_rank)

model = deepspeed.init_inference(model,
                                mp_size=world_size,
                                dtype=torch.float,
                                replace_method='auto',
					   replace_with_kernel_inject=True)

#generate answers
answerSet = []
for query in questionSet:
  query_and_docs = "question: {} context: {}".format(query, context)
  model_input = tokenizer(query_and_docs, padding=True, return_tensors="pt")
  generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device=local_rank),
                                           attention_mask=model_input["attention_mask"].to(device=local_rank),
                                           min_length=200,
                                           max_length=400,
                                           do_sample=False,
                                           early_stopping=True,
                                           num_beams=8,
                                           temperature=1.0,
                                           top_k=None,
                                           top_p=None,
                                           eos_token_id=tokenizer.eos_token_id,
                                           no_repeat_ngram_size=3,
                                           num_return_sequences=1)
  answer = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]
  answerSet.append(answer)

for i in range(0,len(machineAnswerSet)):
  print(machineAnswerSet[i])

My resources:

I tried two different options:

  1. 2 GPUs (Tesla V100-SXM2-32GB), 60 GB CPU RAM per GPU (120 GB CPU RAM total)
  2. 3 GPUs (Tesla V100-SXM2-32GB), 60 GB CPU RAM per GPU (180 GB CPU RAM total)

Both times I received the following error:

RuntimeError: CUDA out of memory. Tried to allocate 160.00 MiB (GPU 0; 31.75 GiB total capacity; 30.99 GiB already allocated; 118.19 MiB free; 30.99 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@stas00
Copy link
Contributor

stas00 commented Sep 27, 2022

but really as I suggested originally you should post an Issue at https://github.com/microsoft/DeepSpeed/issues and tag RezaYazdaniAminabadi - DeepSpeed-Inference isn't something that is integrated into transformers like Deepspeed-ZeRO is. So all problems related to the former should be reported to the Deepspeed project and not here. This thread is about Deepspeed-ZeRO.

Deepspeed is a project that has multiple related frameworks - we have only the ZeRO project integrated into HF Trainer and Accelerate. DeepSpeed-Inference requires no integration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants