From fdb7a4da899f03f8f175789110be268336fcfe65 Mon Sep 17 00:00:00 2001 From: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Date: Wed, 22 May 2024 18:58:48 +0800 Subject: [PATCH] [TorchAcc][Experimental] Integrate more model in torchacc (#683) --- .../baichuan2_13b_chat/acc_lora_dp_sft.sh | 34 ++ .../baichuan2_13b_chat/acc_lora_fsdp_sft.sh | 34 ++ .../baichuan2_13b_chat/swift_lora_sft.sh | 27 + .../torchacc/chatglm3_6b/acc_lora_dp_sft.sh | 35 ++ .../torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh | 35 ++ .../torchacc/chatglm3_6b/swift_lora_sft.sh | 27 + .../llama2_13b_chat/acc_lora_dp_sft.sh | 35 ++ .../llama2_13b_chat/acc_lora_fsdp_sft.sh | 36 ++ .../llama2_13b_chat/swift_lora_sft.sh | 27 + .../llama3_8b_instruct/acc_lora_dp_sft.sh | 37 ++ .../llama3_8b_instruct/acc_lora_fsdp_sft.sh | 37 ++ .../llama3_8b_instruct/swift_lora_sft.sh | 28 + .../qwen1half_14b_chat/acc_lora_dp_sft.sh | 35 ++ .../qwen1half_14b_chat/acc_lora_fsdp_sft.sh | 36 ++ .../qwen1half_14b_chat/swift_lora_sft.sh | 26 + .../qwen1half_32b_chat/acc_lora_fsdp_sft.sh | 34 ++ .../qwen1half_32b_chat/swift_lora_sft.sh | 26 + .../qwen_72b_chat/acc_full_fsdp_sft.sh} | 2 + .../qwen_72b_chat/acc_lora_fsdp_sft.sh} | 6 +- .../torchacc/qwen_72b_chat/swift_lora_sft.sh | 26 + .../torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh | 34 ++ .../torchacc/yi_34b_chat/swift_lora_sft.sh | 27 + swift/llm/sft.py | 7 +- swift/llm/utils/argument.py | 10 +- swift/torchacc_utils.py | 490 +++++++++++++++++- swift/trainers/arguments.py | 2 + swift/trainers/callback.py | 25 +- swift/trainers/mixin.py | 5 +- swift/trainers/trainers.py | 5 +- swift/utils/__init__.py | 2 +- swift/utils/torch_utils.py | 4 + 31 files changed, 1181 insertions(+), 13 deletions(-) create mode 100644 examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/chatglm3_6b/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/swift_lora_sft.sh rename examples/pytorch/llm/scripts/{qwen_72b_chat/torchacc/full_fsdp_sft.sh => torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh} (95%) rename examples/pytorch/llm/scripts/{qwen_72b_chat/torchacc/lora_fsdp_sft.sh => torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh} (90%) create mode 100644 examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/swift_lora_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh create mode 100644 examples/pytorch/llm/scripts/torchacc/yi_34b_chat/swift_lora_sft.sh diff --git a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh new file mode 100644 index 000000000..d6c143b27 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_dp_sft.sh @@ -0,0 +1,34 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +# torchacc dp +export USE_TORCHACC=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +MASTER_PORT=27829 \ +swift sft \ + --model_id_or_path baichuan-inc/Baichuan2-13B-Chat \ + --model_layer_cls_name BaichuanLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..5277721c6 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/acc_lora_fsdp_sft.sh @@ -0,0 +1,34 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +# torchacc fsdp +export USE_TORCHACC=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path baichuan-inc/Baichuan2-13B-Chat \ + --model_layer_cls_name BaichuanLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 2 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..e4d13ba2e --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/baichuan2_13b_chat/swift_lora_sft.sh @@ -0,0 +1,27 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path baichuan-inc/Baichuan2-13B-Chat \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 2 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh new file mode 100644 index 000000000..d059dc7c1 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_dp_sft.sh @@ -0,0 +1,35 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +# torchacc dp +export USE_TORCHACC=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +MASTER_PORT=27829 \ +swift sft \ + --model_id_or_path ZhipuAI/chatglm3-6b \ + --model_layer_cls_name GLMBlock \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..5993ab853 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/acc_lora_fsdp_sft.sh @@ -0,0 +1,35 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +# torchacc fsdp +export USE_TORCHACC=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path ZhipuAI/chatglm3-6b \ + --model_layer_cls_name GLMBlock \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 2 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/swift_lora_sft.sh new file mode 100644 index 000000000..c8b666158 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/chatglm3_6b/swift_lora_sft.sh @@ -0,0 +1,27 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ +# MASTER_PORT=12356 \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path ZhipuAI/chatglm3-6b \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 4 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh new file mode 100644 index 000000000..c5df3e81f --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_dp_sft.sh @@ -0,0 +1,35 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +export USE_TORCHACC=1 +export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path modelscope/Llama-2-13b-chat-ms \ + --model_layer_cls_name LlamaDecoderLayer \ + --dataset codefuse-python-en \ + --template_type llama \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..5d84ea90f --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/acc_lora_fsdp_sft.sh @@ -0,0 +1,36 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +export USE_TORCHACC=1 +export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +MASTER_PORT=27829 \ +swift sft \ + --model_id_or_path modelscope/Llama-2-13b-chat-ms \ + --model_layer_cls_name LlamaDecoderLayer \ + --dataset codefuse-python-en \ + --template_type llama \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 24 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 2 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..2fd5c1921 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama2_13b_chat/swift_lora_sft.sh @@ -0,0 +1,27 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path modelscope/Llama-2-13b-chat-ms \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh new file mode 100644 index 000000000..f86b55436 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_dp_sft.sh @@ -0,0 +1,37 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +export USE_TORCHACC=1 +export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select +export XLA_COORDINATOR_PORT=12457 + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +MASTER_PORT=21779 \ +swift sft \ + --model_id_or_path LLM-Research/Meta-Llama-3-8B-Instruct \ + --model_layer_cls_name LlamaDecoderLayer \ + --dataset codefuse-python-en \ + --template_type llama3 \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..36f69b179 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/acc_lora_fsdp_sft.sh @@ -0,0 +1,37 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +export USE_TORCHACC=1 +export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=100000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select +# export XLA_COORDINATOR_PORT=12457 + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +MASTER_PORT=27829 \ +swift sft \ + --model_id_or_path LLM-Research/Meta-Llama-3-8B-Instruct \ + --model_layer_cls_name LlamaDecoderLayer \ + --dataset codefuse-python-en \ + --template_type llama3 \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 2 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/swift_lora_sft.sh new file mode 100644 index 000000000..659c9edf1 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/llama3_8b_instruct/swift_lora_sft.sh @@ -0,0 +1,28 @@ +# Experimental environment: 2 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_id_or_path LLM-Research/Meta-Llama-3-8B-Instruct \ + --dataset codefuse-python-en \ + --template_type llama3 \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 16 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh new file mode 100644 index 000000000..4a6f0894a --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_dp_sft.sh @@ -0,0 +1,35 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +export USE_TORCHACC=1 +# export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=1000000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=2,3 \ +MASTER_PORT=23797 \ +swift sft \ +--model_type qwen1half-14b-chat \ + --model_layer_cls_name Qwen2DecoderLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 8 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 1 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..6e57d9c5e --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/acc_lora_fsdp_sft.sh @@ -0,0 +1,36 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +DEBUG_PREFIX=qwen15_14b +DEBUG_PATH=torchacc_debug/qwen15/ +export USE_TORCHACC=1 +# export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=1000000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select +MASTER_PORT=23783 \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ +--model_type qwen1half-14b-chat \ + --model_layer_cls_name Qwen2DecoderLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 2 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..8e2a8e0d3 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_14b_chat/swift_lora_sft.sh @@ -0,0 +1,26 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift sft \ + --model_type qwen1half-14b-chat \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 4 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..2a5d6644f --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/acc_lora_fsdp_sft.sh @@ -0,0 +1,34 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +export USE_TORCHACC=1 +# export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=1000000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift sft \ +--model_type qwen1half-32b-chat \ + --model_layer_cls_name Qwen2DecoderLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 4 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..88d1aa325 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen1half_32b_chat/swift_lora_sft.sh @@ -0,0 +1,26 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift sft \ + --model_type qwen1half-32b-chat \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 1 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ diff --git a/examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/full_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh similarity index 95% rename from examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/full_fsdp_sft.sh rename to examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh index 07e0853fa..c819d3d44 100644 --- a/examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/full_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_full_fsdp_sft.sh @@ -30,4 +30,6 @@ swift sft \ --eval_steps 200 \ --save_steps 200 \ --logging_steps 100 \ + --metric_warmup_step 0.1 \ --report_to 'none' + --fsdp_num 32 diff --git a/examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh similarity index 90% rename from examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/lora_fsdp_sft.sh rename to examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh index aa61abf8a..df3cdf35f 100644 --- a/examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/lora_fsdp_sft.sh +++ b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/acc_lora_fsdp_sft.sh @@ -1,4 +1,4 @@ -# Experimental environment: 4 * A800 +# Experimental environment: 4 * A100 # 80GB GPU memory # Note: TorchAcc is currently only available internally. @@ -18,7 +18,7 @@ swift sft \ --output_dir output_qwen_72b \ --num_train_epochs 1 \ --max_length 2048 \ - --batch_size 6 \ + --batch_size 4 \ --use_flash_attn true \ --gradient_accumulation_steps 1 \ --gradient_checkpointing no \ @@ -26,4 +26,6 @@ swift sft \ --eval_steps 200 \ --save_steps 200 \ --logging_steps 100 \ + --metric_warmup_step 0.1 \ --report_to 'none' \ + --fsdp_num 4 \ diff --git a/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..384cd1047 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/qwen_72b_chat/swift_lora_sft.sh @@ -0,0 +1,26 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift sft \ + --model_id_or_path qwen/Qwen-72B-Chat \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 1 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ diff --git a/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh new file mode 100644 index 000000000..17bb17dbd --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/acc_lora_fsdp_sft.sh @@ -0,0 +1,34 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. +export USE_TORCHACC=1 +export TORCHACC_TRIM_GRAPH=1 +export XLA_FLAGS='--xla_gpu_force_compilation_parallelism=32 --xla_multiheap_size_constraint_per_heap=4831838208 --xla_disable_hlo_passes=all-gather-combiner,all-reduce-combiner,reduce-scatter-combiner,gpu-convert-async-collectives-to-sync,rematerialization' +export XLA_IR_SHAPE_CACHE_SIZE=1000000000 +export XLA_ALLOCATOR_FRACTION=0.95 +export XLA_EXPERIMENTAL=nonzero:masked_select + +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift sft \ +--model_type yi-34b-chat \ + --model_layer_cls_name LlamaDecoderLayer \ + --dataset codefuse-python-en \ + --sft_type lora \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 12 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --gradient_checkpointing no \ + --tuner_backend 'peft' \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --fsdp_num 4 \ + --report_to 'none' diff --git a/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/swift_lora_sft.sh b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/swift_lora_sft.sh new file mode 100644 index 000000000..281ef9434 --- /dev/null +++ b/examples/pytorch/llm/scripts/torchacc/yi_34b_chat/swift_lora_sft.sh @@ -0,0 +1,27 @@ +# Experimental environment: 4 * A100 +# 80GB GPU memory +# Note: TorchAcc is currently only available internally. + +# MASTER_ADDR=127.0.0.1 \ + +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift sft \ + --model_type yi-34b-chat \ + --dataset codefuse-python-en \ + --sft_type lora \ + --dtype AUTO \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --batch_size 1 \ + --use_flash_attn true \ + --gradient_accumulation_steps 1 \ + --dataset_test_ratio 0 \ + --save_strategy no \ + --eval_steps 2000000 \ + --save_steps 2000000 \ + --logging_steps 100 \ + --preprocess_num_proc 1 \ + --metric_warmup_step 0.1 \ + --report_to 'none' diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 563cd2831..5b5dadd16 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -13,6 +13,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import is_torch_npu_available +from swift.torchacc_utils import patch_acc_model from swift.trainers import Seq2SeqTrainer from swift.trainers.utils import can_return_loss, find_labels from swift.utils import (check_json_format, compute_acc_metrics, compute_nlg_metrics, get_dist_setting, get_logger, @@ -28,6 +29,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: + logger.info(f'args: {args}') seed_everything(args.seed) training_args = args.training_args @@ -128,7 +130,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: # wrapper the model and make these properties wrong. label_names = find_labels(model) return_loss = can_return_loss(model) - model = ta.patch_qwen_model(model) + model = patch_acc_model(model, args) # Preparing LoRA model, callbacks = prepare_model(model, args) @@ -149,7 +151,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: logger.info('Setting model.config.use_cache: False') model = ta_accelerate( model, - world_size, + args.fsdp_num, args.model_layer_cls_name, args.bf16, args.fp16, @@ -175,6 +177,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: model_author=args.model_author) train_dataset, val_dataset = args._handle_dataset_compat(train_dataset, val_dataset) + training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0 logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') template_kwargs = {} diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index d51e5cd3c..863b42b3c 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -21,7 +21,7 @@ from swift.trainers import Seq2SeqTrainingArguments from swift.tuners import Swift from swift.utils import (add_version_to_work_dir, get_dist_setting, get_logger, get_pai_tensorboard_dir, is_dist, - is_local_master, is_mp, is_pai_training_job) + is_local_master, is_mp, is_pai_training_job, use_torchacc) from .dataset import DATASET_MAPPING, _dataset_name_exists, get_dataset, register_dataset_info_file, sample_dataset from .model import (MODEL_MAPPING, dtype_mapping, get_additional_saved_files, get_default_lora_target_modules, get_default_template_type) @@ -539,6 +539,7 @@ class SftArguments(ArgumentsBase): logging_steps: int = 5 dataloader_num_workers: int = 1 dataloader_pin_memory: bool = True + dataloader_drop_last: bool = False # push to ms hub push_to_hub: bool = False @@ -614,6 +615,8 @@ class SftArguments(ArgumentsBase): neftune_alpha: Optional[float] = None deepspeed_config_path: Optional[str] = None model_cache_dir: Optional[str] = None + metric_warmup_step: Optional[float] = 0 # only use in torchacc + fsdp_num: int = 1 custom_train_dataset_path: List[str] = field(default_factory=list) custom_val_dataset_path: List[str] = field(default_factory=list) @@ -831,6 +834,9 @@ def __post_init__(self) -> None: elif not support_gradient_checkpointing and self.gradient_checkpointing: logger.warning(f'{self.model_type} not support gradient_checkpointing.') + if use_torchacc(): + self.dataloader_drop_last = True + self._init_training_args() if self.add_output_dir_suffix is None: @@ -912,8 +918,10 @@ def _init_training_args(self) -> None: acc_strategy=self.acc_strategy, save_safetensors=self.save_safetensors, logging_first_step=True, + metric_warmup_step=self.metric_warmup_step, fsdp=self.fsdp, fsdp_config=self.fsdp_config, + dataloader_drop_last=self.dataloader_drop_last, **kwargs) training_args.ddp_find_unused_parameters = self.ddp_find_unused_parameters diff --git a/swift/torchacc_utils.py b/swift/torchacc_utils.py index 4b2d113fe..86569a9d4 100644 --- a/swift/torchacc_utils.py +++ b/swift/torchacc_utils.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import math import os import sys +import types from collections import OrderedDict -from typing import List +from typing import List, Optional, Tuple import safetensors import torch @@ -12,7 +14,7 @@ from transformers import PreTrainedModel, trainer from transformers.modeling_utils import unwrap_model -from swift.utils import get_logger +from swift.utils import get_logger, torchacc_trim_graph, use_torchacc logger = get_logger() @@ -254,3 +256,487 @@ def save_ta_checkpoint(self_model, tokenizer, args, output_dir): if tokenizer is not None and args.should_save: tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save) + + +def ta_trim_graph(): + if use_torchacc() and torchacc_trim_graph(): + import torchacc as ta + ta.mark_step() + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def patch_acc_model(model, args): + if not args.use_flash_attn: + logger.warn('Currently use flash attn for torchacc.') + if args.model_type.startswith('qwen1half'): + model = patch_qwen2_model(model) + elif args.model_type.startswith('qwen'): + import torchacc as ta + model = ta.patch_qwen_model(model) + elif args.model_type.startswith('baichuan'): + model = patch_baichuan_model(model) + elif args.model_type.startswith('llama') or args.model_type.startswith('yi'): + model = patch_llama_model(model) + elif args.model_type.startswith('chatglm'): + model = patah_chatglm_model(model) + return model + + +def patch_llama_model(model): + + def llama_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + from torchacc.ops import flash_attn_varlen_xla + import einops + + bsz, q_len, _ = hidden_states.size() + + query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, 'past_key_value is not supported' + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + assert not output_attentions, 'output_attentions is not supported' + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) if use_cache else None + + # See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + # if attention_mask is not None: + # value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1) + q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...') + k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...') + v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...') + max_s = q_len + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True) + output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) + + return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value + + for layer in model.model.layers: + layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn) + + return model + + +def patah_chatglm_model(model): + + def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + def chatglm_attn_forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + + # ================================== + # core attention computation + # ================================== + + from torchacc.ops import flash_attn_varlen_qkvpacked_xla + import einops + + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + bsz, _, q_len, _ = query_layer.size() + qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) + qkv = qkv.transpose(1, 3) + qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + context_layer = flash_attn_varlen_qkvpacked_xla(qkv, cu_q_lens, q_len, 0.0, None, True, False) + context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz) + context_layer = context_layer.permute(1, 0, 2, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, ) + context_layer = context_layer.reshape(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + def torchacc_swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]).to(x[0].dtype) * x[1] + + # patch attention + for layer in model.transformer.encoder.layers: + layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention) + layer.mlp.activation_func = torchacc_swiglu + + return model + + +def patch_baichuan_model(model): + + def baichuan_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + import einops + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)) + query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + from torchacc.ops import flash_attn_varlen_xla + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True) + output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) + output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')) + return output, None, past_key_value + + for layer in model.base_model.layers: + layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn) + + return model + + +def patch_qwen2_model(model): + + def qwen2_attn_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.') + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + rotary_seq_len = kv_seq_len + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + from torchacc.ops import flash_attn_varlen_xla + import einops + + q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + + attn_output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def qwen2_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds') + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + from transformers.modeling_outputs import BaseModelOutputWithPast + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + for layer in model.model.layers: + layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn) + + model.model.forward = types.MethodType(qwen2_forward, model.model) + return model diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 63528bf9c..de79be1f7 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -21,6 +21,8 @@ class SwiftArgumentsMixin: default='push_best', metadata={'choices': {'end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'}}) acc_strategy: str = field(default='token', metadata={'choices': ['token', 'sentence']}) additional_saved_files: Optional[List[str]] = None + metric_warmup_step: Optional[float] = 0 + train_dataset_sample: Optional[int] = -1 def __post_init__(self): if is_dist() and self.ddp_backend == 'nccl' and torch.cuda.is_available() and is_accelerate_available(): diff --git a/swift/trainers/callback.py b/swift/trainers/callback.py index f510dc301..a787bf049 100644 --- a/swift/trainers/callback.py +++ b/swift/trainers/callback.py @@ -1,13 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import time import json from tqdm.auto import tqdm from transformers.trainer_callback import (DefaultFlowCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState) -from transformers.trainer_utils import IntervalStrategy, has_length +from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics -from swift.utils import is_pai_training_job +from swift.utils import is_pai_training_job, use_torchacc from .arguments import TrainingArguments @@ -17,6 +18,11 @@ def on_train_begin(self, args, state, control, **kwargs): if state.is_local_process_zero: self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True) self.current_step = 0 + if use_torchacc(): + self.warmup_start_time = 0 + self.warmup_metric = None + self.metric_warmup_step = int(args.metric_warmup_step + * state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs): if state.is_local_process_zero and has_length(eval_dataloader): @@ -29,6 +35,21 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): logs['global_step'] = state.global_step + if use_torchacc(): + if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0: + self.warmup_start_time = time.time() + self.metric_warmup_step = state.global_step + if state.max_steps == state.global_step and self.warmup_metric is None: + num_steps = state.max_steps - self.metric_warmup_step + num_total_samples = args.train_dataset_sample + num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps) + self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples, + num_steps) + self.warmup_metric['num_total_samples'] = num_total_samples + self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples + if 'train_samples_per_second' in logs: + logs.update(self.warmup_metric) + state.log_history[-1] = logs for k, v in logs.items(): if isinstance(v, float): logs[k] = round(logs[k], 8) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 84b7218ee..705db931d 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -31,7 +31,8 @@ from swift.hub import Repository from swift.hub.check_model import check_local_model_is_latest -from swift.torchacc_utils import save_ta_checkpoint, ta_load_optimizer_and_scheduler, ta_save_optimizer_and_scheduler +from swift.torchacc_utils import (save_ta_checkpoint, ta_load_optimizer_and_scheduler, ta_save_optimizer_and_scheduler, + ta_trim_graph) from swift.tuners import SwiftModel from swift.utils import check_json_format, create_ms_repo, get_logger, use_torchacc from swift.utils.constants import Invoke @@ -522,6 +523,8 @@ def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs): if self.control.should_log: + if use_torchacc(): + ta_trim_graph() self.control.should_log = False logs: Dict[str, float] = {} metrics_log = {'loss': tr_loss} # loss first diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 86ed5877b..a6d00dac5 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -14,7 +14,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import is_peft_available -from swift.torchacc_utils import ta_eval_dataloader, ta_test_dataloader, ta_train_dataloader +from swift.torchacc_utils import ta_eval_dataloader, ta_test_dataloader, ta_train_dataloader, ta_trim_graph from swift.utils import use_torchacc from .callback import DefaultFlowCallbackNew, PrinterCallbackNew, ProgressCallbackNew from .mixin import PushToMsHubMixin, SwiftMixin @@ -206,7 +206,8 @@ def compute_loss(self, model, inputs, return_outputs=None): loss = self.label_smoother(outputs, labels) else: loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0] - + if use_torchacc(): + ta_trim_graph() preds = outputs.logits.argmax(dim=2)[..., :-1] if labels is None: labels = inputs['labels'] diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 0f751fb40..2c8d6e1b7 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -9,7 +9,7 @@ from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing from .torch_utils import (activate_model_parameters, broadcast_string, freeze_model_parameters, get_dist_setting, get_model_info, is_ddp_plus_mp, is_dist, is_local_master, is_master, is_mp, is_on_same_device, - show_layers, time_synchronize, use_torchacc) + show_layers, time_synchronize, torchacc_trim_graph, use_torchacc) from .utils import (add_version_to_work_dir, check_json_format, get_pai_tensorboard_dir, is_pai_training_job, lower_bound, parse_args, read_multi_line, safe_ddp_context, seed_everything, subprocess_run, test_time, upper_bound) diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index d452b03d6..450b166f2 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -78,6 +78,10 @@ def use_torchacc() -> bool: return os.getenv('USE_TORCHACC', '0') == '1' +def torchacc_trim_graph(): + return os.getenv('TORCHACC_TRIM_GRAPH', '0') == '1' + + def is_dist(): """Determine if the training is distributed""" if use_torchacc():