Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions .github/workflows/examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,38 @@ jobs:
run: |
uv venv && source .venv/bin/activate
uv pip install ".[examples]"
- name: Detect GPU availability.
id: gpu_check
run: |
source .venv/bin/activate
if python -c "import torch; raise SystemExit(0 if torch.cuda.is_available() else 1)"; then
echo "has_gpu=true" >> "$GITHUB_OUTPUT"
else
echo "has_gpu=false" >> "$GITHUB_OUTPUT"
echo "::warning::No usable GPU detected on this runner; GPU-only steps will be skipped."
fi
- name: Run single GPU example with Adam to serve as a baseline.
run: |
source .venv/bin/activate
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=adam batch_size=1024
- name: Run single GPU examples with Distributed Shampoo and different graftings on CPU.
run: |
source .venv/bin/activate
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdaGradPreconditionerConfig,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.RMSpropPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.SGDPreconditionerConfig}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdaGradPreconditionerConfig,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.RMSpropPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.SGDPreconditionerConfig}' epochs=1 batch_size=1024
- name: Run single GPU example on GPU.
if: steps.gpu_check.outputs.has_gpu == 'true'
run: |
source .venv/bin/activate
python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024
- name: Run DDP example on CPU.
run: |
source .venv/bin/activate
CUDA_VISIBLE_DEVICES="" torchrun --standalone --nnodes=1 --nproc_per_node=2 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=15 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 backend=gloo
CUDA_VISIBLE_DEVICES="" torchrun --standalone --nnodes=1 --nproc_per_node=2 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=15 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 backend=gloo
- name: Run DDP example on GPU.
if: steps.gpu_check.outputs.has_gpu == 'true'
run: |
source .venv/bin/activate
torchrun --standalone --nnodes=1 --nproc_per_node=1 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024
torchrun --standalone --nnodes=1 --nproc_per_node=1 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024
14 changes: 7 additions & 7 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from distributed_shampoo.shampoo_types import (
AdaGradPreconditionerConfig,
AdamPreconditionerConfig,
AmortizedPreconditionerConfig,
BaseShampooPreconditionerConfig,
ClassicShampooPreconditionerConfig,
DDPDistributedConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
Expand All @@ -59,7 +60,6 @@
RootInvShampooPreconditionerConfig,
ScheduleFreeConfig,
SGDPreconditionerConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
SignDescentPreconditionerConfig,
SingleDeviceDistributedConfig,
Expand All @@ -85,14 +85,14 @@
# `precision_config`.
# `preconditioner_config` options.
"PreconditionerConfig", # Abstract base class.
"AmortizedPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`).
"ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`).
"RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
"BaseShampooPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`).
"ClassicShampooPreconditionerConfig", # Abstract base class (based on `BaseShampooPreconditionerConfig`).
"RootInvShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`.
"DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`.
"RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`.
"EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
"EigendecomposedShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`.
"EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `BaseShampooPreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`.
"SpectralDescentPreconditionerConfig", # Based on `PreconditionerConfig`.
Expand Down
34 changes: 16 additions & 18 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,9 @@ def _initialize_blocked_parameters_state(self) -> None:
for block_info in state_lists[DISTRIBUTOR].local_block_info_list:
param_state = self.state[block_info.param]
assert (
block_index := block_info.composable_block_ids[1]
) not in param_state, (
"There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
)
(block_index := block_info.composable_block_ids[1])
not in param_state
), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
param_state[block_index] = {}

@torch.no_grad()
Expand Down Expand Up @@ -667,9 +666,9 @@ def _preconditioner_config_to_list_cls(
weighting_factor = 1 - beta2
use_bias_correction = False
case AdaGradPreconditionerConfig():
beta2: float = 1.0
weighting_factor: float = 1.0
use_bias_correction: bool = False
beta2 = 1.0
weighting_factor = 1.0
use_bias_correction = False
case _:
raise AssertionError(
f"Unexpected preconditioner config: {preconditioner_config}"
Expand Down Expand Up @@ -719,9 +718,9 @@ def _preconditioner_config_to_list_cls(
preconditioner_config=preconditioner_config,
)
case SpectralDescentPreconditionerConfig():
assert group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2, (
f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
)
assert (
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
return SpectralDescentPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=preconditioner_config,
Expand All @@ -734,9 +733,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
assert group[PRECONDITIONER_CONFIG] is not None, (
f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
)
assert (
group[PRECONDITIONER_CONFIG] is not None
), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
self._preconditioner_config_to_list_cls(
state_lists=state_lists,
Expand Down Expand Up @@ -1629,8 +1628,7 @@ def _pre_load_state_dict_hook(optimizer: Optimizer, state_dict: StateDict) -> No
)
if group[ITERATE_AVERAGING_CONFIG] is not None
]
# type: ignore
optimizer._pre_load_train_modes = saved_train_modes
optimizer._pre_load_train_modes = saved_train_modes # type: ignore[attr-defined]

@staticmethod
def _post_load_state_dict_hook(optimizer: Optimizer) -> None:
Expand All @@ -1655,9 +1653,9 @@ def _post_load_state_dict_hook(optimizer: Optimizer) -> None:
if saved_train_modes:
# Mixed train/eval modes across parameter groups is not supported
# since train() and eval() always operate on all groups uniformly.
assert all(m == saved_train_modes[0] for m in saved_train_modes), (
"Mixed train/eval modes across parameter groups is not supported."
)
assert all(
m == saved_train_modes[0] for m in saved_train_modes
), "Mixed train/eval modes across parameter groups is not supported."
operator.attrgetter("train" if saved_train_modes[0] else "eval")(
optimizer
)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ def __init__(
)[0]
self._group_rank: int = dist.get_rank(group=self._dist_group)

should_assign_param_idx = (
lambda i: i % self._group_size == self._group_rank
if self._param_assignment_strategy
== FSDPParamAssignmentStrategy.ROUND_ROBIN
else True
)
def should_assign_param_idx(i: int) -> bool:
if (
self._param_assignment_strategy
== FSDPParamAssignmentStrategy.ROUND_ROBIN
):
return i % self._group_size == self._group_rank
return True

self._assigned_params_mask: tuple[bool, ...] = tuple(
should_assign_param_idx(idx) for idx in range(len(param_group[PARAMS]))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def __init__(

# Stores full parameters (as opposed to DTensors) for the model parameters assigned to this rank.
# For example, when the strategy is REPLICATE, it stores the full parameters on all ranks.
should_assign_param_idx = (
lambda i: i % self._shard_group_size == self._shard_group_rank
if self._param_assignment_strategy
== FSDPParamAssignmentStrategy.ROUND_ROBIN
else True
)
def should_assign_param_idx(i: int) -> bool:
if (
self._param_assignment_strategy
== FSDPParamAssignmentStrategy.ROUND_ROBIN
):
return i % self._shard_group_size == self._shard_group_rank
return True

with torch.no_grad():
self._assigned_params_mask: tuple[bool, ...] = tuple(
should_assign_param_idx(idx) for idx in range(len(param_group[PARAMS]))
Expand Down Expand Up @@ -155,7 +157,7 @@ def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None
if assigned and (full_grad is None or full_grad.numel() > 0)
)
else:
return filter(lambda p: p.numel() > 0, self._assigned_full_params)
return (p for p in self._assigned_full_params if p.numel() > 0)

@torch.no_grad()
def update_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_checkpoint_preemption_bitwise_equivalence(
model1 = self._construct_model(
model_linear_layers_dims=model_linear_layers_dims,
model_dead_layers_dims=model_dead_layers_dims,
post_model_decoration=post_model_decoration,
post_model_decoration=post_model_decoration, # type: ignore[arg-type]
)
input_dim = model_linear_layers_dims[0]
optimizer1 = self._shampoo_optim_factory(distributed_config=config)(
Expand Down Expand Up @@ -408,7 +408,7 @@ def test_checkpoint_preemption_bitwise_equivalence(
model2 = self._construct_model(
model_linear_layers_dims=model_linear_layers_dims,
model_dead_layers_dims=model_dead_layers_dims,
post_model_decoration=post_model_decoration,
post_model_decoration=post_model_decoration, # type: ignore[arg-type]
)
# Use tiny lr for init step to minimize impact
optimizer2 = self._shampoo_optim_factory(distributed_config=config, lr=1e-10)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def _construct_model(
assert (
sum(param.numel() for param in model.parameters())
== sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2
), (
f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
)
), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
distributed_config.param_to_metadata = compile_fsdp_parameter_metadata(
model
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def _construct_model(
assert (
sum(param.numel() for param in model.parameters())
== sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2
), (
f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
)
), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
distributed_config.param_to_metadata = compile_fsdp_parameter_metadata(
model
)
Expand Down
14 changes: 7 additions & 7 deletions distributed_shampoo/distributor/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def all_gather_into_tensor() -> None:
global_buffers = self._global_dist_blocked_buffers

if self._communicate_params:
assert len(local_params) == len(blocked_search_directions), (
f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
)
assert (
len(local_params) == len(blocked_search_directions)
), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand All @@ -335,9 +335,9 @@ def all_gather_into_tensor() -> None:
)

else:
assert len(local_buffers) == len(blocked_search_directions), (
f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
)
assert (
len(local_buffers) == len(blocked_search_directions)
), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand Down Expand Up @@ -576,7 +576,7 @@ def _allocate_zeros_distributed_tensor(
mesh=tuple(batched(iterable=ranks_in_group, n=self._group_size)),
mesh_dim_names=("replicate", "shard"),
)
replicate_submesh = device_mesh_2d._get_all_submeshes(
replicate_submesh = device_mesh_2d._get_all_submeshes( # type: ignore[attr-defined]
mesh_dim_name="replicate"
)[group_source_rank]

Expand Down
Loading
Loading