Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError: too many indices for tensor of dimension 2 #3560

Closed
1 task done
heroding77 opened this issue May 3, 2024 · 3 comments
Closed
1 task done

IndexError: too many indices for tensor of dimension 2 #3560

heroding77 opened this issue May 3, 2024 · 3 comments
Labels
solved This problem has been already solved

Comments

@heroding77
Copy link

Reminder

  • I have read the README and searched the existing issues.

Reproduction

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 src/train_bash.py
--stage sft
--do_train True
--model_name_or_path xxx
--finetuning_type lora
--template llama3
--flash_attn fa2
--dataset_dir data
--dataset LongAlpaca-12k
--cutoff_len 32768
--learning_rate 2e-05
--num_train_epochs 3.0
--max_samples 100000
--per_device_train_batch_size 2
--gradient_accumulation_steps 8
--lr_scheduler_type constant_with_warmup
--max_grad_norm 1.0
--logging_steps 5
--save_steps 140
--warmup_steps 20
--optim adamw_torch
--shift_attn True
--report_to none
--output_dir xxx
--fp16 True
--lora_rank 8
--lora_alpha 16
--lora_dropout 0.1
--use_dora True
--lora_target all
--plot_loss True

Expected behavior

When I use flash_attn2 and shift_attn together, I get an error: IndexError: too many indices for tensor of dimension 2. There is no problem when I use flash_attn2 or shift_attn separately. What could be the reason for this?

System Info

  • transformers version: 4.40.0
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Others

flash_attn version: 2.5.8

@heroding77
Copy link
Author

Additional error information:
image

@tanghui315
Copy link

me too , how to solve this problem ?

@hiyouga hiyouga closed this as completed in 0f8f7d3 May 7, 2024
@hiyouga
Copy link
Owner

hiyouga commented May 7, 2024

fixed

@hiyouga hiyouga added the solved This problem has been already solved label May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

3 participants