From 12c32f1454e9f391741b7bca21e79da52da77a58 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 11 Apr 2026 20:38:32 +0100 Subject: [PATCH 1/6] Remove stale preconditioner config names from `__init__.py` `AmortizedPreconditionerConfig` was renamed to `BaseShampooPreconditionerConfig` and `ShampooPreconditionerConfig` to `ClassicShampooPreconditionerConfig` in `shampoo_types.py`, but `__init__.py` still imported and re-exported the old names. Replace them with the new names in the imports and `__all__`. Co-Authored-By: Claude Opus 4.6 (1M context) --- distributed_shampoo/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index b767a52..767cf6a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -34,7 +34,8 @@ from distributed_shampoo.shampoo_types import ( AdaGradPreconditionerConfig, AdamPreconditionerConfig, - AmortizedPreconditionerConfig, + BaseShampooPreconditionerConfig, + ClassicShampooPreconditionerConfig, DDPDistributedConfig, DefaultEigenvalueCorrectedShampooConfig, DefaultShampooConfig, @@ -59,7 +60,6 @@ RootInvShampooPreconditionerConfig, ScheduleFreeConfig, SGDPreconditionerConfig, - ShampooPreconditionerConfig, ShampooPT2CompileConfig, SignDescentPreconditionerConfig, SingleDeviceDistributedConfig, @@ -85,14 +85,14 @@ # `precision_config`. # `preconditioner_config` options. "PreconditionerConfig", # Abstract base class. - "AmortizedPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`). - "ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`). - "RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "BaseShampooPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`). + "ClassicShampooPreconditionerConfig", # Abstract base class (based on `BaseShampooPreconditionerConfig`). + "RootInvShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`. "DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`. "RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`. - "EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "EigendecomposedShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`. "EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. - "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `BaseShampooPreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`. "SpectralDescentPreconditionerConfig", # Based on `PreconditionerConfig`. From 010e368d22d30a6b7ecdc9f382ea987a155924a8 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 25 Apr 2026 17:03:21 +0100 Subject: [PATCH 2/6] Fix examples workflow grafting config overrides MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CLI overrides `optimizer.grafting_config={...}` use OmegaConf's default merge semantics, so fields from the base shampoo.yaml grafting_config (notably `beta2: 0.999`) leak into overrides whose `_target_` does not accept them — breaking the AdaGrad and SGD grafting cases. Replace each override with a delete (`~`) followed by add (`+`) so the new mapping fully replaces the inherited one. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/examples.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/examples.yaml b/.github/workflows/examples.yaml index 2cb302e..e0f3e99 100644 --- a/.github/workflows/examples.yaml +++ b/.github/workflows/examples.yaml @@ -26,19 +26,19 @@ jobs: - name: Run single GPU examples with Distributed Shampoo and different graftings on CPU. run: | source .venv/bin/activate - CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdaGradPreconditionerConfig,epsilon:1e-8}' epochs=1 batch_size=1024 - CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 - CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.RMSpropPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 - CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.SGDPreconditionerConfig}' epochs=1 batch_size=1024 + CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdaGradPreconditionerConfig,epsilon:1e-8}' epochs=1 batch_size=1024 + CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 + CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.RMSpropPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 + CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.SGDPreconditionerConfig}' epochs=1 batch_size=1024 - name: Run single GPU example on GPU. run: | source .venv/bin/activate - python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 + python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 - name: Run DDP example on CPU. run: | source .venv/bin/activate - CUDA_VISIBLE_DEVICES="" torchrun --standalone --nnodes=1 --nproc_per_node=2 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=15 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 backend=gloo + CUDA_VISIBLE_DEVICES="" torchrun --standalone --nnodes=1 --nproc_per_node=2 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=15 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 backend=gloo - name: Run DDP example on GPU. run: | source .venv/bin/activate - torchrun --standalone --nnodes=1 --nproc_per_node=1 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=30 'optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 + torchrun --standalone --nnodes=1 --nproc_per_node=1 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 From cafa945a93735c9cbdcc82b11b81ccef72e4c45e Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 25 Apr 2026 17:33:46 +0100 Subject: [PATCH 3/6] Skip GPU-only example steps when runner has no usable GPU The `4-core-ubuntu-gpu-t4` runner currently has an outdated NVIDIA driver, so `torch.cuda.is_available()` returns False and the "Run DDP example on GPU." step crashes with `ProcessGroupNCCL is only supported with GPUs, no GPUs found!`. The "Run single GPU example on GPU." step "passes" only because torch silently falls back to CPU. Add a `gpu_check` step that probes `torch.cuda.is_available()` and gate both GPU-only steps on its output. If no GPU is detected, those steps are skipped (with a workflow warning) and the job stays green. When the runner image is fixed and a GPU is actually available, both steps run as before with no other changes needed. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/examples.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/examples.yaml b/.github/workflows/examples.yaml index e0f3e99..5a0fe19 100644 --- a/.github/workflows/examples.yaml +++ b/.github/workflows/examples.yaml @@ -19,6 +19,16 @@ jobs: run: | uv venv && source .venv/bin/activate uv pip install ".[examples]" + - name: Detect GPU availability. + id: gpu_check + run: | + source .venv/bin/activate + if python -c "import torch; raise SystemExit(0 if torch.cuda.is_available() else 1)"; then + echo "has_gpu=true" >> "$GITHUB_OUTPUT" + else + echo "has_gpu=false" >> "$GITHUB_OUTPUT" + echo "::warning::No usable GPU detected on this runner; GPU-only steps will be skipped." + fi - name: Run single GPU example with Adam to serve as a baseline. run: | source .venv/bin/activate @@ -31,6 +41,7 @@ jobs: CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.RMSpropPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 CUDA_VISIBLE_DEVICES="" python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.SGDPreconditionerConfig}' epochs=1 batch_size=1024 - name: Run single GPU example on GPU. + if: steps.gpu_check.outputs.has_gpu == 'true' run: | source .venv/bin/activate python -m distributed_shampoo.examples.cifar10_example optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 batch_size=1024 @@ -39,6 +50,7 @@ jobs: source .venv/bin/activate CUDA_VISIBLE_DEVICES="" torchrun --standalone --nnodes=1 --nproc_per_node=2 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=15 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 backend=gloo - name: Run DDP example on GPU. + if: steps.gpu_check.outputs.has_gpu == 'true' run: | source .venv/bin/activate torchrun --standalone --nnodes=1 --nproc_per_node=1 -m distributed_shampoo.examples.cifar10_example parallelism=ddp optimizer=shampoo optimizer.precondition_frequency=30 '~optimizer.grafting_config' '+optimizer.grafting_config={_target_:distributed_shampoo.AdamPreconditionerConfig,beta2:0.999,epsilon:1e-8}' epochs=1 local_batch_size=1024 From 44aaa9f9aedf3c96629bba2c2a244b74aa871fc8 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 25 Apr 2026 17:15:54 +0100 Subject: [PATCH 4/6] Fix ruff lint errors (E402, E731) and normalize copyright headers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit E402 errors (~92 total across 12 files): every affected file had a duplicate triple-quoted copyright block at the top followed by the real module docstring (or, in gpa/, a `#`-style copyright). The duplicate triple-quoted block was a non-import statement that triggered E402 for every subsequent import. Drop the duplicate so the real module docstring is the first statement. E731 errors (2 sites in *_lossless_distributor.py): convert `should_assign_param_idx = lambda i: ...` to a `def`. Header convention: 79 of 79 untouched files use a single triple-quoted copyright block (some with merged module docstring text). Normalize the 6 gpa files I touched and gpa/gpa_adamw.py to that style — drop the redundant `#`-comment copyright block in favor of triple-quoted. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../_shampoo_fully_shard_lossless_distributor.py | 14 ++++++++------ .../_shampoo_hybrid_shard_lossless_distributor.py | 14 ++++++++------ distributed_shampoo/examples/cifar10_example.py | 9 --------- distributed_shampoo/examples/parallelism.py | 9 --------- distributed_shampoo/examples/resolvers.py | 9 --------- distributed_shampoo/examples/utils.py | 9 --------- .../gpu_tests/iterate_averaging_test.py | 9 --------- distributed_shampoo/gpu_tests/weight_decay_test.py | 9 --------- gpa/examples/cifar10_example.py | 12 ++---------- gpa/gpa_adamw.py | 12 +++--------- gpa/gpa_types.py | 9 --------- gpa/gpu_tests/gpa_adamw_numerics_test.py | 9 --------- gpa/tests/gpa_adamw_test.py | 9 --------- gpa/tests/gpa_equivalence_test.py | 13 ++----------- gpa/tests/gpa_test_utils.py | 9 --------- 15 files changed, 23 insertions(+), 132 deletions(-) diff --git a/distributed_shampoo/distributor/_shampoo_fully_shard_lossless_distributor.py b/distributed_shampoo/distributor/_shampoo_fully_shard_lossless_distributor.py index d8e2894..b378c1c 100644 --- a/distributed_shampoo/distributor/_shampoo_fully_shard_lossless_distributor.py +++ b/distributed_shampoo/distributor/_shampoo_fully_shard_lossless_distributor.py @@ -66,12 +66,14 @@ def __init__( )[0] self._group_rank: int = dist.get_rank(group=self._dist_group) - should_assign_param_idx = ( - lambda i: i % self._group_size == self._group_rank - if self._param_assignment_strategy - == FSDPParamAssignmentStrategy.ROUND_ROBIN - else True - ) + def should_assign_param_idx(i: int) -> bool: + if ( + self._param_assignment_strategy + == FSDPParamAssignmentStrategy.ROUND_ROBIN + ): + return i % self._group_size == self._group_rank + return True + self._assigned_params_mask: tuple[bool, ...] = tuple( should_assign_param_idx(idx) for idx in range(len(param_group[PARAMS])) ) diff --git a/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py b/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py index 3d6fcbb..621c78f 100644 --- a/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py +++ b/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py @@ -72,12 +72,14 @@ def __init__( # Stores full parameters (as opposed to DTensors) for the model parameters assigned to this rank. # For example, when the strategy is REPLICATE, it stores the full parameters on all ranks. - should_assign_param_idx = ( - lambda i: i % self._shard_group_size == self._shard_group_rank - if self._param_assignment_strategy - == FSDPParamAssignmentStrategy.ROUND_ROBIN - else True - ) + def should_assign_param_idx(i: int) -> bool: + if ( + self._param_assignment_strategy + == FSDPParamAssignmentStrategy.ROUND_ROBIN + ): + return i % self._shard_group_size == self._shard_group_rank + return True + with torch.no_grad(): self._assigned_params_mask: tuple[bool, ...] = tuple( should_assign_param_idx(idx) for idx in range(len(param_group[PARAMS])) diff --git a/distributed_shampoo/examples/cifar10_example.py b/distributed_shampoo/examples/cifar10_example.py index f043fc7..6980850 100644 --- a/distributed_shampoo/examples/cifar10_example.py +++ b/distributed_shampoo/examples/cifar10_example.py @@ -1,12 +1,3 @@ -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - -""" - #!/usr/bin/env python3 """CIFAR-10 training example. diff --git a/distributed_shampoo/examples/parallelism.py b/distributed_shampoo/examples/parallelism.py index 5d39f0d..cfb9623 100644 --- a/distributed_shampoo/examples/parallelism.py +++ b/distributed_shampoo/examples/parallelism.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - Parallelism strategies for distributed training. """ diff --git a/distributed_shampoo/examples/resolvers.py b/distributed_shampoo/examples/resolvers.py index afde3a6..4ae34f3 100644 --- a/distributed_shampoo/examples/resolvers.py +++ b/distributed_shampoo/examples/resolvers.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - Hydra resolvers for complex types in YAML configs. """ diff --git a/distributed_shampoo/examples/utils.py b/distributed_shampoo/examples/utils.py index 0481482..498ff84 100644 --- a/distributed_shampoo/examples/utils.py +++ b/distributed_shampoo/examples/utils.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - Utility functions for CIFAR-10 training examples. """ diff --git a/distributed_shampoo/gpu_tests/iterate_averaging_test.py b/distributed_shampoo/gpu_tests/iterate_averaging_test.py index 5b40bef..277a24c 100644 --- a/distributed_shampoo/gpu_tests/iterate_averaging_test.py +++ b/distributed_shampoo/gpu_tests/iterate_averaging_test.py @@ -1,12 +1,3 @@ -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - -""" - #!/usr/bin/env python3 """Tests for validating iterate averaging (GPA and Schedule-Free) equivalence. diff --git a/distributed_shampoo/gpu_tests/weight_decay_test.py b/distributed_shampoo/gpu_tests/weight_decay_test.py index 27365ca..0d20bc6 100644 --- a/distributed_shampoo/gpu_tests/weight_decay_test.py +++ b/distributed_shampoo/gpu_tests/weight_decay_test.py @@ -1,12 +1,3 @@ -""" -Copyright (c) Meta Platforms, Inc. and affiliates. -All rights reserved. - -This source code is licensed under the BSD-style license found in the -LICENSE file in the root directory of this source tree. - -""" - #!/usr/bin/env python3 """ diff --git a/gpa/examples/cifar10_example.py b/gpa/examples/cifar10_example.py index 751f5fc..1a91e27 100644 --- a/gpa/examples/cifar10_example.py +++ b/gpa/examples/cifar10_example.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + """ Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. @@ -5,16 +7,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" CIFAR-10 Training Example with GPAAdamW Optimizer. Single GPU example demonstrating GPA optimizer usage with a simple ConvNet. diff --git a/gpa/gpa_adamw.py b/gpa/gpa_adamw.py index aaa1d6d..99048e4 100644 --- a/gpa/gpa_adamw.py +++ b/gpa/gpa_adamw.py @@ -7,12 +7,6 @@ """ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - # pyre-unsafe from logging import getLogger from typing import Callable, Optional, Union @@ -425,9 +419,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] lr_max = max(lr, self.state[first_param][LR_MAX].item()) self.state[first_param][LR_MAX].fill_(lr_max) - assert lr_max > 0, ( - f"lr_max must be positive, got lr_max={lr_max}. Check that lr={lr} is positive." - ) + assert ( + lr_max > 0 + ), f"lr_max must be positive, got lr_max={lr_max}. Check that lr={lr} is positive." # Compute avg_coeff ONCE per step (before the parameter loop). # This is important for Schedule-Free: the coefficient should be the same diff --git a/gpa/gpa_types.py b/gpa/gpa_types.py index b97e0e2..d1d851a 100644 --- a/gpa/gpa_types.py +++ b/gpa/gpa_types.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" Constants for optimizers - parameter group and state keys. This is one single place to define all the keys used in optimizer state. """ diff --git a/gpa/gpu_tests/gpa_adamw_numerics_test.py b/gpa/gpu_tests/gpa_adamw_numerics_test.py index 48a5397..aa9382a 100644 --- a/gpa/gpu_tests/gpa_adamw_numerics_test.py +++ b/gpa/gpu_tests/gpa_adamw_numerics_test.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" GPU convergence tests for GPAAdamW optimizer. This module contains convergence tests that run on both CPU and GPU using diff --git a/gpa/tests/gpa_adamw_test.py b/gpa/tests/gpa_adamw_test.py index b4860e0..271a88b 100644 --- a/gpa/tests/gpa_adamw_test.py +++ b/gpa/tests/gpa_adamw_test.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" Consolidated unit tests for GPAAdamW optimizer. Test classes: diff --git a/gpa/tests/gpa_equivalence_test.py b/gpa/tests/gpa_equivalence_test.py index 5e1426b..7cbb7f8e 100644 --- a/gpa/tests/gpa_equivalence_test.py +++ b/gpa/tests/gpa_equivalence_test.py @@ -5,17 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -""" Equivalence tests for GPA optimizers. This module tests that GPA reduces to known optimizers under specific @@ -30,6 +19,8 @@ python -m unittest gpa.tests.gpa_equivalence_test -v """ +# pyre-unsafe + import unittest import torch diff --git a/gpa/tests/gpa_test_utils.py b/gpa/tests/gpa_test_utils.py index 8abf391..d41e678 100644 --- a/gpa/tests/gpa_test_utils.py +++ b/gpa/tests/gpa_test_utils.py @@ -5,15 +5,6 @@ This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. -""" - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" Shared test utilities for GPA optimizer tests. This module provides common helper functions used across the GPA optimizer From 17cee6ec7bdf4b94cba4ea9a717053e7bcc8064f Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 25 Apr 2026 17:16:01 +0100 Subject: [PATCH 5/6] Apply ruff-format pass to resolve formatter drift Pre-commit's pinned `ruff-format` (v0.8.0 in .pre-commit-config.yaml) considers 14 files mis-formatted. The diffs are purely whitespace choices around `assert ..., (...)` wrapping introduced by ruff version changes. Apply the formatter so the `ruff-format` hook passes on the first run instead of failing after auto-modifying files. Co-Authored-By: Claude Opus 4.7 (1M context) --- distributed_shampoo/distributed_shampoo.py | 25 +++++++++---------- .../shampoo_fsdp_distributor_test.py | 4 +-- .../shampoo_hsdp_distributor_test.py | 4 +-- .../distributor/shampoo_ddp_distributor.py | 12 ++++----- .../distributor/shampoo_distributor.py | 12 ++++----- .../distributor/shampoo_fsdp_distributor.py | 6 ++--- .../distributor/shampoo_fsdp_utils.py | 25 ++++++++++--------- .../distributor/shampoo_hsdp_distributor.py | 18 ++++++------- .../shampoo_hybrid_shard_distributor.py | 12 ++++----- distributed_shampoo/examples/convnet.py | 4 +-- .../preconditioner/matrix_functions.py | 12 ++++----- .../shampoo_preconditioner_list.py | 18 ++++++------- .../utils/shampoo_quantization.py | 6 ++--- distributed_shampoo/utils/shampoo_utils.py | 18 ++++++------- 14 files changed, 85 insertions(+), 91 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 2f5168e..f7125bf 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -621,10 +621,9 @@ def _initialize_blocked_parameters_state(self) -> None: for block_info in state_lists[DISTRIBUTOR].local_block_info_list: param_state = self.state[block_info.param] assert ( - block_index := block_info.composable_block_ids[1] - ) not in param_state, ( - "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions." - ) + (block_index := block_info.composable_block_ids[1]) + not in param_state + ), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions." param_state[block_index] = {} @torch.no_grad() @@ -719,9 +718,9 @@ def _preconditioner_config_to_list_cls( preconditioner_config=preconditioner_config, ) case SpectralDescentPreconditionerConfig(): - assert group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2, ( - f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." - ) + assert ( + group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 + ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." return SpectralDescentPreconditionerList( block_list=state_lists[DISTRIBUTOR].local_blocked_params, preconditioner_config=preconditioner_config, @@ -734,9 +733,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None: for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - assert group[PRECONDITIONER_CONFIG] is not None, ( - f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." - ) + assert ( + group[PRECONDITIONER_CONFIG] is not None + ), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( self._preconditioner_config_to_list_cls( state_lists=state_lists, @@ -1655,9 +1654,9 @@ def _post_load_state_dict_hook(optimizer: Optimizer) -> None: if saved_train_modes: # Mixed train/eval modes across parameter groups is not supported # since train() and eval() always operate on all groups uniformly. - assert all(m == saved_train_modes[0] for m in saved_train_modes), ( - "Mixed train/eval modes across parameter groups is not supported." - ) + assert all( + m == saved_train_modes[0] for m in saved_train_modes + ), "Mixed train/eval modes across parameter groups is not supported." operator.attrgetter("train" if saved_train_modes[0] else "eval")( optimizer )() diff --git a/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py b/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py index 7d91201..44947d9 100644 --- a/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py +++ b/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py @@ -101,9 +101,7 @@ def _construct_model( assert ( sum(param.numel() for param in model.parameters()) == sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2 - ), ( - f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" - ) + ), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" distributed_config.param_to_metadata = compile_fsdp_parameter_metadata( model ) diff --git a/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py index 1678109..8b82b1b 100644 --- a/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py @@ -102,9 +102,7 @@ def _construct_model( assert ( sum(param.numel() for param in model.parameters()) == sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2 - ), ( - f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" - ) + ), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" distributed_config.param_to_metadata = compile_fsdp_parameter_metadata( model ) diff --git a/distributed_shampoo/distributor/shampoo_ddp_distributor.py b/distributed_shampoo/distributor/shampoo_ddp_distributor.py index d40af73..ea9b2cf 100644 --- a/distributed_shampoo/distributor/shampoo_ddp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_ddp_distributor.py @@ -308,9 +308,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert len(local_params) == len(blocked_search_directions), ( - f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_params) == len(blocked_search_directions) + ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -335,9 +335,9 @@ def all_gather_into_tensor() -> None: ) else: - assert len(local_buffers) == len(blocked_search_directions), ( - f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_buffers) == len(blocked_search_directions) + ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/distributor/shampoo_distributor.py b/distributed_shampoo/distributor/shampoo_distributor.py index 4509ca5..9d6bf4f 100644 --- a/distributed_shampoo/distributor/shampoo_distributor.py +++ b/distributed_shampoo/distributor/shampoo_distributor.py @@ -280,9 +280,9 @@ def _merge_and_block_gradients( assert grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite(grad).all(), ( - f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping." - ) + assert torch.isfinite( + grad + ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping." # Obtain blocks for each gradient after merging. blocks_within_grad = multi_dim_split( @@ -354,9 +354,9 @@ def update_params( else self._local_blocked_params ) - assert len(blocked_search_directions) == len(target_params), ( - f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}." - ) + assert ( + len(blocked_search_directions) == len(target_params) + ), f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/distributor/shampoo_fsdp_distributor.py b/distributed_shampoo/distributor/shampoo_fsdp_distributor.py index b63cf6b..cbec4c8 100644 --- a/distributed_shampoo/distributor/shampoo_fsdp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_fsdp_distributor.py @@ -154,9 +154,9 @@ def _merge_and_block_gradients( assert flattened_grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite(flattened_grad).all(), ( - f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." - ) + assert torch.isfinite( + flattened_grad + ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." # Split flattened gradients into valid tensor blocks of the gradient. split_grads = FSDPDistributor._split_tensor_block_recovery( diff --git a/distributed_shampoo/distributor/shampoo_fsdp_utils.py b/distributed_shampoo/distributor/shampoo_fsdp_utils.py index 6b18d18..59446b0 100644 --- a/distributed_shampoo/distributor/shampoo_fsdp_utils.py +++ b/distributed_shampoo/distributor/shampoo_fsdp_utils.py @@ -41,9 +41,9 @@ def compile_fsdp_parameter_metadata( shard_param_infos = flat_param._shard_param_infos sharding_strategy = fsdp_module.sharding_strategy - assert flat_param._params is not None, ( - "flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True " - ) + assert ( + flat_param._params is not None + ), "flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True " "would populate flat_param._params." params = flat_param._params @@ -156,12 +156,13 @@ def partition_param_list( ) assert ( - unioned_keys := fsdp_params_dict.keys() - | hsdp_params_dict.keys() - | other_params_dict.keys() - ) == original_params_dict.keys(), ( - f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}" - ) + ( + unioned_keys := fsdp_params_dict.keys() + | hsdp_params_dict.keys() + | other_params_dict.keys() + ) + == original_params_dict.keys() + ), f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}" for (name1, dict1), (name2, dict2) in itertools.combinations( ( ("fsdp_params_dict", fsdp_params_dict), @@ -170,9 +171,9 @@ def partition_param_list( ), 2, ): - assert not (common_keys := dict1.keys() & dict2.keys()), ( - f"{common_keys} exist in both {name1} and {name2}!" - ) + assert not ( + common_keys := dict1.keys() & dict2.keys() + ), f"{common_keys} exist in both {name1} and {name2}!" return ( list(fsdp_params_dict.items()), diff --git a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py index 60ab4e2..0f9a361 100644 --- a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py @@ -262,9 +262,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert len(local_params) == len(blocked_search_directions), ( - f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_params) == len(blocked_search_directions) + ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -289,9 +289,9 @@ def all_gather_into_tensor() -> None: ) else: - assert len(local_buffers) == len(blocked_search_directions), ( - f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_buffers) == len(blocked_search_directions) + ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -564,9 +564,9 @@ def _merge_and_block_gradients( assert flattened_grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite(flattened_grad).all(), ( - f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." - ) + assert torch.isfinite( + flattened_grad + ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." # Split flattened gradients into valid tensor blocks of the gradient. split_grads = HSDPDistributor._split_tensor_block_recovery( diff --git a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py index 0138ef6..1d687d2 100644 --- a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py @@ -287,9 +287,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert len(local_params) == len(blocked_search_directions), ( - f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_params) == len(blocked_search_directions) + ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -314,9 +314,9 @@ def all_gather_into_tensor() -> None: ) else: - assert len(local_buffers) == len(blocked_search_directions), ( - f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." - ) + assert ( + len(local_buffers) == len(blocked_search_directions) + ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/examples/convnet.py b/distributed_shampoo/examples/convnet.py index 9302228..23d4815 100644 --- a/distributed_shampoo/examples/convnet.py +++ b/distributed_shampoo/examples/convnet.py @@ -67,8 +67,6 @@ def _infer_conv_output_shape( output_shape = [] for input_length in input_shape: output_length = (input_length - kernel_size + 2 * padding) / stride + 1 - assert output_length.is_integer(), ( - f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!" - ) + assert output_length.is_integer(), f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!" output_shape.append(int(output_length)) return output_shape diff --git a/distributed_shampoo/preconditioner/matrix_functions.py b/distributed_shampoo/preconditioner/matrix_functions.py index 65ad1c8..29e70db 100644 --- a/distributed_shampoo/preconditioner/matrix_functions.py +++ b/distributed_shampoo/preconditioner/matrix_functions.py @@ -808,9 +808,9 @@ def qr_algorithm( Q = eigenvectors_estimate # This assertion provides a more clear error message than the internal error message in `torch.mm`, and assertion makes sure that user-side is unable to catch the error. - assert Q.dtype == A.dtype, ( - f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}" - ) + assert ( + Q.dtype == A.dtype + ), f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}" eigenvalues_estimate = Q.T @ A @ Q iteration = 0 @@ -877,9 +877,9 @@ def qr_algorithm( func=eigh_eigenvalue_decomposition, config=eigendecomposition_config )(A=A_ridge) case QREigendecompositionConfig(): - assert eigenvectors_estimate is not None, ( - "eigenvectors_estimate should not be None when QR algorithm is used." - ) + assert ( + eigenvectors_estimate is not None + ), "eigenvectors_estimate should not be None when QR algorithm is used." return _assign_function_args_from_config( func=qr_algorithm, config=eigendecomposition_config )(A=A_ridge, eigenvectors_estimate=eigenvectors_estimate) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index bba4509..217c79c 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -1374,14 +1374,14 @@ def _precondition_grad( ) -> Tensor: # TODO: Need to refactor this function to be more efficient. Ideally eliminate those branches. # Might consider einsum? - assert sum(preconditioned_dims_selector) == len(preconditioner_list), ( - f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})." - ) + assert ( + sum(preconditioned_dims_selector) == len(preconditioner_list) + ), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})." # Extract all dtypes and assert they are unique - assert len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1, ( - f"All preconditioners must have the same dtype, but found: {unique_dtypes}" - ) + assert ( + len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1 + ), f"All preconditioners must have the same dtype, but found: {unique_dtypes}" # Use the single dtype if preconditioners exist, otherwise use grad dtype target_dtype = next(iter(unique_dtypes), grad.dtype) @@ -1566,9 +1566,9 @@ def _get_inverse_exponent(self, dimension: int, order: int) -> float: return inverse_exponent_override_on_order.get( dimension, 1 / (2 * max(order, 1)) ) - assert isinstance(inverse_exponent_override_on_order, float), ( - f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead." - ) + assert isinstance( + inverse_exponent_override_on_order, float + ), f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead." return inverse_exponent_override_on_order def _create_preconditioned_dims_selector( diff --git a/distributed_shampoo/utils/shampoo_quantization.py b/distributed_shampoo/utils/shampoo_quantization.py index de05058..437f61f 100644 --- a/distributed_shampoo/utils/shampoo_quantization.py +++ b/distributed_shampoo/utils/shampoo_quantization.py @@ -153,9 +153,9 @@ def __init__( value.dtype == quantized_dtype for value in self.quantized_value_list ) self.quantized_dtype = quantized_dtype - assert computation_dtype in _FLOAT_DTYPES, ( - f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!" - ) + assert ( + computation_dtype in _FLOAT_DTYPES + ), f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!" self.computation_dtype = computation_dtype # All min/max values should be None, or no min/max values are None diff --git a/distributed_shampoo/utils/shampoo_utils.py b/distributed_shampoo/utils/shampoo_utils.py index 23060d9..1a0e690 100644 --- a/distributed_shampoo/utils/shampoo_utils.py +++ b/distributed_shampoo/utils/shampoo_utils.py @@ -79,9 +79,9 @@ def merge_small_dims( return (0,) if isinstance(target_tensor_dimensionality, float): - assert target_tensor_dimensionality == math.inf, ( - f"{target_tensor_dimensionality=} has to be an integer or math.inf." - ) + assert ( + target_tensor_dimensionality == math.inf + ), f"{target_tensor_dimensionality=} has to be an integer or math.inf." return tensor_shape # Squeeze tensor shape to remove dimension with 1; if all dimensions are 1, @@ -151,9 +151,9 @@ def multi_dim_split(tensor: Tensor, split_size: int | float) -> tuple[Tensor, .. """ if isinstance(split_size, float): - assert split_size == math.inf, ( - f"{split_size=} has to be an integer or math.inf." - ) + assert ( + split_size == math.inf + ), f"{split_size=} has to be an integer or math.inf." return (tensor,) return reduce( @@ -190,9 +190,9 @@ def compress_list( Only elements from complete_list where the corresponding selector is True are included. """ - assert len(complete_list) == len(selector), ( - f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!" - ) + assert ( + len(complete_list) == len(selector) + ), f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!" return tuple(compress(complete_list, selector)) From ce5b5521b5dcfd9848345dbec5a2aa4dcb92ccac Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 25 Apr 2026 17:23:45 +0100 Subject: [PATCH 6/6] Fix mypy errors - gpa/gpa_adamw.py: * step(): add @overload pair matching torch.optim.Optimizer * step(): annotate the per-group buffer lists as list[Tensor] * step(): rename group-local first_param to group_first_param to avoid mypy union with the Optional[Parameter] from the train-mode check loop, and skip empty parameter groups via `continue` - gpa/tests/gpa_test_utils.py: annotate `devices` as tuple[device, ...] - gpa/gpu_tests/gpa_adamw_numerics_test.py: ignore missing `parameterized` stubs (no public stub package) - gpa/examples/cifar10_example.py: ignore len() on Dataset[Any] - distributed_shampoo/distributed_shampoo.py: * drop redundant re-annotations on AdaGrad branch (no-redef) * move stray `# type: ignore` onto the `_pre_load_train_modes` assignment line so it actually applies - distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py: * replace filter(lambda, ...) with a generator expression to dodge a confusing TypeGuard overload mypy picks for filter() - distributed_shampoo/distributor/shampoo_*_distributor.py (3 files): * ignore attr-defined for the private DeviceMesh._get_all_submeshes - distributed_shampoo/distributor/shampoo_distributor.py: ignore assignment narrowing of map(partial(tuple), ...) - distributed_shampoo/distributor/gpu_tests/shampoo_checkpoint_test.py: ignore arg-type for partial[FSDPModule] passed to a Callable[[Module], Module] parameter (FSDPModule mixes in but mypy can't see it) - distributed_shampoo/examples/parallelism.py: same FSDPModule ignore for two fully_shard call sites - distributed_shampoo/examples/utils.py: annotate `sampler` - distributed_shampoo/examples/tests/utils_test.py: replace assertIsNotNone with `assert ... is not None` so mypy narrows `_fmt` before assertIn - distributed_shampoo/utils/load_balancing_utils.py: extend the existing `# type: ignore[misc]` to also cover `call-overload` from max(float, floating[Any]) Co-Authored-By: Claude Opus 4.7 (1M context) --- distributed_shampoo/distributed_shampoo.py | 9 +++-- ...ampoo_hybrid_shard_lossless_distributor.py | 2 +- .../gpu_tests/shampoo_checkpoint_test.py | 4 +-- .../distributor/shampoo_ddp_distributor.py | 2 +- .../distributor/shampoo_distributor.py | 2 +- .../distributor/shampoo_hsdp_distributor.py | 2 +- .../shampoo_hybrid_shard_distributor.py | 2 +- distributed_shampoo/examples/parallelism.py | 4 +-- .../examples/tests/utils_test.py | 4 +-- distributed_shampoo/examples/utils.py | 6 ++-- .../utils/load_balancing_utils.py | 2 +- gpa/examples/cifar10_example.py | 2 +- gpa/gpa_adamw.py | 35 ++++++++++++------- gpa/gpu_tests/gpa_adamw_numerics_test.py | 2 +- gpa/tests/gpa_test_utils.py | 2 +- 15 files changed, 45 insertions(+), 35 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index f7125bf..faaedef 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -666,9 +666,9 @@ def _preconditioner_config_to_list_cls( weighting_factor = 1 - beta2 use_bias_correction = False case AdaGradPreconditionerConfig(): - beta2: float = 1.0 - weighting_factor: float = 1.0 - use_bias_correction: bool = False + beta2 = 1.0 + weighting_factor = 1.0 + use_bias_correction = False case _: raise AssertionError( f"Unexpected preconditioner config: {preconditioner_config}" @@ -1628,8 +1628,7 @@ def _pre_load_state_dict_hook(optimizer: Optimizer, state_dict: StateDict) -> No ) if group[ITERATE_AVERAGING_CONFIG] is not None ] - # type: ignore - optimizer._pre_load_train_modes = saved_train_modes + optimizer._pre_load_train_modes = saved_train_modes # type: ignore[attr-defined] @staticmethod def _post_load_state_dict_hook(optimizer: Optimizer) -> None: diff --git a/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py b/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py index 621c78f..adf4467 100644 --- a/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py +++ b/distributed_shampoo/distributor/_shampoo_hybrid_shard_lossless_distributor.py @@ -157,7 +157,7 @@ def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None if assigned and (full_grad is None or full_grad.numel() > 0) ) else: - return filter(lambda p: p.numel() > 0, self._assigned_full_params) + return (p for p in self._assigned_full_params if p.numel() > 0) @torch.no_grad() def update_params( diff --git a/distributed_shampoo/distributor/gpu_tests/shampoo_checkpoint_test.py b/distributed_shampoo/distributor/gpu_tests/shampoo_checkpoint_test.py index d095088..5c469b8 100644 --- a/distributed_shampoo/distributor/gpu_tests/shampoo_checkpoint_test.py +++ b/distributed_shampoo/distributor/gpu_tests/shampoo_checkpoint_test.py @@ -373,7 +373,7 @@ def test_checkpoint_preemption_bitwise_equivalence( model1 = self._construct_model( model_linear_layers_dims=model_linear_layers_dims, model_dead_layers_dims=model_dead_layers_dims, - post_model_decoration=post_model_decoration, + post_model_decoration=post_model_decoration, # type: ignore[arg-type] ) input_dim = model_linear_layers_dims[0] optimizer1 = self._shampoo_optim_factory(distributed_config=config)( @@ -408,7 +408,7 @@ def test_checkpoint_preemption_bitwise_equivalence( model2 = self._construct_model( model_linear_layers_dims=model_linear_layers_dims, model_dead_layers_dims=model_dead_layers_dims, - post_model_decoration=post_model_decoration, + post_model_decoration=post_model_decoration, # type: ignore[arg-type] ) # Use tiny lr for init step to minimize impact optimizer2 = self._shampoo_optim_factory(distributed_config=config, lr=1e-10)( diff --git a/distributed_shampoo/distributor/shampoo_ddp_distributor.py b/distributed_shampoo/distributor/shampoo_ddp_distributor.py index ea9b2cf..deaeee5 100644 --- a/distributed_shampoo/distributor/shampoo_ddp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_ddp_distributor.py @@ -576,7 +576,7 @@ def _allocate_zeros_distributed_tensor( mesh=tuple(batched(iterable=ranks_in_group, n=self._group_size)), mesh_dim_names=("replicate", "shard"), ) - replicate_submesh = device_mesh_2d._get_all_submeshes( + replicate_submesh = device_mesh_2d._get_all_submeshes( # type: ignore[attr-defined] mesh_dim_name="replicate" )[group_source_rank] diff --git a/distributed_shampoo/distributor/shampoo_distributor.py b/distributed_shampoo/distributor/shampoo_distributor.py index 9d6bf4f..e77dc62 100644 --- a/distributed_shampoo/distributor/shampoo_distributor.py +++ b/distributed_shampoo/distributor/shampoo_distributor.py @@ -225,7 +225,7 @@ def _merge_and_block_parameters(self) -> None: """ self._global_blocked_params: tuple[Tensor, ...] self._global_num_blocks_per_param: tuple[int, ...] - self._global_blocked_params, self._global_num_blocks_per_param = map( + self._global_blocked_params, self._global_num_blocks_per_param = map( # type: ignore[assignment] partial(tuple), self._merge_and_block_with_params(params=self._get_params_or_grads()), ) diff --git a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py index 0f9a361..288eeff 100644 --- a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py @@ -874,7 +874,7 @@ def _allocate_zeros_distributed_tensor( # For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]]. # Note that the group source rank must belong to {0, 1, 2} in this case. # Suppose the group_source_rank = 1, then this would get the submesh [11, 35]. - replicate_submesh = device_mesh_2d._get_all_submeshes( + replicate_submesh = device_mesh_2d._get_all_submeshes( # type: ignore[attr-defined] mesh_dim_name="replicate" )[group_source_rank] diff --git a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py index 1d687d2..0e63f3f 100644 --- a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py @@ -605,7 +605,7 @@ def _allocate_zeros_distributed_tensor( # For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]]. # Note that the group source rank must belong to {0, 1, 2} in this case. # Suppose the group_source_rank = 1, then this would get the submesh [11, 35]. - replicate_submesh = device_mesh_2d._get_all_submeshes( + replicate_submesh = device_mesh_2d._get_all_submeshes( # type: ignore[attr-defined] mesh_dim_name="replicate" )[group_source_rank] diff --git a/distributed_shampoo/examples/parallelism.py b/distributed_shampoo/examples/parallelism.py index cfb9623..a8c5678 100644 --- a/distributed_shampoo/examples/parallelism.py +++ b/distributed_shampoo/examples/parallelism.py @@ -174,7 +174,7 @@ def wrap_model( device_mesh: DeviceMesh | None = None, ) -> WrappedModel: config = self.distributed_config() if self.distributed_config else None - return WrappedModel(model=fully_shard(model), distributed_config=config) + return WrappedModel(model=fully_shard(model), distributed_config=config) # type: ignore[arg-type] @dataclass @@ -203,6 +203,6 @@ def wrap_model( if self.distributed_config: config = self.distributed_config(device_mesh=device_mesh) return WrappedModel( - model=fully_shard(model, mesh=device_mesh), + model=fully_shard(model, mesh=device_mesh), # type: ignore[arg-type] distributed_config=config, ) diff --git a/distributed_shampoo/examples/tests/utils_test.py b/distributed_shampoo/examples/tests/utils_test.py index de17cdb..b73ef6e 100644 --- a/distributed_shampoo/examples/tests/utils_test.py +++ b/distributed_shampoo/examples/tests/utils_test.py @@ -224,7 +224,7 @@ def test_formatter_with_distributed_initialized(self, mock_dist: MagicMock) -> N mock_dist.get_rank.return_value = 3 formatter = PerRankLoggingFormatter() - self.assertIsNotNone(formatter._fmt) + assert formatter._fmt is not None self.assertIn("RANK 3", formatter._fmt) @patch("distributed_shampoo.examples.utils.dist") @@ -236,7 +236,7 @@ def test_formatter_without_distributed_initialized( formatter = PerRankLoggingFormatter() # When fmt=None is passed to Formatter.__init__, it defaults to '%(message)s' - self.assertIsNotNone(formatter._fmt) + assert formatter._fmt is not None self.assertNotIn("RANK", formatter._fmt) diff --git a/distributed_shampoo/examples/utils.py b/distributed_shampoo/examples/utils.py index 498ff84..cc789e4 100644 --- a/distributed_shampoo/examples/utils.py +++ b/distributed_shampoo/examples/utils.py @@ -158,8 +158,10 @@ def get_data_loader_and_sampler( dataset = datasets.CIFAR10( data_path, train=True, download=True, transform=transform ) - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=world_size, rank=rank, shuffle=True + sampler: torch.utils.data.distributed.DistributedSampler = ( + torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=True + ) ) return ( torch.utils.data.DataLoader( diff --git a/distributed_shampoo/utils/load_balancing_utils.py b/distributed_shampoo/utils/load_balancing_utils.py index 64bb48e..a052b8a 100644 --- a/distributed_shampoo/utils/load_balancing_utils.py +++ b/distributed_shampoo/utils/load_balancing_utils.py @@ -54,7 +54,7 @@ class PolynomialComputationalCostModel(CostModel): def cost(self, tensor: torch.Tensor) -> float: return sum( - max(self.min_cost, polyval(x=dim_size, c=self.coefficients)) # type: ignore[misc] + max(self.min_cost, polyval(x=dim_size, c=self.coefficients)) # type: ignore[misc,call-overload] for dim_size in tensor.shape ) diff --git a/gpa/examples/cifar10_example.py b/gpa/examples/cifar10_example.py index 1a91e27..490ec02 100644 --- a/gpa/examples/cifar10_example.py +++ b/gpa/examples/cifar10_example.py @@ -112,7 +112,7 @@ def get_args() -> argparse.Namespace: if batch_idx % 100 == 0: print( - f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] " + f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] " # type: ignore[arg-type] f"Loss: {loss.item():.4f}" ) diff --git a/gpa/gpa_adamw.py b/gpa/gpa_adamw.py index 99048e4..462ced0 100644 --- a/gpa/gpa_adamw.py +++ b/gpa/gpa_adamw.py @@ -9,10 +9,11 @@ # pyre-unsafe from logging import getLogger -from typing import Callable, Optional, Union +from typing import Callable, Optional, overload, Union import torch import torch.optim +from torch import Tensor from gpa.gpa_types import ( BETA1, BETA2, @@ -351,6 +352,11 @@ def train(self): ) self.state[first_param][TRAIN_MODE].fill_(True) + @overload + def step(self, closure: None = None) -> None: ... + @overload + def step(self, closure: Callable[[], float]) -> float: ... + @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. @@ -377,11 +383,14 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] loss = closure() for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - z_buffer_list = [] + if not group[PARAMS]: + continue + + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + z_buffer_list: list[Tensor] = [] self._init_group( group, @@ -392,12 +401,12 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] z_buffer_list, ) - # Get first_param for accessing shared state. - first_param = group[PARAMS][0] if group[PARAMS] else None + # Get group_first_param for accessing shared state. + group_first_param: Tensor = group[PARAMS][0] # Increment step counter and use it as group step. - self.state[first_param][STEP] += 1 - k = self.state[first_param][STEP].item() + self.state[group_first_param][STEP] += 1 + k = self.state[group_first_param][STEP].item() # Get all group variables. eps = group[EPS] @@ -416,8 +425,8 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] bias_correction2 = 1 - beta2**k # Update LR_MAX in first parameter's state. - lr_max = max(lr, self.state[first_param][LR_MAX].item()) - self.state[first_param][LR_MAX].fill_(lr_max) + lr_max = max(lr, self.state[group_first_param][LR_MAX].item()) + self.state[group_first_param][LR_MAX].fill_(lr_max) assert ( lr_max > 0 @@ -433,7 +442,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] weight_pow_coeff=weight_pow_coeff, lr_max=lr_max, weight_lr_power=weight_lr_power, - weight_sum_ref=self.state[first_param][WEIGHT_SUM], + weight_sum_ref=self.state[group_first_param][WEIGHT_SUM], ) for y, grad, exp_avg, exp_avg_sq, z in zip( diff --git a/gpa/gpu_tests/gpa_adamw_numerics_test.py b/gpa/gpu_tests/gpa_adamw_numerics_test.py index aa9382a..8f8c45b 100644 --- a/gpa/gpu_tests/gpa_adamw_numerics_test.py +++ b/gpa/gpu_tests/gpa_adamw_numerics_test.py @@ -28,7 +28,7 @@ create_gpa_optimizer, create_schedulefree_optimizer, ) -from parameterized import parameterized +from parameterized import parameterized # type: ignore[import-not-found] # ============================================================================= diff --git a/gpa/tests/gpa_test_utils.py b/gpa/tests/gpa_test_utils.py index d41e678..71e7029 100644 --- a/gpa/tests/gpa_test_utils.py +++ b/gpa/tests/gpa_test_utils.py @@ -39,7 +39,7 @@ def get_available_devices() -> tuple[torch.device, ...]: Returns: Tuple of torch.device objects. """ - devices = (torch.device("cpu"),) + devices: tuple[torch.device, ...] = (torch.device("cpu"),) if torch.cuda.is_available(): devices = devices + (torch.device("cuda"),) return devices