Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 94 additions & 54 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import logging
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type

import torch
from torch import nn, optim
Expand Down Expand Up @@ -59,8 +59,6 @@ def __init__(
model: nn.Module,
optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
) -> None:
"""
Args:
Expand All @@ -78,21 +76,8 @@ def __init__(
self._local_step = 0
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
device = backup_device or torch.device("cpu")
self._backup_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
if (
pin_memory
and t.device == torch.device("cpu")
and torch.cuda.is_available()
):
t = t.pin_memory()
self._backup_parameters[name] = t

self._hooks: List[RemovableHandle] = []
# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def __enter__(self) -> "LocalSGD":
# Add optimizer hook which increments the local step counter and syncs if necessary
Expand All @@ -108,30 +93,15 @@ def __exit__(
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
if exc_type is not None:
# If an exception occurred, restore parameters
self._restore_parameters()
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=False)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
Expand All @@ -151,30 +121,31 @@ def sync(self) -> None:
def _perform_sync(self) -> None:
"""
Performs the synchronization of the model weights across the manager.
This method is intended to be overridden by subclasses to implement custom
synchronization logic.
"""
self._average()
averaged_parameters = self._average()
if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?
# Update the model parameters with the averaged values
for param, avg_param in zip(self._model.parameters(), averaged_parameters):
param.data.copy_(avg_param)

def _average(self) -> list[torch.Tensor]:
"""
Averages the model parameters across the manager and returns the averaged parameters.
"""
works = []

averaged_parameters = []
for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p.data.detach()))

# Create a new tensor to store the averaged parameter
p.data.grad = None
avg_param = p.data.clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking/setting p.data.grad to None to minimize memory impact from the copy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think p.data.grad should always be None, but I can add a line to ensure it is None. I think only p.grad gets set, not sure when p.data.grad would get updated

works.append(self._manager.allreduce(avg_param))
averaged_parameters.append(avg_param)
for work in works:
work.wait()
return averaged_parameters


class DiLoCo(LocalSGD):
class DiLoCo:
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
Expand All @@ -197,27 +168,96 @@ def __init__(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)
super().__init__(
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
)
super().__init__()
self._manager = manager
self._model = model
self._local_optimizer = inner_optimizer
self._local_step = 0
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
self._backup_device = backup_device
self._pin_memory = pin_memory

self._hooks: List[RemovableHandle] = []
self._outer_optimizer = outer_optimizer
self.original_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device)
if (
self._pin_memory
and t.device == torch.device("cpu")
and torch.cuda.is_available()
):
t = t.pin_memory()
self.original_parameters[name] = t

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def _save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self.original_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self.original_parameters[name], non_blocking=False)

def __enter__(self) -> "DiLoCo":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()

def sync(self) -> None:
"""
Synchronizes and averages the model weights across the manager.
"""
self._manager.start_quorum()
self._perform_sync()
self._local_step = 0

def _perform_sync(self) -> None:
"""
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
step using the outer optimizer.
"""

# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
assert name in self._backup_parameters
pseudogradient = p.data - self._backup_parameters[name]
pseudogradient = p.data - self.original_parameters[name]
p.grad = pseudogradient

self._average_grads()
# Restore the parameters back to the previous state
self._restore_parameters()

if self._manager.should_commit():
# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
Expand Down
Loading