Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changes to support FSDP #598

Merged
merged 20 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/language-modeling/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_forward_prefetch": false,
"fsdp_offload_params": false,
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": true,
"fsdp_use_orig_params": true,
"transformer_layer_cls_to_wrap": "GaudiLlamaDecoderLayer",
"fsdp_activation_checkpointing": false
}
13 changes: 8 additions & 5 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
TaskType,
get_peft_model,
)
from peft import LoraConfig, TaskType, get_peft_model, tuners
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand All @@ -45,6 +42,7 @@
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
from optimum.habana.peft.layer import GaudiLoraLayerLinearForward
from optimum.habana.utils import set_seed


Expand Down Expand Up @@ -674,6 +672,7 @@ def compute_metrics(eval_preds):
)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
tuners.lora.layer.Linear.forward = GaudiLoraLayerLinearForward
lora_model = get_peft_model(model, peft_config)
if training_args.bf16:
lora_model = lora_model.to(torch.bfloat16)
Expand All @@ -695,6 +694,10 @@ def compute_metrics(eval_preds):
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)

# Solution for https://github.com/huggingface/peft/blob/v0.6.2/README.md#caveats (1)
if training_args.fsdp and training_args.fsdp_config["auto_wrap_policy"] == "TRANSFORMER_BASED_WRAP":
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(lora_model)

if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
Expand Down
12 changes: 12 additions & 0 deletions examples/question-answering/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_forward_prefetch": false,
"fsdp_offload_params": false,
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": true,
"fsdp_use_orig_params": true,
"transformer_layer_cls_to_wrap": "BertLayer",
"fsdp_activation_checkpointing": false
}
96 changes: 92 additions & 4 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import contextlib
import functools
import math
import os
import sys
Expand All @@ -37,7 +38,6 @@
DistributedDataParallelKwargs,
DistributedType,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand All @@ -50,8 +50,10 @@
check_os_kernel,
convert_outputs_to_fp32,
is_deepspeed_available,
is_torch_version,
parse_choice_from_env,
)
from accelerate.utils.constants import FSDP_PYTORCH_VERSION
from accelerate.utils.operations import _gpu_gather
from accelerate.utils.other import is_compiled_module
from torch.optim.lr_scheduler import LRScheduler
Expand All @@ -68,7 +70,12 @@

from .data_loader import gaudi_prepare_data_loader
from .state import GaudiAcceleratorState, GaudiPartialState
from .utils import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)


logger = get_logger(__name__)
Expand All @@ -87,7 +94,7 @@ def __init__(
gradient_accumulation_steps: int = 1,
cpu: bool = False,
deepspeed_plugin: DeepSpeedPlugin | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
fsdp_plugin: GaudiFullyShardedDataParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -142,6 +149,27 @@ def __init__(
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()

if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
fsdp_plugin, GaudiFullyShardedDataParallelPlugin
):
import importlib.metadata

torch_version = importlib.metadata.version("torch")
torch_version = torch_version[5:]
if is_torch_version("<", FSDP_PYTORCH_VERSION + torch_version):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
GaudiFullyShardedDataParallelPlugin()
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else None
)
else:
if not isinstance(fsdp_plugin, GaudiFullyShardedDataParallelPlugin):
raise TypeError("`fsdp_plugin` must be a GaudiFullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
Expand Down Expand Up @@ -370,6 +398,54 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == GaudiDistributedType.FSDP:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
)

if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": torch.device("hpu"),
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)

apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)
# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model):
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
Expand Down Expand Up @@ -672,7 +748,11 @@ def gather(self, tensor):
tensor([0, 1, 2, 3])
```
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
if GaudiPartialState().distributed_type in [
GaudiDistributedType.MULTI_HPU,
GaudiDistributedType.DEEPSPEED,
GaudiDistributedType.FSDP,
]:
return _gpu_gather(tensor)
else:
return tensor
Expand Down Expand Up @@ -719,6 +799,14 @@ def get_state_dict(self, model, unwrap=True):
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save

state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
# copied from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/accelerator.py#L3057
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(self, cpu: bool = False, **kwargs):
deepspeed.init_distributed(dist_backend=self.backend, **kwargs)
logger.info("DeepSpeed is enabled.")
self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
elif os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.distributed_type = GaudiDistributedType.FSDP
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size)
logger.info("Enabled distributed run.")
else:
self.distributed_type = GaudiDistributedType.MULTI_HPU
if not torch.distributed.is_initialized():
Expand Down Expand Up @@ -115,6 +120,7 @@ def wait_for_everyone(self):
GaudiDistributedType.MULTI_CPU,
GaudiDistributedType.DEEPSPEED,
GaudiDistributedType.MULTI_HPU,
GaudiDistributedType.FSDP,
):
torch.distributed.barrier()

Expand Down Expand Up @@ -171,6 +177,10 @@ def __init__(
)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
self.deepspeed_plugin = deepspeed_plugin
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" and not cpu:
if self._mixed_precision != "no":
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
GaudiPartialState._shared_state["distributed_type"] = self.distributed_type
self.use_ipex = False

Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .dataclasses import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
38 changes: 38 additions & 0 deletions optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from dataclasses import dataclass
from enum import Enum

import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool

Expand All @@ -31,12 +34,14 @@ class GaudiDistributedType(str, Enum):
- **NO** -- Not a distributed environment, just a single process.
- **MULTI_HPU** -- Distributed on multiple HPUs.
- **DEEPSPEED** -- Using DeepSpeed.
- **FSDP** -- Using FSDP.
"""

# Subclassing str as well as Enum allows the `GaudiDistributedType` to be JSON-serializable out of the box.
NO = "NO"
MULTI_HPU = "MULTI_HPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"


class GaudiDynamoBackend(str, BaseEnum):
Expand Down Expand Up @@ -106,3 +111,36 @@ def __post_init__(self):
self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1
if self.dynamic is None:
self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1


@dataclass
class GaudiFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin):
def __post_init__(self):
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy

prefix = "FSDP_"
if self.sharding_strategy is None:
self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1)))

if self.cpu_offload is None:
if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:
self.cpu_offload = CPUOffload(offload_params=True)
else:
self.cpu_offload = CPUOffload(offload_params=False)

if self.backward_prefetch is None:
prefetch_policy = os.environ.get(prefix + "BACKWARD_PREFETCH", "NO_PREFETCH")
if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]:
self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1)

if self.state_dict_type is None:
state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT")
self.set_state_dict_type(state_dict_type_policy)
self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1
self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1
self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.device("hpu")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
1 change: 1 addition & 0 deletions optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layer import GaudiLoraLayerLinearForward
31 changes: 31 additions & 0 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any

import torch


def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# https://github.com/huggingface/peft/blob/4b02148af252c17e36b0a4b995f9e8519806fbb5/src/peft/tuners/lora/layer.py#L354C1-L376C22
# only differences are avoiding inplace update of "result" to prevent error from torch Dynamo in torch.compile mode of execution
# and replacing self.base_layer by self._linear
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self._linear(x, *args, **kwargs)
elif self.merged:
result = self._linear(x, *args, **kwargs)
else:
result = self._linear(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result.clone() + lora_B(lora_A(dropout(x))) * scaling
vivekgoe marked this conversation as resolved.
Show resolved Hide resolved

result = result.to(previous_dtype)
return result
10 changes: 8 additions & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states):
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
Loading
Loading