Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor not on the same device in finetuning a OPT model with device_map=auto #31037

Closed
2 of 4 tasks
shiningrain opened this issue May 26, 2024 · 3 comments · Fixed by #31092
Closed
2 of 4 tasks

Tensor not on the same device in finetuning a OPT model with device_map=auto #31037

shiningrain opened this issue May 26, 2024 · 3 comments · Fixed by #31092

Comments

@shiningrain
Copy link

shiningrain commented May 26, 2024

System Info

  • transformers version: 4.41.1
  • Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes. Using three GPU and device_map=auto in training

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I'm referring to the official example case to finetune an OPT-2.7B model on multiple devices, it raises an error Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward).

Traceback (most recent call last):
  File "/home/user/data/finetune/demo.py", line 165, in <module>
    trainer.train()
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/transformers/trainer.py", line 3264, in compute_loss
    outputs = model(**inputs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/peft/peft_model.py", line 296, in forward
    return self.get_base_model()(*args, **kwargs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 1433, in forward
    start_loss = loss_fct(start_logits, start_positions)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/user/anaconda3/envs/dealrec/lib/python3.10/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)

During debugging, I found that this problem occurred in the calculation of torch._C._nn.cross_entropy_loss. The reason is that the input (on device 2) generated by a hidden layer and the target from the dataset (on device 0) are not on the same device.
However, this problem did not happen when fine-tuning this model on multiple GPU (the implementation refers to this example). The following is the reproduction code.

from datasets import load_dataset
from transformers import AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2" # on three GPUs
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from peft import (  # noqa: E402
    LoraConfig,
    get_peft_model,
)
import transformers
from datasets import load_dataset


model_dir="facebook/opt-2.7b"
output_dir='./tmp_model-opt27-QA'

squad_v2 = False
datasets = load_dataset("squad_v2" if squad_v2 else "squad")
lora_r=8
lora_alpha=16
lora_dropout=0.05#https://opendelta.readthedocs.io/en/latest/modules/deltas.html
lora_target_modules=[
        "q_proj",
        "v_proj",
]
batch_size = 16
max_length = 384 
doc_stride = 128 
lr=2e-5
epochs=3


model = AutoModelForQuestionAnswering.from_pretrained(
    model_dir,
    load_in_8bit=True,
    # torch_dtype=torch.float32,
    device_map='auto',
    # max_new_tokens=1024
)
tokenizer = AutoTokenizer.from_pretrained(model_dir,padding_side="left")
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
pad_on_right = tokenizer.padding_side == "right"

def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        # cls_index = input_ids.index(tokenizer.cls_token_id)
        if tokenizer.cls_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.cls_token_id)
        elif tokenizer.bos_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.bos_token_id)
        else:
            cls_index = 0

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type='QUESTION_ANS'
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names,batch_size=batch_size, num_proc=4)

args = transformers.TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy = "epoch",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    fp16=False,
)
from transformers import default_data_collator

data_collator = default_data_collator
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(output_dir)

Expected behavior

I don't know whether it is a bug or an implementation error in my code. I have read and tried some solutions from related issues and discussions, but they didn't work...

Thanks a lot

@amyeroberts
Copy link
Collaborator

cc @younesbelkada

@younesbelkada
Copy link
Contributor

Hi @shiningrain

There is indeed a fix that we did not propagated in the xxxForQuestionAnswering classes that I fixed in #31092
However, whenever possible, I would advise to go for DDP training by making sure to load the entire model in a single GPU, and replicate the training process across all available GPUs. First run accelerate config and select "multi-GPU". Then modify the script above with:

from datasets import load_dataset
from transformers import AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2" # on three GPUs
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from peft import (  # noqa: E402
    LoraConfig,
    get_peft_model,
)
import transformers
from datasets import load_dataset

+ from accelerate import PartialState

model_dir="facebook/opt-2.7b"
output_dir='./tmp_model-opt27-QA'

squad_v2 = False
datasets = load_dataset("squad_v2" if squad_v2 else "squad")
lora_r=8
lora_alpha=16
lora_dropout=0.05#https://opendelta.readthedocs.io/en/latest/modules/deltas.html
lora_target_modules=[
        "q_proj",
        "v_proj",
]
batch_size = 16
max_length = 384 
doc_stride = 128 
lr=2e-5
epochs=3


model = AutoModelForQuestionAnswering.from_pretrained(
    model_dir,
    load_in_8bit=True,
    # torch_dtype=torch.float32,
-   device_map="auto",
+   device_map={'': PartialState().process_index},
    # max_new_tokens=1024
)
tokenizer = AutoTokenizer.from_pretrained(model_dir,padding_side="left")
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
pad_on_right = tokenizer.padding_side == "right"

def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        # cls_index = input_ids.index(tokenizer.cls_token_id)
        if tokenizer.cls_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.cls_token_id)
        elif tokenizer.bos_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.bos_token_id)
        else:
            cls_index = 0

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type='QUESTION_ANS'
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names,batch_size=batch_size, num_proc=4)

args = transformers.TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy = "epoch",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    fp16=False,
)
from transformers import default_data_collator

data_collator = default_data_collator
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(output_dir)

the line

device_map={'': PartialState().process_index},

Will make sure to set the entire model on the GPU-i for each process i and you should run your training script with accelerate launch xxx.py

@shiningrain
Copy link
Author

@younesbelkada Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants