-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading #7685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f9ead54
caab326
49c8b60
4bbe823
527b8cb
ec887e7
9d47947
a09614e
ab238da
c3268ac
503a765
89b07dc
662e958
d3dccda
e9ed284
8bdca08
48039c3
cff1e06
9c23aed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -274,6 +274,7 @@ def __init__( | |
| use_ema_warmup: bool = False, | ||
| inv_gamma: Union[float, int] = 1.0, | ||
| power: Union[float, int] = 2 / 3, | ||
| foreach: bool = False, | ||
| model_cls: Optional[Any] = None, | ||
| model_config: Dict[str, Any] = None, | ||
| **kwargs, | ||
|
|
@@ -288,6 +289,7 @@ def __init__( | |
| inv_gamma (float): | ||
| Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. | ||
| power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. | ||
| foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. | ||
| device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA | ||
| weights will be stored on CPU. | ||
|
|
||
|
|
@@ -342,16 +344,17 @@ def __init__( | |
| self.power = power | ||
| self.optimization_step = 0 | ||
| self.cur_decay_value = None # set in `step()` | ||
| self.foreach = foreach | ||
|
|
||
| self.model_cls = model_cls | ||
| self.model_config = model_config | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, path, model_cls) -> "EMAModel": | ||
| def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": | ||
| _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) | ||
| model = model_cls.from_pretrained(path) | ||
|
|
||
| ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) | ||
| ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) | ||
|
|
||
| ema_model.load_state_dict(ema_kwargs) | ||
| return ema_model | ||
|
|
@@ -418,15 +421,37 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): | |
| if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): | ||
| import deepspeed | ||
|
|
||
| for s_param, param in zip(self.shadow_params, parameters): | ||
| if self.foreach: | ||
| if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): | ||
| context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) | ||
| context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) | ||
|
|
||
| with context_manager(): | ||
| if param.requires_grad: | ||
| s_param.sub_(one_minus_decay * (s_param - param)) | ||
| else: | ||
| s_param.copy_(param) | ||
| params_grad = [param for param in parameters if param.requires_grad] | ||
| s_params_grad = [ | ||
| s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad | ||
| ] | ||
|
|
||
| if len(params_grad) < len(parameters): | ||
| torch._foreach_copy_( | ||
| [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], | ||
| [param for param in parameters if not param.requires_grad], | ||
| non_blocking=True, | ||
| ) | ||
|
|
||
| torch._foreach_sub_( | ||
| s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay | ||
| ) | ||
|
|
||
| else: | ||
| for s_param, param in zip(self.shadow_params, parameters): | ||
| if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): | ||
| context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) | ||
|
|
||
| with context_manager(): | ||
| if param.requires_grad: | ||
| s_param.sub_(one_minus_decay * (s_param - param)) | ||
| else: | ||
| s_param.copy_(param) | ||
|
|
||
| def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
| """ | ||
|
|
@@ -438,18 +463,34 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
| `ExponentialMovingAverage` was initialized will be used. | ||
| """ | ||
| parameters = list(parameters) | ||
| for s_param, param in zip(self.shadow_params, parameters): | ||
| param.data.copy_(s_param.to(param.device).data) | ||
| if self.foreach: | ||
| torch._foreach_copy_( | ||
| [param.data for param in parameters], | ||
| [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], | ||
|
Comment on lines
+468
to
+469
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just one question. Should we add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there's much benefit to doing so. I expect that copy_to is going to be used only for validation and model saving and won't be used every step, so in my opinion there's little benefit to having a non-blocking transfer for something that'll probably be used at the absolute most every several minutes and more realistically hours apart. I've also had to troubleshoot a few issues with unexpected increased VRAM usage (presumably from tensors not being removed from memory fast enough) when switching between validation and training, so with that (combined with the risk that someone might do something like a non-blocking copy_to of an EMA state to the model's parameters and save the model, which might result in save being called on non-ready tensors) I think it is safer for these to just be blocking.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright. That sounds reasonable to me. |
||
| ) | ||
| else: | ||
| for s_param, param in zip(self.shadow_params, parameters): | ||
| param.data.copy_(s_param.to(param.device).data) | ||
|
|
||
| def to(self, device=None, dtype=None) -> None: | ||
| def pin_memory(self) -> None: | ||
| r""" | ||
| Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for | ||
| offloading EMA params to the host. | ||
| """ | ||
|
|
||
| self.shadow_params = [p.pin_memory() for p in self.shadow_params] | ||
|
|
||
| def to(self, device=None, dtype=None, non_blocking=False) -> None: | ||
| r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
|
|
||
| Args: | ||
| device: like `device` argument to `torch.Tensor.to` | ||
| """ | ||
| # .to() on the tensors handles None correctly | ||
| self.shadow_params = [ | ||
| p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
| p.to(device=device, dtype=dtype, non_blocking=non_blocking) | ||
| if p.is_floating_point() | ||
| else p.to(device=device, non_blocking=non_blocking) | ||
| for p in self.shadow_params | ||
| ] | ||
|
|
||
|
|
@@ -493,8 +534,13 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
| """ | ||
| if self.temp_stored_params is None: | ||
| raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") | ||
| for c_param, param in zip(self.temp_stored_params, parameters): | ||
| param.data.copy_(c_param.data) | ||
| if self.foreach: | ||
| torch._foreach_copy_( | ||
| [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] | ||
| ) | ||
| else: | ||
| for c_param, param in zip(self.temp_stored_params, parameters): | ||
| param.data.copy_(c_param.data) | ||
|
|
||
| # Better memory-wise. | ||
| self.temp_stored_params = None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.