-
Notifications
You must be signed in to change notification settings - Fork 78
Adding FSDP Support to Training Library #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c6b14f4
cdca42c
1114539
2e33add
80e65d3
802f74f
d4851d1
ca10236
4a89949
5c3d96c
cb9fec7
32b7265
7a2ca02
b02fa03
b736c25
6d4bb46
64595d4
5f582c1
ca274f6
85ac691
2f96481
807d9f7
3907e38
0208e64
b5e6d38
d3010a7
d9b71c0
fa8c1ba
291acdb
eff3fef
edf7618
14f2b08
fdbe288
8fb86ca
ffb7fab
2b35ed1
065df3e
e472b4d
8f8dc8f
72abd67
7f7e9f2
723a85e
7cd6747
95eb2c0
dd47117
4cdfb8d
70ff83c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,39 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # First Party | ||
| from instructlab.training.tokenizer_utils import SpecialTokens | ||
| from instructlab.training.tokenizer_utils import SpecialTokens, TokenInfo | ||
|
|
||
| SPECIAL_TOKENS = SpecialTokens( | ||
| bos="<s>", | ||
| eos="</s>", | ||
| user="[INST]", | ||
| assistant="[/INST]", | ||
| bos=TokenInfo("<s>", add_to_tokenizer=True), | ||
| eos=TokenInfo("</s>", add_to_tokenizer=True), | ||
Maxusmusti marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| user=TokenInfo("[INST]", add_to_tokenizer=False), | ||
| assistant=TokenInfo("[/INST]", add_to_tokenizer=False), | ||
| ) | ||
|
|
||
| CHAT_TEMPLATE = ( | ||
| "{%- if messages[0]['role'] == 'system' %}" | ||
| "{%- set system_message = messages[0]['content'] %}" | ||
| "{%- set loop_messages = messages[1:] %}" | ||
| "{%- else %}" | ||
| "{%- set loop_messages = messages %}" | ||
| "{%- endif %}" | ||
| "{{ '<s>' }}" | ||
| "{% for message in messages %}" | ||
| "{% if message['role'] == 'pretraining' %}" | ||
| "{{'<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>'}}" | ||
| "{% elif message['role'] == 'user' %}" | ||
| "{{ '[INST] ' + message['content'] + ' [/INST]' }}" | ||
| "{% elif message['role'] == 'assistant' %}" | ||
| "{{ message['content'] + '</s>'}}" | ||
| "{% endif %}" | ||
| "{% endfor %}" | ||
| "{%- for message in loop_messages %}" | ||
| "{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" | ||
| "{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}" | ||
| "{%- endif %}" | ||
| "{%- if message['role'] == 'user' %}" | ||
| "{%- if loop.first and system_message is defined %}" | ||
| "{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}" | ||
| "{%- else %}" | ||
| "{{- ' [INST] ' + message['content'] + ' [/INST]' }}" | ||
| "{%- endif %}" | ||
| "{%- elif message['role'] == 'pretraining' %}" | ||
| "{{- '<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>' }}" | ||
| "{%- elif message['role'] == 'assistant' %}" | ||
| "{{- ' ' + message['content'] + '</s>'}}" | ||
| "{%- else %}" | ||
| "{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}" | ||
| "{%- endif %}" | ||
| "{%- endfor %}" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,12 @@ class DeepSpeedOffloadStrategy(Enum): | |
| NONE = None | ||
|
|
||
|
|
||
| # public API | ||
| class DistributedBackend(Enum): | ||
| FSDP: str = "fsdp" | ||
| DEEPSPEED: str = "deepspeed" | ||
|
|
||
|
|
||
| # public API | ||
| class QuantizeDataType(Enum): | ||
| """ | ||
|
|
@@ -111,6 +117,24 @@ class DeepSpeedOptions(BaseModel): | |
| save_samples: int | None = None | ||
|
|
||
|
|
||
| # public API | ||
| class ShardingStrategies(Enum): | ||
| FULL_SHARD = "FULL_SHARD" | ||
| SHARD_GRAD_OP = "SHARD_GRAD_OP" | ||
| NO_SHARD = "NO_SHARD" | ||
| HYBRID_SHARD = "HYBRID_SHARD" | ||
Maxusmusti marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| # public API | ||
| class FSDPOptions(BaseModel): | ||
| """ | ||
| Represents the options for configuring FSDP which are exposed by the Training Library | ||
| """ | ||
|
|
||
| cpu_offload_params: Optional[bool] = False | ||
| sharding_strategy: ShardingStrategies = ShardingStrategies.SHARD_GRAD_OP | ||
|
|
||
|
|
||
| # public API | ||
| class TrainingArgs(BaseModel): | ||
| """ | ||
|
|
@@ -157,6 +181,12 @@ class TrainingArgs(BaseModel): | |
| cpu_offload_optimizer_pin_memory=False, | ||
| ) | ||
| ) | ||
| fsdp_options: FSDPOptions = Field( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this need to be a factory? I think it can just be an assignment
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm following the current convention set by DeepSpeedOptions in the file, so imo if we want to change this, we should make a follow-up PR that updates both of them |
||
| default_factory=lambda: FSDPOptions( | ||
| cpu_offload_params=False, sharding_strategy=ShardingStrategies.SHARD_GRAD_OP | ||
| ) | ||
| ) | ||
| distributed_backend: DistributedBackend = DistributedBackend.DEEPSPEED | ||
|
|
||
| disable_flash_attn: Optional[bool] = False | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.