Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Parallelism and accelerate's usage of DDP aren't compatible #1368

Open
2 of 4 tasks
RobertKirk opened this issue Apr 28, 2023 · 18 comments
Open
2 of 4 tasks

Model Parallelism and accelerate's usage of DDP aren't compatible #1368

RobertKirk opened this issue Apr 28, 2023 · 18 comments
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate

Comments

@RobertKirk
Copy link

System Info

- `Accelerate` version: 0.18.0
- Platform: Linux-5.4.0-124-generic-x86_64-with-glibc2.31
- Python version: 3.9.12
- PyTorch version (GPU?): 1.12.0 (True)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: no
        - use_cpu: False
        - num_processes: 16
        - machine_rank: 0
        - num_machines: 16
        - main_process_ip: 192.168.1.1
        - main_process_port: 8080
        - rdzv_backend: static
        - same_network: False
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

If I use model parallelism (for example using huggingface parallelize), and I'm using accelerate with a standard multi-GPU environment (that uses DDP), then when I prepare the model I get the following error:

  File "/private/home/raileanu/new-rlvsil/rlvsil/experiment_accel.py", line 689, in main
    model, optimizer, lr_scheduler, *prepared_dataloaders = accelerator.prepare(
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/accelerate/accelerator.py", line 1122, in prepare
    result = tuple(
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/accelerate/accelerator.py", line 1123, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/accelerate/accelerator.py", line 977, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/accelerate/accelerator.py", line 1202, in prepare_model
    model = torch.nn.parallel.DistributedDataParallel(
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 571, in __init__
    self._log_and_throw(
  File "/private/home/raileanu/.conda/envs/rob/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 674, in _log_and_throw
    raise err_type(err_msg)
ValueError: DistributedDataParallel device_ids and output_device arguments only work with single-device/multiple-device GPU modules or CPU modules, but got device_ids [0], output_device 0, and module parameters {device(type='cuda', index=0), device(type='cuda', index=1)}.

I think this is because in line

model = torch.nn.parallel.DistributedDataParallel(
it initialises the DDP model by setting device_ids and output_device, whereas these should both be set to None if using model parallelism.

You should be able to reproduce this on a 4-gpu machine with something like the following:

from transformers import GPTJForCausalLM
import accelerate

model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
device_map = {
    0: [0, 1, 2, 3, 4, 5, 6],
    1: [7, 8, 9, 10, 11, 12, 13],
    2: [14, 15, 16, 17, 18, 19, 20],
    3: [21, 22, 23, 24, 25, 26, 27],
}
model.parallelize(device_map)

accelerator = accelerate.Accelerator()

model = accelerator.prepare(model)

I'm currently getting around this by wrapping the model in DDP myself with the correct arguments, and then doing accelerator._models.append(model).

Expected behavior

I'd expect accelerate's usage of DDP to be compatible with naïve model parallelism, as DDP is compatible with it.

I think the fix would be to adjust

model = torch.nn.parallel.DistributedDataParallel(
such that if the model has parameters on multiple devices, or the hf_device_map uses multiple devices, (or maybe the user passes an explicit parameters saying they're using model parallelism), the DDP initialisation doesn't set device_ids and output_device. I'd be happy to submit a PR to make that change if that seems reasonable.

@RobertKirk RobertKirk changed the title Model Parallelism and DDP aren't compatible Model Parallelism and accelerate's usage of DDP aren't compatible Apr 28, 2023
@sgugger
Copy link
Collaborator

sgugger commented Apr 28, 2023

Yes, Accelerate does not support DDP with model parallelism. I'm not sure your proposed fix would work as DDP will all-reduce the gradients across GPUs except all GPUs don't have the same parameters.

For pipeline parallelism as you are trying to acheive, use FSDP or DeepSpeed

@RobertKirk
Copy link
Author

If you make sure each accelerate process gets multiple GPUs, then I think DDP will work as expected - so you have 1 accelerate process and hence 1 DDP model per 4 gpus (for example), then you should get the correct synchronisation. For example that's what this tutorial implies.

In my setup preparing the model separately does work as intended, with 4 accelerate processes with 3 GPUs each and the model layers split across those 3 GPUs. I'm launching each those processes with a separate call however, so I'm uncertain how you'd do it with a single call on a multi-GPU machine if you wanted 4 processes with 2GPUs each rather than 8 processes with 1 GPU each.

I'll look into trying to get FSDP to work in my set up as well.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jun 6, 2023
@Andcircle
Copy link

Andcircle commented Aug 8, 2023

@sgugger need your guidance, I wanna use

model = AutoModelForCausalLM.from_pretrained(
        model_name, quantization_config=bnb_config, device_map={'':torch.cuda.current_device()}, trust_remote_code=True
    )

to train 40b, but also wanna DDP, then how should I achieve it? Thanks

@sgugger
Copy link
Collaborator

sgugger commented Aug 8, 2023

You can use DDP if your model is only on one device like this.

@Andcircle
Copy link

@sgugger Thanks for your fast help. But what if the model is too big for one GPU device?

@sgugger
Copy link
Collaborator

sgugger commented Aug 9, 2023

Then you cannot use DDP + device_map="auto". You need to use DeepSpeed or FSDP.

@sgugger
Copy link
Collaborator

sgugger commented Aug 9, 2023

I feel like you are not listening. You cannot use DDP + device_map="auto" and thus not DDP + device_map="auto" + DeepSpeep either. You need to just use DeepSpeed ZeRO-3 to shard your model on several devices and train it with a mix of model parallelism and data parallelism.

@Andcircle
Copy link

I feel like you are not listening. You cannot use DDP + device_map="auto" and thus not DDP + device_map="auto" + DeepSpeep either. You need to just use DeepSpeed ZeRO-3 to shard your model on several devices and train it with a mix of model parallelism and data parallelism.

Sorry for my misunderstanding, I got your point now

@k21993
Copy link

k21993 commented Aug 9, 2023

@sgugger Just to make sure my understanding is correct, can we use deepspeed support with the Trainer API to do model + data parallel (without setting device_map) or do we have to write code with pure deepspeed without HF transformers to load the model? Sorry if my question is repetitive.

@sgugger
Copy link
Collaborator

sgugger commented Aug 10, 2023

As long as you properly configure DeepSpeed ZeRO-3, you won't need to use device_map="auto" yes, and the model will be loaded on several GPUs (each weight will be split).

@maxidl
Copy link

maxidl commented Aug 23, 2023

Just to document my experience on getting DDP + MP (2x2 on 4 gpus) to work with Accelerate (via HF trainer):

I modified the current main branch to initialize the DDP model by setting device_ids and output_device to None, as described in the pytorch docs when using multi-device modules. Additionally, I had to remove some ValueErrors that are being raised (for no good reason?).

I launch two processes with torchrun; each supposed to use 2 gpus.
Each process uses a different device map to load the model (llama-2-7b).
Process 1 uses gpus 0,2, and process 2 uses gpus 1,3.
Example device map for process 1:

{'model.embed_tokens': 0, 'model.norm': 2, 'lm_head': 2, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0 , 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 2, 'model.layers.17': 2, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 2, 'model.layers.28': 2, 'model.layers.29': 2, 'model.layers.30': 2, 'model.layers.31': 2}

I set training_args.place_model_on_device=False, as the model is already placed on devices:

model = transformers.LlamaForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
    device_map=device_map
)

This works for me and produces exactly the same loss curves compared to using only DDP or only MP.

@muellerzr
Copy link
Collaborator

@maxidl can you share your modified code? Curious what those exceptions are that exist for "no good reason"

@maxidl
Copy link

maxidl commented Aug 23, 2023

@maxidl can you share your modified code? Curious what those exceptions are that exist for "no good reason"

@muellerzr I do think these error are necessary if one does not also modify the DDP construction. In fact, they are correct if one has created the device_map with 'auto'. However, the errors get triggered even when using a custom device_map.

You can find my fork here: maxidl@332d960

Also, note that I did not run any tests and check whether this breaks other behavior.

Now why do I think it is useful to have DDP + MP (in the classic pipeline of layers way): In my case, I am running gpus without fast interconnect (nvlink) which makes FSDP style training very slow.

@muellerzr muellerzr reopened this Aug 23, 2023
@muellerzr
Copy link
Collaborator

Thanks @maxidl, as an approach here's what the team has decided we will do:

  1. I'll put a PR in today that let's you explicitly disable the blocking behavior, and will set it to None as you have shown in your example.
  2. We'll keep this issue open, and I ask that the community react with a 👍 to this message if you wind up using this. We want to see what kind of usage folks are having with this, and when we can turn it from a "power-user" feature into something more folks are using.
  3. Long term we'll see how to enable these kind of native TP trainings directly with accelerate + proper config, once we get a decent amount of folks wanting this.

Seem reasonable @maxidl? And thank you for this reproducer!

@muellerzr muellerzr added enhancement New feature or request feature request Request for a new feature to be added to Accelerate labels Aug 23, 2023
@maxidl
Copy link

maxidl commented Aug 23, 2023

Sure, that sounds great. Once the changes are in (no rush with that), I might create a tutorial-style GitHub repo for it and do some benchmarking, to be shared via Twitter (sorry, "X" ....).

@Andcircle
Copy link

@muellerzr

Sorry I wanna bring this up again, is it possible to add this functionality as a feature, background is we wanna tune 70b or 8x7b model as a teacher, tried to use FSDP, but lots of feature is not supported in FSDP, DS is even worse, e.g. nested quantization, sliding window attention, the final total saving is actually not that much.

the following is my testing code, basically for each node, we have 8 A100_80gb GPU, each training process will take 2 GPU:
(btw, I haven't check @maxidl's approach yet, will have a look for sure)

def create_and_prepare_model():
    compute_dtype = getattr(torch, "float16")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        # load_in_8bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
    )

    device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 0, 'model.layers.27': 0, 'model.layers.28': 0, 'model.layers.29': 0, 'model.layers.30': 0, 'model.layers.31': 0, 'model.layers.32': 0, 'model.layers.33': 0, 'model.layers.34': 0, 'model.layers.35': 0, 'model.layers.36': 1, 'model.layers.37': 1, 'model.layers.38': 1, 'model.layers.39': 1, 'model.layers.40': 1, 'model.layers.41': 1, 'model.layers.42': 1, 'model.layers.43': 1, 'model.layers.44': 1, 'model.layers.45': 1, 'model.layers.46': 1, 'model.layers.47': 1, 'model.layers.48': 1, 'model.layers.49': 1, 'model.layers.50': 1, 'model.layers.51': 1, 'model.layers.52': 1, 'model.layers.53': 1, 'model.layers.54': 1, 'model.layers.55': 1, 'model.layers.56': 1, 'model.layers.57': 1, 'model.layers.58': 1, 'model.layers.59': 1, 'model.layers.60': 1, 'model.layers.61': 1, 'model.layers.62': 1, 'model.layers.63': 1, 'model.layers.64': 1, 'model.layers.65': 1, 'model.layers.66': 1, 'model.layers.67': 1, 'model.layers.68': 1, 'model.layers.69': 1, 'model.layers.70': 1, 'model.layers.71': 1, 'model.layers.72': 1, 'model.layers.73': 1, 'model.layers.74': 1, 'model.layers.75': 1, 'model.layers.76': 1, 'model.layers.77': 1, 'model.layers.78': 1, 'model.layers.79': 1, 'model.norm': 1, 'lm_head': 1}
    device1 = torch.cuda.current_device()
    device2 = device1 + 2
    for k,v in device_map.items():
        device_map[k] = device1 if v == 0 else device2
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name, quantization_config=bnb_config, trust_remote_code=True, 
        # device_map='auto',
        device_map=device_map,
        # device_map={'':torch.cuda.current_device()},
        use_flash_attention_2=True
        )
    print(model)
    print(model.hf_device_map)
    
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
      
    if '70b' in model_from:
        model.config.max_position_embeddings = 4096
        
    peft_config = LoraConfig(
        lora_alpha=alpha,
        lora_dropout=0.1,
        r=rank,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    
    # for llama family
    # tokenizer.padding_side = 'right'
    # # for mistral family
    # tokenizer.padding_side = 'left'

    return model, peft_config, tokenizer

# import multiprocessing
# NUM_PROC = multiprocessing.cpu_count() #should be num cpus
save_dir = "/sensei-fs/tenants/Sensei-AdobeSearch/CreativeLLM/zhangli/cpt-hf/runai_experiment/"


training_arguments = TrainingArguments(
    output_dir=os.path.join(save_dir, run_name),
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=accumlate_steps,
    optim="paged_adamw_8bit",
    
    # Can't resume training: Error invalid argument at line 393 in file /mmfs1/gscratch/zlab/timdettmers/git/bitsandbytes/csrc/pythonInterface.c 
    # optim="paged_adamw_8bit",
    
    save_steps=500,
    logging_steps=10,
    learning_rate=lr,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=100,
    warmup_ratio=0.03,
    # group_by_length=True,
    lr_scheduler_type="constant",
    run_name=run_name,
    evaluation_strategy="steps",
    eval_steps=200,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=True,
    # weight_decay=0.01,
    # dataloader_num_workers=NUM_PROC//2
)

model, peft_config, tokenizer = create_and_prepare_model()
model.config.use_cache = False # because of gradient checkpointing


trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=True,
)

trainer.train()

btw, beside this, any other memory optimization approach I can take?

@Andcircle
Copy link

@muellerzr @maxidl, because I loaded the model in 4bit so I also comment out this line:
https://github.com/maxidl/accelerate/blob/332d960d625deda76090c32a6e67dee70be76761/src/accelerate/accelerator.py#L1342

But don't know is there any bad effect, it starts to train at least, could you pls elaborate potential consequences

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate
Projects
None yet
Development

No branches or pull requests

6 participants