-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchAcc][Experimental] Integrate TorchAcc. #647
Conversation
NNODES=4 \ | ||
NPROC_PER_NODE=8 \ | ||
swift sft \ | ||
--model_type qwen-72b-chat \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对齐需要处理一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
--gradient_accumulation_steps 1 \ | ||
--gradient_checkpointing no \ | ||
--tuner_backend 'peft' \ | ||
--eval_steps 2000000 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个eval_steps是不是太大了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
--save_steps 2000000 \ | ||
--logging_steps 10 \ | ||
--preprocess_num_proc 1 \ | ||
--dataloader_num_workers 0 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议改成4提高处理效率
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -0,0 +1,31 @@ | |||
# Experimental environment: 4 * A800 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
# wrapper the model and make these properties wrong. | ||
label_names = find_labels(model) | ||
return_loss = can_return_loss(model) | ||
model = ta.patch_qwen_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前仅支持qwen吗,这个方法名是否有点特异化了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个后面一个PR会解决
swift/llm/sft.py
Outdated
sft_main = get_main(SftArguments, llm_sft) | ||
def get_sft_main(args, llm): | ||
if use_torchacc(): | ||
logger.warning('TorchAcc is currently only available internally.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
internally建议改为更具体的场景,否则用户会疑问什么是内部场景
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
swift/llm/utils/template.py
Outdated
return [max_length // 4 * (i + 1) for i in range(4)] | ||
|
||
|
||
def _get_bucket(bucket_sizes, data_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议给torchacc单独建一个py文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -428,6 +453,32 @@ def data_collator(self, | |||
loss_scale, batch_first=True, padding_value=0.) | |||
labels = pad_sequence(labels, batch_first=True, padding_value=-100) | |||
|
|||
if use_torchacc(): | |||
rank, _, world_size, _ = get_dist_setting() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
封装成方法,单独放入torchacc单独的py中,以免用户阅读这里的代码的时候有疑问
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
swift/trainers/mixin.py
Outdated
if not use_torchacc(): | ||
return super()._save_tpu(output_dir) | ||
|
||
import torch_xla.core.xla_model as xm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同样,这里也建议封装单独的方法,放入单独的py中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
swift/trainers/trainers.py
Outdated
return super().get_train_dataloader() | ||
else: | ||
# patch skip_first_batches for customized dataloader. | ||
def acc_skip_first_batches(dataloader, num_batches=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
swift/llm/sft.py
Outdated
@@ -181,8 +203,21 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: | |||
if val_dataset is not None: | |||
val_dataset = LazyLLMDataset(val_dataset, template) | |||
|
|||
bucket_sizes = get_bucket_sizes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里会不会有问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted.
e0fe1d4
to
dd5b360
Compare
dd5b360
to
30ad8c8
Compare
2a849fa
to
48ff067
Compare
* main: update Agent best practice with Modelscope-Agent (modelscope#676) [TorchAcc][Experimental] Integrate TorchAcc. (modelscope#647)
PR type
PR information
TorchAcc is a framework developed by Alibaba PAI to accelerate PyTorch model training, providing computational acceleration based on compilation optimization and distributed strategies such as FSDP and TP+SP. This PR uses TorchAcc to accelerate the training of Swift SFT LoRA and Full scenarios and provides examples of qwen-72b-chat. Users can enable TorchAcc acceleration by setting
export USE_TORCHACC=1
. Currently, this feature is still in the experimental stage and only available internally.Experiment results
Test with 4*80G A100 on qwen-72b-chat lora with script:
sh examples/pytorch/llm/scripts/qwen_72b_chat/torchacc/lora_fsdp_sft.sh
{"eval_loss": 0.35520425, "eval_acc": 0.87996558, "eval_runtime": 753.7452, "eval_samples_per_second": 0.362, "eval_steps_per_second": 0.004, "epoch": 1.0, "global_step": 1123}
{"train_runtime": 13049.8116, "train_samples_per_second": 2.065, "train_steps_per_second": 0.086, "total_flos": 0.0, "train_loss": 0.36302515, "epoch": 1.0, "global_step": 1123}