From e2c42befd829c975b5955ef28080c49d56faed8c Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 19 Nov 2025 07:16:30 -0800 Subject: [PATCH] move weight update validation functions to util (#573) Summary: * Fix the weight update test * Extract common logic to a separate util function; see the next diff D87083010 for how to use them in verifying weights do get updated as part of infra verification when debugging a buggy run. Reviewed By: casteryh Differential Revision: D87005971 --- src/forge/actors/generator.py | 19 +- src/forge/util/weight_verification.py | 217 ++++++++++++++++++ tests/integration_tests/test_policy_update.py | 118 +++++----- 3 files changed, 285 insertions(+), 69 deletions(-) create mode 100644 src/forge/util/weight_verification.py diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 1cb1d5bd2..1696214fc 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -579,16 +579,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await stop_proc_mesh(actor._fetcher_procs) @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" + async def save_model_params(self): + """Save model parameters before weight update, used for testing purposes only.""" logger.info("[Generator] save model parameters for testing.") - await self.worker._test_save_model_params.call() + await self.worker.save_model_params.call() @endpoint - async def _test_validate_model_params(self, validate_fn): + async def validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[Generator] start validating model parameters.") - return await self.worker._test_validate_model_params.call(validate_fn) + return await self.worker.validate_model_params.call(validate_fn) @dataclass @@ -604,6 +604,9 @@ class GeneratorWorker(ForgeActor): # TODO: Remove below param _test_prev_params = {} + def __post_init__(self): + super().__init__() + @endpoint async def setup(self): self.rank = current_rank().rank @@ -720,8 +723,8 @@ async def update_weights( t.stop() @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" + async def save_model_params(self): + """Save model parameters before weight update, used for testing purposes only.""" logger.info("[GeneratorWorker] save model parameters for testing.") for name, param in self.worker.model_runner.model.named_parameters(): self._test_prev_params[name] = param.detach().cpu() @@ -731,7 +734,7 @@ async def _test_save_model_params(self): ) @endpoint - async def _test_validate_model_params(self, validate_fn): + async def validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[GeneratorWorker] start validating model parameters.") return validate_fn( diff --git a/src/forge/util/weight_verification.py b/src/forge/util/weight_verification.py new file mode 100644 index 000000000..aa98d7df1 --- /dev/null +++ b/src/forge/util/weight_verification.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utilities for verifying model weight updates during training.""" + +import logging +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + + +logger = logging.getLogger(__name__) + + +@dataclass +class WeightSnapshot: + """Snapshot of model weights at a specific point in time.""" + + params: dict[str, torch.Tensor] + version: int | None = None + metadata: dict[str, Any] | None = None + + @classmethod + def from_model( + cls, model: nn.Module, version: int | None = None, device: str = "cpu" + ) -> "WeightSnapshot": + """Create a snapshot of model parameters. + + Args: + model: PyTorch model to snapshot + version: Optional version identifier + device: Device to store snapshot tensors (default: cpu) + + Returns: + WeightSnapshot containing detached copies of all parameters + """ + params = {} + for name, param in model.named_parameters(): + params[name] = param.detach().to(device).clone() + + return cls(params=params, version=version) + + +@dataclass +class WeightVerificationResult: + """Result of weight verification check.""" + + weights_changed: bool + num_params_checked: int + num_params_changed: int + num_params_unchanged: int + num_params_skipped: int + changed_params: list[str] + unchanged_params: list[str] + skipped_params: list[str] + max_delta: float | None = None + mean_delta: float | None = None + + def __str__(self) -> str: + status = "✅ CHANGED" if self.weights_changed else "⚠️ UNCHANGED" + max_delta = f"{self.max_delta:.6e}" if self.max_delta is not None else "N/A" + mean_delta = f"{self.mean_delta:.6e}" if self.mean_delta is not None else "N/A" + + return ( + f"Weight Verification {status}:\n" + f" Checked: {self.num_params_checked}\n" + f" Changed: {self.num_params_changed}\n" + f" Unchanged: {self.num_params_unchanged}\n" + f" Skipped: {self.num_params_skipped}\n" + f" Max delta: {max_delta}\n" + f" Mean delta: {mean_delta}" + ) + + +def verify_weights_changed( + prev_snapshot: WeightSnapshot, + current_model: nn.Module, + atol: float = 1e-6, + rtol: float = 1e-5, + skip_non_float: bool = True, + verbose: bool = False, +) -> WeightVerificationResult: + """Verify that model weights have changed compared to a previous snapshot. + + This is a more robust verification than simple parameter hashing, as it: + - Checks each parameter individually + - Uses proper floating point comparison (torch.allclose) + - Provides detailed information about which parameters changed + - Computes statistics about the magnitude of changes + + Args: + prev_snapshot: Previous weight snapshot to compare against + current_model: Current model to check + atol: Absolute tolerance for considering weights unchanged + rtol: Relative tolerance for considering weights unchanged + skip_non_float: Whether to skip non-floating point parameters + verbose: Whether to log detailed information + + Returns: + WeightVerificationResult with detailed information about changes + """ + changed_params = [] + unchanged_params = [] + skipped_params = [] + deltas = [] + + for name, param in current_model.named_parameters(): + if skip_non_float and not torch.is_floating_point(param): + skipped_params.append(name) + if verbose: + logger.info(f"Skipping non-float param: {name}") + continue + + if name not in prev_snapshot.params: + logger.warning(f"Parameter {name} not found in previous snapshot") + skipped_params.append(name) + continue + + prev_param = prev_snapshot.params[name] + curr_param = param.detach().cpu() + + # Check if parameters are close (i.e., unchanged) + is_close = torch.allclose(prev_param, curr_param, atol=atol, rtol=rtol) + + if is_close: + unchanged_params.append(name) + else: + changed_params.append(name) + # Compute delta for statistics + delta = (curr_param - prev_param).abs().max().item() + deltas.append(delta) + + if verbose: + logger.info( + f"Parameter {name} changed - max delta: {delta:.6e}, " + f"mean delta: {(curr_param - prev_param).abs().mean().item():.6e}" + ) + + # Compute statistics + max_delta = max(deltas) if deltas else 0 + mean_delta = sum(deltas) / len(deltas) if deltas else 0 + + result = WeightVerificationResult( + weights_changed=len(changed_params) > 0, + num_params_checked=len(changed_params) + len(unchanged_params), + num_params_changed=len(changed_params), + num_params_unchanged=len(unchanged_params), + num_params_skipped=len(skipped_params), + changed_params=changed_params, + unchanged_params=unchanged_params, + skipped_params=skipped_params, + max_delta=max_delta, + mean_delta=mean_delta, + ) + + logger.info(str(result)) + + return result + + +def verify_weights_all_zeros( + current_model: nn.Module, + atol: float = 1e-4, + rtol: float = 1e-3, + skip_non_float: bool = True, + verbose: bool = False, +) -> tuple[bool, list[str], list[str]]: + """Verify that all model parameters are zero. + + Args: + current_model: Model to check + atol: Absolute tolerance + rtol: Relative tolerance + skip_non_float: Whether to skip non-floating point parameters + verbose: Whether to log detailed information + + Returns: + Tuple of (all_zeros, zero_params, non_zero_params) + """ + zero_params = [] + non_zero_params = [] + + for name, param in current_model.named_parameters(): + if skip_non_float and not torch.is_floating_point(param): + if verbose: + logger.info(f"Skipping non-float param: {name}") + continue + + param_cpu = param.detach().cpu() + is_zero = torch.allclose( + torch.zeros_like(param_cpu), param_cpu, atol=atol, rtol=rtol + ) + + if is_zero: + zero_params.append(name) + else: + non_zero_params.append(name) + if verbose: + logger.info( + f"Parameter {name} is not zero - " + f"max: {param_cpu.abs().max().item():.6e}, " + f"mean: {param_cpu.abs().mean().item():.6e}" + ) + + all_zeros = len(non_zero_params) == 0 + + logger.info( + f"Zero check: {'✅ PASS' if all_zeros else '⚠️ FAIL'} - " + f"{len(zero_params)} zero, {len(non_zero_params)} non-zero" + ) + + return all_zeros, zero_params, non_zero_params diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index d4151b5b6..645718fcf 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -9,6 +9,7 @@ import shutil from pathlib import Path +import monarch import pytest import pytest_asyncio @@ -22,10 +23,20 @@ from forge.controller.service.service import uuid from forge.types import LauncherConfig, ProvisionerConfig from forge.util.config import resolve_hf_hub_paths +from forge.util.weight_verification import ( + verify_weights_all_zeros, + verify_weights_changed, + WeightSnapshot, +) from monarch.actor import endpoint from omegaconf import DictConfig, OmegaConf +# Workaround for monarch mesh shutdown exit code during teardown +# Without this, proc_mesh.stop will raise exit code 1 after test completes +monarch.actor.unhandled_fault_hook = lambda failure: None + + requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", @@ -39,11 +50,10 @@ """ Run tests: +TORCHSTORE_RDMA_ENABLED=0 \ PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ - --config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml --use_dcp=false + --config tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml -PYTHONPATH=. pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ - --config apps/grpo/qwen3_8b.yaml """ # Temp directory won't work for multi-node because NFS does not cover the tmp path @@ -81,70 +91,53 @@ def _load_config(config_path: str) -> DictConfig: def _test_validate_params_unchanged( prev_params, curr_model, logger ) -> Exception | None: - """Validate that current parameters are the same as prev_params.""" - verified = set() - skipped = set() + """Validate that current parameters are the same as prev_params. + + Uses the new weight_verification utility for robust checking. + """ + prev_snapshot = WeightSnapshot(params=prev_params, version=None) + result = verify_weights_changed( + prev_snapshot, curr_model, atol=1e-3, rtol=1e-2, verbose=False + ) + logger.info( - f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + f"Validation: {result.num_params_checked} params checked, " + f"{result.num_params_changed} changed, {result.num_params_unchanged} unchanged" ) - errs = [] - for name, param in curr_model.named_parameters(): - if not torch.is_floating_point(param): - logger.info(f"Skipping non-float param {name}") - skipped.add(name) - continue - try: - assert name in prev_params, f"Param {name} not found in prev_params" - assert torch.allclose( - prev_params[name], param.cpu(), atol=1e-3, rtol=1e-2 - ), ( - f"current param {name} does not match expected value; " - f"previous param ({prev_params[name].size()})= {prev_params[name]}; " - f"expected = {prev_params[name]} vs got = {param.cpu().size()} {param.cpu()}" - ) - verified.add(name) - except Exception as e: - errs.append((name, e)) - logger.info(f"Verified params = {verified}") - logger.info(f"Skipped params = {skipped}") - if errs: - logger.error( - f"Validation failed for the following params: {[e[0] for e in errs]}" + + # We EXPECT no changes for this validation + if result.weights_changed: + error_msg = ( + f"Weights unexpectedly changed! {result.num_params_changed} params changed " + f"(max_delta={result.max_delta:.6e}). Changed params: {result.changed_params[:5]}" ) - return AssertionError(f"Validation failed: {errs}") + logger.error(error_msg) + return AssertionError(error_msg) def _test_validate_params_all_zeros( prev_params, curr_model, logger ) -> Exception | None: - """Validate all parameters are set to zero. prev_params is actually not used.""" - _ = prev_params - verified = set() - skipped = set() + """Validate all parameters are set to zero.""" + _ = prev_params # Unused + + all_zeros, zero_params, non_zero_params = verify_weights_all_zeros( + curr_model, atol=1e-4, rtol=1e-3, verbose=False + ) + logger.info( - f"Validating model params, all named_parameters() = {curr_model.named_parameters()}" + f"Zero validation: {len(zero_params)} zero params, {len(non_zero_params)} non-zero params" ) - errs = [] - for name, param in curr_model.named_parameters(): - if not torch.is_floating_point(param): - logger.info(f"Skipping non-float param {name}") - skipped.add(name) - continue - try: - param = param.cpu() - assert torch.allclose( - torch.zeros_like(param), param, atol=1e-4, rtol=1e-3 - ), f"param {name} is not zero." - verified.add(name) - except Exception as e: - errs.append((name, e)) - logger.info(f"Verified params = {verified}") - logger.info(f"Skipped params = {skipped}") - if errs: - logger.error( - f"Validation failed for the following params: {[e[0] for e in errs]}" + + if not all_zeros: + error_msg = ( + f"Not all params are zero! {len(non_zero_params)} non-zero params found. " + f"First few non-zero: {non_zero_params[:5]}" ) - return AssertionError(f"Validation failed: {errs}") + logger.error(error_msg) + return AssertionError(error_msg) + + return None @pytest_asyncio.fixture(autouse=True) @@ -211,11 +204,14 @@ async def _setup_and_teardown(request): # ---- teardown ---- # logger.info("Shutting down services and cleaning up DCP directory..") - await asyncio.gather( - policy.shutdown(), - ts.shutdown(), - TitanTrainer.shutdown(titan_trainer), - ) + # Call cleanup to destroy process group before shutdown + # This prevents TCPStore connection errors from NCCL heartbeat threads + await titan_trainer.cleanup.call() + + # Shutdown sequentially to avoid race conditions + await policy.shutdown() + await TitanTrainer.shutdown(titan_trainer) + await ts.shutdown() # Cleanup DCP directory path = Path(TEST_DCP_DIR)