-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor not on the same device in finetuning a OPT model with device_map=auto
#31037
Comments
Hi @shiningrain There is indeed a fix that we did not propagated in the from datasets import load_dataset
from transformers import AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2" # on three GPUs
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from peft import ( # noqa: E402
LoraConfig,
get_peft_model,
)
import transformers
from datasets import load_dataset
+ from accelerate import PartialState
model_dir="facebook/opt-2.7b"
output_dir='./tmp_model-opt27-QA'
squad_v2 = False
datasets = load_dataset("squad_v2" if squad_v2 else "squad")
lora_r=8
lora_alpha=16
lora_dropout=0.05#https://opendelta.readthedocs.io/en/latest/modules/deltas.html
lora_target_modules=[
"q_proj",
"v_proj",
]
batch_size = 16
max_length = 384
doc_stride = 128
lr=2e-5
epochs=3
model = AutoModelForQuestionAnswering.from_pretrained(
model_dir,
load_in_8bit=True,
# torch_dtype=torch.float32,
- device_map="auto",
+ device_map={'': PartialState().process_index},
# max_new_tokens=1024
)
tokenizer = AutoTokenizer.from_pretrained(model_dir,padding_side="left")
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
pad_on_right = tokenizer.padding_side == "right"
def prepare_train_features(examples):
# Some of the questions have lots of whitespace on the left, which is not useful and will make the
# truncation of the context fail (the tokenized question will take a lots of space). So we remove that
# left whitespace
examples["question"] = [q.lstrip() for q in examples["question"]]
# Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples["question" if pad_on_right else "context"],
examples["context" if pad_on_right else "question"],
truncation="only_second" if pad_on_right else "only_first",
max_length=max_length,
stride=doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# The offset mappings will give us a map from token to character position in the original context. This will
# help us compute the start_positions and end_positions.
offset_mapping = tokenized_examples.pop("offset_mapping")
# Let's label those examples!
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
# cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
answers = examples["answers"][sample_index]
# If no answers are given, set the cls_index as answer.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Start/end character index of the answer in the text.
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
return tokenized_examples
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type='QUESTION_ANS'
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names,batch_size=batch_size, num_proc=4)
args = transformers.TrainingArguments(
output_dir=output_dir,
evaluation_strategy = "epoch",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
weight_decay=0.01,
fp16=False,
)
from transformers import default_data_collator
data_collator = default_data_collator
trainer = transformers.Trainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(output_dir) the line device_map={'': PartialState().process_index}, Will make sure to set the entire model on the GPU-i for each process i and you should run your training script with |
@younesbelkada Thank you very much! |
System Info
transformers
version: 4.41.1device_map=auto
in trainingWho can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When I'm referring to the official example case to finetune an OPT-2.7B model on multiple devices, it raises an error
Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
.During debugging, I found that this problem occurred in the calculation of
torch._C._nn.cross_entropy_loss
. The reason is that theinput
(on device 2) generated by a hidden layer and thetarget
from the dataset (on device 0) are not on the same device.However, this problem did not happen when fine-tuning this model on multiple GPU (the implementation refers to this example). The following is the reproduction code.
Expected behavior
I don't know whether it is a bug or an implementation error in my code. I have read and tried some solutions from related issues and discussions, but they didn't work...
Thanks a lot
The text was updated successfully, but these errors were encountered: