Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

finetuning with PEFT int-8bit + LoRA on single node multiGPU was working, now doesn't any more #1840

Closed
niccolor opened this issue Aug 12, 2023 · 10 comments

Comments

@niccolor
Copy link

niccolor commented Aug 12, 2023

I have been experimenting with finetuning the mpt-7b-instruct model on a private dataset. I am developing in databricks notebooks.

This was my setup:
cluster: single driver node with g5.12xlarge

bits&bytes config:

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model config:

config = AutoConfig.from_pretrained(
  'mosaicml/mpt-7b-instruct',
  trust_remote_code=True
)
config.update({"max_seq_len": 4096})

model = AutoModelForCausalLM.from_pretrained(
  'mosaicml/mpt-7b-instruct',
  config=config,
  quantization_config=bnb_config,
  trust_remote_code=True,
  #torch_dtype=torch.bfloat16,
  device_map = 'balanced'
)

lora config (here weight_query_key_modules = [key for key, _ in model.named_modules() if 'Wqkv' in key] is a list of attention layers):

lora_config = LoraConfig(
    r=8,
    lora_alpha=32, 
    target_modules=weight_query_key_modules, 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

peft+lora model preparation:

from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
model_prepared = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) # mpt-7b-instruct does not do gradient checkpointing
peft_model = get_peft_model(model_prepared, lora_config)

skipping the dataset preparation, but ultimately i get the usual input_ids and attention_mask tensors.

training:

trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=my_dataset.select_columns(['input_ids', 'attention_mask']),
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        num_train_epochs=2,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=10,
        log_level='debug',
        log_level_replica='debug',
        logging_strategy='steps',
        #dataloader_pin_memory=False,
        output_dir=f"/dbfs/FileStore/user/###/{datetime.now().strftime('outputs_%Y%m%d_%H_%M_%S')}",
        optim="paged_adamw_8bit", #"adafactor"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

trainer.train()

Up until ~ 2 weeks ago, this used to work. Now it doesn't work anymore.

First, the Trainer initialization crashes and hints that I should upgrade to bitsandbytes==0.41.1 - looks like this error.

So I update, but now trainer.train() crashes with

ValueError: You can't train a model that has been loaded with `device_map='auto'` in any distributed mode. Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`.

which I do not understand because I am not running any script, I did not use device_map = 'auto' (ok, the result is the same, my model is distributed, but still not the clearest error), and generally that did not use to be a problem!

Then I tried to manually avoid the error, by setting
peft_model._is_quantized_training_enabled = True before Trainer initialization.
Then Trainer initializes correctly, although I get the additional warning

Found safetensors installation, but --save_safetensors=False. Safetensors should be a preferred weights saving format due to security and performance reasons. If your model cannot be saved by safetensors please feel free to open an issue at https://github.com/huggingface/safetensors!
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).

which I do not follow - I did not touch the save_safetensors option, so why the additional warning?

That works, but then trainer.train() crashes with the same error

ValueError: You can't train a model that has been loaded with `device_map='auto'` in any distributed mode. Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`.

I tried other combinations of versions - downgrade accelerate to 0.21.0, downgrade transformers to 4.30.0 or 4.31.0... nothing seems to be working (honestly, I do not know if I tried all the possible combinations). I mostly get the same ValueError as before, although sometimes I get the following error:

ValueError: You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode. In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism. Therefore you should not specify that you are under any distributed regime in your accelerate config.

from here which again does not make much sense since I am not passing any accelerate config, and since the same code was working up to a couple weeks ago.

What changed?

Unfortunately I do not know exactly what package versions I had installed two weeks ago. My installation setup was:

%pip install langchain==0.0.162 InstructorEmbedding==1.0.0 sentence-transformers==2.2.2 transformers==4.28.1 accelerate einops nvidia-ml-py3

%pip install -q -U bitsandbytes==0.39.1
%pip install -q -U git+https://github.com/huggingface/transformers.git 
%pip install -q -U git+https://github.com/huggingface/peft.git
%pip install -q -U git+https://github.com/huggingface/accelerate.git
%pip install -q datasets==2.13.1
!pip install protobuf==3.20.*

obviously pretty redundant to pin versions of transformers and then reinstall it, but I didn't need to clean it up until now.

Now, the same installation doesn't work anymore.

I looked at similar issues (this and this for example), but was unable to find a solution for my problem. It seems both issues have been fixed by a PR, but even when installing these libraries from source, my code doesn't work anymore.

Can someone help me here? There must be some configurations of the many libraries involved (torch, transformers, accelerate, peft, bitsandbytes...) that works for what I am trying to do.
@younesbelkada tagging you since you seem to have helped a lot of people here :) and I'm hoping you can add me to the list.

@sgugger
Copy link
Collaborator

sgugger commented Aug 13, 2023

Using device_map="balanced" is the same as using device_map="auto": it puts the model on several GPUs and is not compatible with using DistributedDataParallel for training. You need to launch your script with python and not accelerate launch.

@niccolor
Copy link
Author

Using device_map="balanced" is the same as using device_map="auto": it puts the model on several GPUs and is not compatible with using DistributedDataParallel for training. You need to launch your script with python and not accelerate launch.

Thanks for your reply!

However, my code was working with device_map = "balanced" up until a couple weeks ago. And I am running it from a databricks notebook, without any script. I was not running accelerate - in fact, I am not even instantiating an Accelerator object in my code.

Can you help me figure out why it does not work anymore?
What versions of the many libraries

%pip install -q -U git+https://github.com/huggingface/transformers.git 
%pip install -q -U git+https://github.com/huggingface/peft.git
%pip install -q -U git+https://github.com/huggingface/accelerate.git

involved would I have gotten, if I had been installing from source on, say, July 30?

Again, the code used to work.
Was it doing both model parallelism and data parallelism? Probably not. It certainly was loading the model on multiple GPUs (nvidia-smi as well as checking which device contained the model's named parameters both showed a roughly equal distribution across the 4 GPUs).

Maybe it was not doing data parallelism correct, but at least I had resolved the OOM errors. Trying to run the same notebook on a single GPU gets OOM immediately, meaning that the code was indeed leveraging the multiple GPUs somehow.

@sgugger
Copy link
Collaborator

sgugger commented Aug 15, 2023

Could you please give us the full traceback so we can understand what is going on?

@niccolor
Copy link
Author

niccolor commented Aug 17, 2023

Apologies for the delay. Sure.

The setup (model download and quantization, LoRa layers, peft adapters preparation) is exactly as in my initial message.

The notebook cell

trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=test_dataset_after_tokenization.select_columns(['input_ids', 'attention_mask']),
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        warmup_steps=2,
        max_steps=50,
        num_train_epochs=2,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=10,
        log_level='debug',
        log_level_replica='debug',
        logging_strategy='steps',
        #dataloader_pin_memory=False,
        output_dir=f"/dbfs/FileStore/user/###/{datetime.now().strftime('outputs_%Y%m%d_%H_%M_%S')}",
        optim="paged_adamw_8bit", #"adafactor"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    callbacks=[GPUMemoryLoggerCallback, PrinterCallback]
)
peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

errors out with traceback

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File <command-1150587063542666>:1
----> 1 trainer = transformers.Trainer(
      2     model=peft_model,
      3     train_dataset=test_dataset_after_tokenization.select_columns(['input_ids', 'attention_mask']),
      4     args=transformers.TrainingArguments(
      5         per_device_train_batch_size=1,
      6         gradient_accumulation_steps=1,
      7         warmup_steps=2,
      8         max_steps=50,
      9         num_train_epochs=2,
     10         learning_rate=2e-4,
     11         fp16=True,
     12         logging_steps=10,
     13         log_level='debug',
     14         log_level_replica='debug',
     15         logging_strategy='steps',
     16         #dataloader_pin_memory=False,
     17         output_dir=f"/dbfs/FileStore/user/###/{datetime.now().strftime('outputs_%Y%m%d_%H_%M_%S')}",
     18         optim="paged_adamw_8bit", #"adafactor"
     19     ),
     20     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
     21     callbacks=[GPUMemoryLoggerCallback, PrinterCallback]
     22 )
     23 peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-4d9998b4-a56e-4ea6-a7f3-02ddc4dade59/lib/python3.10/site-packages/transformers/trainer.py:405, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    399         logger.info(
    400             "The model is quantized. To train this model you need to add additional modules"
    401             " inside the model such as adapters using `peft` library and freeze the model weights. Please"
    402             " check the examples in https://github.com/huggingface/peft for more details."
    403         )
    404     else:
--> 405         raise ValueError(
    406             "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
    407             " model, please make sure that you have installed `bitsandbytes>=0.41.1`. "
    408         )
    410 # Setup Sharded DDP training
    411 self.sharded_ddp = None

with the following library versions

Name: accelerate
Version: 0.22.0.dev0
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: sylvain@huggingface.co
License: Apache
Location: /local_disk0/.ephemeral_nfs/envs/pythonEnv-4d9998b4-a56e-4ea6-a7f3-02ddc4dade59/lib/python3.10/site-packages
Requires: numpy, packaging, psutil, pyyaml, torch
Required-by: peft
---
Name: bitsandbytes
Version: 0.39.1
Summary: k-bit optimizers and matrix multiplication routines.
Home-page: https://github.com/TimDettmers/bitsandbytes
Author: Tim Dettmers
Author-email: dettmers@cs.washington.edu
License: MIT
Location: /local_disk0/.ephemeral_nfs/envs/pythonEnv-4d9998b4-a56e-4ea6-a7f3-02ddc4dade59/lib/python3.10/site-packages
Requires: 
Required-by: 
---
Name: transformers
Version: 4.32.0.dev0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /local_disk0/.ephemeral_nfs/envs/pythonEnv-4d9998b4-a56e-4ea6-a7f3-02ddc4dade59/lib/python3.10/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, sentence-transformers
---
Name: peft
Version: 0.5.0.dev0
Summary: Parameter-Efficient Fine-Tuning (PEFT)
Home-page: https://github.com/huggingface/peft
Author: The HuggingFace team
Author-email: sourab@huggingface.co
License: Apache
Location: /local_disk0/.ephemeral_nfs/envs/pythonEnv-4d9998b4-a56e-4ea6-a7f3-02ddc4dade59/lib/python3.10/site-packages
Requires: accelerate, numpy, packaging, psutil, pyyaml, safetensors, torch, tqdm, transformers
Required-by: 

So I restart from the beginning, update bitsandbytes to the require version %pip install -q -U bitsandbytes==0.41.1 and re-run the rest of the preparation steps.

This time the trainer instantiation (exactly as before) succeeds, with cell output

You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set to `True` to avoid any unexpected behavior such as device placement mismatching.
The model is quantized. To train this model you need to add additional modules inside the model such as adapters using `peft` library and freeze the model weights. Please check the examples in https://github.com/huggingface/peft for more details.
max_steps is given, it will override any value given in num_train_epochs

(probably due to my high debug level? not sure)

but the next notebook cell

trainer.train()

fails immediately, with traceback

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File <command-1150587063542667>:1
----> 1 trainer.train()

File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:434, in safe_patch.<locals>.safe_patch_function(*args, **kwargs)
    419 if (
    420     active_session_failed
    421     or autologging_is_disabled(autologging_integration)
   (...)
    428     # warning behavior during original function execution, since autologging is being
    429     # skipped
    430     with set_non_mlflow_warnings_behavior_for_current_thread(
    431         disable_warnings=False,
    432         reroute_warnings=False,
    433     ):
--> 434         return original(*args, **kwargs)
    436 # Whether or not the original / underlying function has been called during the
    437 # execution of patched code
    438 original_has_been_called = False

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b7c9217e-9167-48b0-9ab9-c4a22823650e/lib/python3.10/site-packages/transformers/trainer.py:1545, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1543         hf_hub_utils.enable_progress_bars()
   1544 else:
-> 1545     return inner_training_loop(
   1546         args=args,
   1547         resume_from_checkpoint=resume_from_checkpoint,
   1548         trial=trial,
   1549         ignore_keys_for_eval=ignore_keys_for_eval,
   1550     )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b7c9217e-9167-48b0-9ab9-c4a22823650e/lib/python3.10/site-packages/transformers/trainer.py:1674, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1672         model = self.accelerator.prepare(self.model)
   1673     else:
-> 1674         model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
   1675 else:
   1676     # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
   1677     model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
   1678         self.model, self.optimizer, self.lr_scheduler
   1679     )

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-b7c9217e-9167-48b0-9ab9-c4a22823650e/lib/python3.10/site-packages/accelerate/accelerator.py:1201, in Accelerator.prepare(self, device_placement, *args)
   1195 for obj in args:
   1196     if (
   1197         isinstance(obj, torch.nn.Module)
   1198         and self.verify_device_map(obj)
   1199         and self.distributed_type != DistributedType.NO
   1200     ):
-> 1201         raise ValueError(
   1202             "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
   1203             " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
   1204         )
   1206 if self.distributed_type == DistributedType.FSDP:
   1207     from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

Expected behavior: running the cell

trainer.train()

actually finetunes the model, running the required number of steps as intended, without running OOM (that's what used to happen until late July).

@sgugger
Copy link
Collaborator

sgugger commented Aug 17, 2023

cc @younesbelkada and @pacman100 since it's a PEFT model.

@rich-caputo-oc
Copy link

rich-caputo-oc commented Aug 17, 2023

I'm facing a very similar issue also in Databricks. On a single-node g5.2xlarge (1 GPU), the script runs completely fine, but upon moving to a single-node 4 GPU cluster, I get the following error message when trying to run trainer.train():

ValueError: You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode. In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism. Therefore you should not specify that you are under any distributed regime in your accelerate config.

Traceback

ValueError                                Traceback (most recent call last)
File <command-91325968453>, line 6
      4 except Exception as e:
      5     mlflow.end_run()
----> 6     raise e

File <command-91325968453>, line 3
      1 # Train the model
      2 try:
----> 3     trainer.train()
      4 except Exception as e:
      5     mlflow.end_run()

File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:432, in safe_patch.<locals>.safe_patch_function(*args, **kwargs)
    417 if (
    418     active_session_failed
    419     or autologging_is_disabled(autologging_integration)
   (...)
    426     # warning behavior during original function execution, since autologging is being
    427     # skipped
    428     with set_non_mlflow_warnings_behavior_for_current_thread(
    429         disable_warnings=False,
    430         reroute_warnings=False,
    431     ):
--> 432         return original(*args, **kwargs)
    434 # Whether or not the original / underlying function has been called during the
    435 # execution of patched code
    436 original_has_been_called = False

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1534     self.model_wrapped = self.model
   1536 inner_training_loop = find_executable_batch_size(
   1537     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1538 )
-> 1539 return inner_training_loop(
   1540     args=args,
   1541     resume_from_checkpoint=resume_from_checkpoint,
   1542     trial=trial,
   1543     ignore_keys_for_eval=ignore_keys_for_eval,
   1544 )

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/transformers/trainer.py:1656, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1654         model = self.accelerator.prepare(self.model)
   1655     else:
-> 1656         model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
   1657 else:
   1658     # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
   1659     model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
   1660         self.model, self.optimizer, self.lr_scheduler
   1661     )

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/accelerate/accelerator.py:1202, in Accelerator.prepare(self, device_placement, *args)
   1200     result = self._prepare_megatron_lm(*args)
   1201 else:
-> 1202     result = tuple(
   1203         self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1204     )
   1205     result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
   1207 if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
   1208     # 2. grabbing new model parameters

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/accelerate/accelerator.py:1203, in <genexpr>(.0)
   1200     result = self._prepare_megatron_lm(*args)
   1201 else:
   1202     result = tuple(
-> 1203         self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1204     )
   1205     result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
   1207 if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
   1208     # 2. grabbing new model parameters

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/accelerate/accelerator.py:1030, in Accelerator._prepare_one(self, obj, first_pass, device_placement)
   1028     return self.prepare_data_loader(obj, device_placement=device_placement)
   1029 elif isinstance(obj, torch.nn.Module):
-> 1030     return self.prepare_model(obj, device_placement=device_placement)
   1031 elif isinstance(obj, torch.optim.Optimizer):
   1032     optimizer = self.prepare_optimizer(obj, device_placement=device_placement)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/accelerate/accelerator.py:1270, in Accelerator.prepare_model(self, model, device_placement, evaluation_mode)
   1268 model_devices = set(model.hf_device_map.values())
   1269 if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
-> 1270     raise ValueError(
   1271         "You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode."
   1272         " In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
   1273         " Therefore you should not specify that you are under any distributed regime in your accelerate config."
   1274     )
   1275 current_device = list(model_devices)[0]
   1276 current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device

ValueError: You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode. In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism. Therefore you should not specify that you are under any distributed regime in your accelerate config.

Script Snippet:

# Load model directly
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer

# Set up BNB config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Get tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Get LLM model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config,
    device_map="auto",
    use_cache=False
)

# Set up Lora and peft
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)


def format_instruction(sample):
    ...


# Get datasets
train_dataset, test_dataset = ...

# Set up training
args = TrainingArguments(
    output_dir=MODEL_ID.replace("/", "-"),
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=False,  # disable tqdm since with packing values are in correct
    ddp_find_unused_parameters=False,
)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length=MAX_SEQ_LENGTH,
    packing=True,
    tokenizer=tokenizer,
    formatting_func=format_instruction,
    args=args,
)

trainer.train()

Dependencies

accelerate==0.21.0
bitsandbytes==0.41.1
peft==0.4.0
transformers==4.31.0
trl==0.5.0

Debugging

I tried removing my bnb_config, but this yielded yet another error when running trainer.train():

ValueError: DistributedDataParallel device_ids and output_device arguments only work with single-device/multiple-device GPU modules or CPU modules, but got device_ids [0], output_device 0, and module parameters {device(type='cuda', index=0), device(type='cuda', index=1), device(type='cuda', index=2), device(type='cuda', index=3)}.

Hope this helps!

@younesbelkada
Copy link
Contributor

younesbelkada commented Aug 17, 2023

Hi there, thanks all for the ping, let me try to answer that question to the best way possible.

Firstly, I am very surprised when you say that DDP + multi-GPU + device_map="auto" used to work with 8-bit quantized models as it shouldn't work, see related: huggingface/transformers#22628 / and more precisely: huggingface/peft#269 (comment)

The rootcause of this issue is that you are using device_map="auto" which evenly dispatches the model across all available GPU devices, and running the script in a mutli-GPU configuration will make the accelerator try to wrap the model in a DDP module. As this leads to errors which are similar to: huggingface/peft#269 (comment) we decided to protect that scenario by simply not supporting it.

Hence, two scenarios that are left for us now, depending on the initial training setup:

Run the training setup with Naive PP (Naive Pipeline Parallelism)

If the model does not fit entirely into a single GPU, you can continue using device_map="auto" but instead of running the script with accelerate launch xxx or python -m torch.distributed.run xxx run it simply with python xxx.py. Note that NPP is a naive sequential paradigm that will have a single GPU occupied while all other GPUs are kept idle. Read more about it here: #1523

Use DDP (if the model fits a single GPU)

DDP + quantized models should work if and only if the training setup (meaning model weights, gradients + intermediate hidden states) can entirely fit a single GPU (which I assume is the case since you said:

On a single-node g5.2xlarge (1 GPU), the script runs completely fine, but upon moving to a single-node 4 GPU cluster, I get the following error message

You need a hack so that each working process will load the entire model on the correct GPU. Simply replacing device_map="auto" by the solution below:

Solution
from accelerate import Accelerator

device_index = Accelerator().process_index
device_map = {"": device_index}

...

model = AutoModelForCausalLM.from_pretrained(
   model_id,
   device_map=device_map
   ...
)

then run your script with accelerate launch xxx or python -m torch.distributed.run xxx.

This is what we do in TRL library precisely here and seems to work fine so far, therefore I am sure it should fix the issue for both of you.

Hope that helps

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@xinmengZ
Copy link

xinmengZ commented Oct 5, 2023

rich-caputo-oc

I am struggling with the same issue on Databricks as well. Have you solved your problem?

@owos
Copy link

owos commented Feb 7, 2024

rich-caputo-oc

I am struggling with the same issue on Databricks as well. Have you solved your problem?

try this, it worked for me:

model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants