Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchAcc][Experimental] Integrate TorchAcc. #647

Merged
merged 11 commits into from
Apr 9, 2024

Conversation

baoleai
Copy link
Collaborator

@baoleai baoleai commented Apr 2, 2024

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Model or Dataset Support

PR information

TorchAcc is a framework developed by Alibaba PAI to accelerate PyTorch model training, providing computational acceleration based on compilation optimization and distributed strategies such as FSDP and TP+SP. This PR uses TorchAcc to accelerate the training of Swift SFT LoRA and Full scenarios and provides examples of qwen-72b-chat. Users can enable TorchAcc acceleration by setting export USE_TORCHACC=1. Currently, this feature is still in the experimental stage and only available internally.

Experiment results

Test with 4*80G A100 on qwen-72b-chat lora with script:
sh examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/lora_fsdp_sft.sh

{"eval_loss": 0.35520425, "eval_acc": 0.87996558, "eval_runtime": 753.7452, "eval_samples_per_second": 0.362, "eval_steps_per_second": 0.004, "epoch": 1.0, "global_step": 1123}
{"train_runtime": 13049.8116, "train_samples_per_second": 2.065, "train_steps_per_second": 0.086, "total_flos": 0.0, "train_loss": 0.36302515, "epoch": 1.0, "global_step": 1123}

NNODES=4 \
NPROC_PER_NODE=8 \
swift sft \
--model_type qwen-72b-chat \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对齐需要处理一下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

--gradient_accumulation_steps 1 \
--gradient_checkpointing no \
--tuner_backend 'peft' \
--eval_steps 2000000 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个eval_steps是不是太大了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

--save_steps 2000000 \
--logging_steps 10 \
--preprocess_num_proc 1 \
--dataloader_num_workers 0 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议改成4提高处理效率

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -0,0 +1,31 @@
# Experimental environment: 4 * A800
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# wrapper the model and make these properties wrong.
label_names = find_labels(model)
return_loss = can_return_loss(model)
model = ta.patch_qwen_model(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前仅支持qwen吗,这个方法名是否有点特异化了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个后面一个PR会解决

swift/llm/sft.py Outdated
sft_main = get_main(SftArguments, llm_sft)
def get_sft_main(args, llm):
if use_torchacc():
logger.warning('TorchAcc is currently only available internally.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internally建议改为更具体的场景,否则用户会疑问什么是内部场景

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return [max_length // 4 * (i + 1) for i in range(4)]


def _get_bucket(bucket_sizes, data_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议给torchacc单独建一个py文件

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -428,6 +453,32 @@ def data_collator(self,
loss_scale, batch_first=True, padding_value=0.)
labels = pad_sequence(labels, batch_first=True, padding_value=-100)

if use_torchacc():
rank, _, world_size, _ = get_dist_setting()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

封装成方法,单独放入torchacc单独的py中,以免用户阅读这里的代码的时候有疑问

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not use_torchacc():
return super()._save_tpu(output_dir)

import torch_xla.core.xla_model as xm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样,这里也建议封装单独的方法,放入单独的py中

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return super().get_train_dataloader()
else:
# patch skip_first_batches for customized dataloader.
def acc_skip_first_batches(dataloader, num_batches=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

swift/llm/sft.py Outdated
@@ -181,8 +203,21 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
if val_dataset is not None:
val_dataset = LazyLLMDataset(val_dataset, template)

bucket_sizes = get_bucket_sizes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会不会有问题

Copy link
Collaborator Author

@baoleai baoleai Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted.

@tastelikefeet tastelikefeet merged commit 1e9f8be into modelscope:main Apr 9, 2024
2 checks passed
tastelikefeet added a commit to tastelikefeet/swift that referenced this pull request Apr 10, 2024
* main:
  update Agent best practice with Modelscope-Agent (modelscope#676)
  [TorchAcc][Experimental] Integrate TorchAcc. (modelscope#647)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants